Skip to content

Commit

Permalink
refactor(config): using constant for viper config keys
Browse files Browse the repository at this point in the history
  • Loading branch information
MohammadBnei committed Feb 18, 2024
1 parent 1ecf78c commit 8e77790
Show file tree
Hide file tree
Showing 18 changed files with 146 additions and 79 deletions.
17 changes: 9 additions & 8 deletions api/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"net/http"

"github.com/MohammadBnei/go-ai-cli/config"
"github.com/samber/lo"
"github.com/spf13/viper"
"github.com/tmc/langchaingo/llms"
Expand All @@ -26,16 +27,16 @@ const (
)

func GetGenerateFunction() (func(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error), error) {
model := viper.GetString("model")
switch viper.GetString("API_TYPE") {
model := viper.GetString(config.AI_MODEL_NAME)
switch viper.GetString(config.AI_API_TYPE) {
case API_OPENAI:
llm, err := openai.New(openai.WithToken(viper.GetString("OPENAI_KEY")), openai.WithModel(model))
llm, err := openai.New(openai.WithToken(viper.GetString(config.AI_OPENAI_KEY)), openai.WithModel(model))
return llm.GenerateContent, err
case API_HUGGINGFACE:
llm, err := huggingface.New(huggingface.WithToken(viper.GetString("HUGGINGFACE_KEY")), huggingface.WithModel(model))
llm, err := huggingface.New(huggingface.WithToken(viper.GetString(config.AI_HUGGINGFACE_KEY)), huggingface.WithModel(model))
return llm.GenerateContent, err
case API_OLLAMA:
llama, err := ollama.New(ollama.WithModel(model), ollama.WithServerURL(viper.GetString("OLLAMA_HOST")))
llama, err := ollama.New(ollama.WithModel(model), ollama.WithServerURL(viper.GetString(config.AI_OLLAMA_HOST)))
return llama.GenerateContent, err
default:
return nil, errors.New("invalid api type")
Expand All @@ -47,7 +48,7 @@ func GetApiTypeList() []string {
}

func GetApiModelList() ([]string, error) {
switch viper.GetString("API_TYPE") {
switch viper.GetString(config.AI_API_TYPE) {
case API_OPENAI:
return GetOpenAiModelList()
case API_HUGGINGFACE:
Expand All @@ -61,7 +62,7 @@ func GetApiModelList() ([]string, error) {
}

func GetOllamaModelList() ([]string, error) {
req, err := http.NewRequest("GET", "http://127.0.0.1:11434/api/tags", nil)
req, err := http.NewRequest("GET", viper.GetString(config.AI_OLLAMA_HOST)+"/api/tags", nil)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -91,7 +92,7 @@ func GetOllamaModelList() ([]string, error) {
}

func GetOpenAiModelList() ([]string, error) {
c := openaiHelper.NewClient(viper.GetString("OPENAI_KEY"))
c := openaiHelper.NewClient(viper.GetString(config.AI_OPENAI_KEY))
models, err := c.ListModels(context.Background())
if err != nil {
return nil, err
Expand Down
5 changes: 3 additions & 2 deletions api/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@ import (
"testing"

"github.com/MohammadBnei/go-ai-cli/api"
"github.com/MohammadBnei/go-ai-cli/config"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
)

func TestGetOllamaModelList(t *testing.T) {

// Set the API type to "OLLAMA"
viper.Set("API_TYPE", api.API_OLLAMA)
viper.Set(config.AI_API_TYPE, api.API_OLLAMA)

// Set the OLLAMA_HOST to your test server URL
viper.Set("OLLAMA_HOST", "127.0.0.1:11434")
viper.Set(config.AI_OLLAMA_HOST, "127.0.0.1:11434")

// Call the function
models, err := api.GetOllamaModelList()
Expand Down
3 changes: 2 additions & 1 deletion api/hugging.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@ import (
"errors"
"fmt"

"github.com/MohammadBnei/go-ai-cli/config"
"github.com/TannerKvarfordt/hfapigo"
"github.com/spf13/viper"
)

func Mask(prompt string) (string, error) {
apiKey := viper.GetString("HUGGINGFACE_KEY") // Your Hugging Face API key
apiKey := viper.GetString(config.AI_HUGGINGFACE_KEY) // Your Hugging Face API key
if apiKey == "" {
return "", errors.New("hugging Face API key not found")
}
Expand Down
5 changes: 3 additions & 2 deletions api/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@ import (
"context"
"io"

"github.com/MohammadBnei/go-ai-cli/config"
"github.com/sashabaranov/go-openai"
"github.com/spf13/viper"
)

func SpeechToText(ctx context.Context, filename string, lang string) (string, error) {
c := openai.NewClient(viper.GetString("OPENAI_KEY"))
c := openai.NewClient(viper.GetString(config.AI_OPENAI_KEY))

if lang == "" {
lang = "en"
Expand All @@ -29,7 +30,7 @@ func SpeechToText(ctx context.Context, filename string, lang string) (string, er
}

func TextToSpeech(ctx context.Context, content string) (io.ReadCloser, error) {
c := openai.NewClient(viper.GetString("OPENAI_KEY"))
c := openai.NewClient(viper.GetString(config.AI_OPENAI_KEY))

response, err := c.CreateSpeech(ctx, openai.CreateSpeechRequest{
Model: openai.TTSModel1,
Expand Down
2 changes: 1 addition & 1 deletion cmd/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ var configCmd = &cobra.Command{
Use: "config",
Short: "Set the configuration in a file",
Run: func(cmd *cobra.Command, args []string) {
filePath := viper.GetString("configfile")
filePath := viper.ConfigFileUsed()
folders := strings.Split(filePath, "/")
created := false
if _, err := os.Stat(filePath); folders[0] != filePath && errors.Is(err, os.ErrNotExist) {
Expand Down
13 changes: 5 additions & 8 deletions cmd/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Copyright © 2023 NAME HERE <EMAIL ADDRESS>
package cmd

import (
"github.com/MohammadBnei/go-ai-cli/config"
"github.com/MohammadBnei/go-ai-cli/service"
"github.com/MohammadBnei/go-ai-cli/ui/chat"
"github.com/spf13/cobra"
Expand All @@ -15,14 +16,12 @@ var promptCmd = &cobra.Command{
Use: "prompt",
Short: "Start the prompt loop",
Run: func(cmd *cobra.Command, args []string) {
viper.BindPFlag("md", cmd.Flags().Lookup("md"))

promptConfig := &service.PromptConfig{
ChatMessages: service.NewChatMessages("default"),
}

defaulSystemPrompt := viper.GetStringMapString("default-systems")
savedSystemPrompt := viper.GetStringMapString("systems")
defaulSystemPrompt := viper.GetStringMapString(config.PR_SYSTEM_DEFAULT)
savedSystemPrompt := viper.GetStringMapString(config.PR_SYSTEM)
for k := range defaulSystemPrompt {
promptConfig.ChatMessages.AddMessage(savedSystemPrompt[k], service.RoleSystem)
}
Expand All @@ -41,9 +40,7 @@ var promptCmd = &cobra.Command{
func init() {
RootCmd.AddCommand(promptCmd)

promptCmd.PersistentFlags().Int("depth", 2, "the depth of the tree view, when in file mode")
promptCmd.PersistentFlags().Bool("md", false, "markdown mode enabled")
promptCmd.PersistentFlags().BoolP("auto-load", "s", false, "Automatically save the prompt to a file")
// promptCmd.PersistentFlags().BoolP("auto-load", "s", false, "Automatically save the prompt to a file")

viper.BindPFlag("autoSave", promptCmd.Flags().Lookup("auto-load"))
// viper.BindPFlag("autoSave", promptCmd.Flags().Lookup("auto-load"))
}
11 changes: 10 additions & 1 deletion cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,20 @@ func init() {
return []string{api.API_HUGGINGFACE, api.API_OLLAMA, api.API_OPENAI}, cobra.ShellCompDirectiveDefault
})

RootCmd.PersistentFlags().String(config.AI_OLLAMA_HOST, "http://127.0.0.1:11434", "the ollama host to be added to config")

RootCmd.PersistentFlags().Bool(config.UI_MARKDOWN_MODE, false, "enable markdown mode")
RootCmd.PersistentFlags().Bool(config.UI_CODE_MODE, false, "enable code mode")

RootCmd.PersistentFlags().Float64(config.AI_TEMPERATURE, 0.7, "the temperature of the ai model's response")
RootCmd.PersistentFlags().Int(config.AI_TOP_K, 50, "The top-k parameter limits the model’s predictions to the top k most probable tokens at each step of generation")
RootCmd.PersistentFlags().Float64(config.AI_TOP_P, 0.5, "Top-p controls the cumulative probability of the generated tokens")

RootCmd.PersistentFlags().StringP(config.AI_MODEL_NAME, "m", openai.GPT4, "the model to use")
defaultModel := openai.GPT4
if v, _ := RootCmd.Flags().GetString(config.AI_API_TYPE); v == api.API_OLLAMA {
defaultModel = "llama2"
}
RootCmd.PersistentFlags().StringP(config.AI_MODEL_NAME, "m", defaultModel, "the model to use")
RootCmd.RegisterFlagCompletionFunc(config.AI_MODEL_NAME, func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
apiType, err := cmd.Flags().GetString(config.AI_API_TYPE)
if err != nil || apiType == "" {
Expand Down
15 changes: 9 additions & 6 deletions config/constant.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
package config

const (
UI_MARKDOWN_MODE = "markdown-mode"
UI_CODE_MODE = "code-mode"
UI_MARKDOWN_MODE = "markdown-mode"
UI_CODE_MODE = "code-mode"

AI_TEMPERATURE = "temperature"
AI_TOP_P = "top-p"
AI_TOP_K = "top-k"
AI_TOP_P = "top-p"
AI_TOP_K = "top-k"
AI_MODEL_NAME = "model-name"
AI_API_TYPE = "api-type"
AI_API_TYPE = "api-type"
AI_OPENAI_KEY = "openai-key"
AI_HUGGINGFACE_KEY = "huggingface-key"
AI_OLLAMA_HOST = "ollama-host"

PR_SYSTEM_DEFAULT = "default-systems"
PR_SYSTEM = "systems"
)
5 changes: 3 additions & 2 deletions service/chatMessages.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"sort"
"time"

"github.com/MohammadBnei/go-ai-cli/config"
"github.com/MohammadBnei/go-ai-cli/tool"
"github.com/bwmarrin/snowflake"
"github.com/jinzhu/copier"
Expand Down Expand Up @@ -152,8 +153,8 @@ func (c *ChatMessages) AddMessage(content string, role ROLES) (*ChatMessage, err
Type: TypeUser,
AssociatedMessageId: -1,
Meta: Meta{
ApiType: viper.GetString("API_TYPE"),
Model: viper.GetString("model"),
ApiType: viper.GetString(config.AI_API_TYPE),
Model: viper.GetString(config.AI_MODEL_NAME),
},
}

Expand Down
5 changes: 3 additions & 2 deletions service/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@ import (
"os"
"time"

"github.com/MohammadBnei/go-ai-cli/config"
"github.com/briandowns/spinner"
"github.com/c2h5oh/datasize"
"github.com/sashabaranov/go-openai"
"github.com/spf13/viper"
)

func AskImage(prompt string, size string) ([]byte, error) {
c := openai.NewClient(viper.GetString("OPENAI_KEY"))
c := openai.NewClient(viper.GetString(config.AI_OPENAI_KEY))

s := spinner.New(spinner.CharSets[26], 100*time.Millisecond)
s.Start()
Expand Down Expand Up @@ -47,7 +48,7 @@ func AskImage(prompt string, size string) ([]byte, error) {
}

func EditImage(filePath, prompt, size string) ([]byte, error) {
c := openai.NewClient(viper.GetString("OPENAI_KEY"))
c := openai.NewClient(viper.GetString(config.AI_OPENAI_KEY))

file, err := os.OpenFile(filePath, os.O_CREATE|os.O_RDWR|os.O_APPEND, os.ModePerm)
if err != nil {
Expand Down
5 changes: 3 additions & 2 deletions ui/chat/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"os"
"reflect"

"github.com/MohammadBnei/go-ai-cli/config"
"github.com/MohammadBnei/go-ai-cli/service"
"github.com/MohammadBnei/go-ai-cli/ui/audio"
"github.com/MohammadBnei/go-ai-cli/ui/event"
Expand Down Expand Up @@ -234,14 +235,14 @@ func (m chatModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
if m.userPrompt != "" {
aiRes := m.aiResponse
if viper.GetBool("md") && m.userPrompt != "Infos" {
if viper.GetBool(config.UI_MARKDOWN_MODE) && m.userPrompt != "Infos" {
str, err := m.mdRenderer.Render(aiRes)
if err != nil {
return m, event.Error(err)
}
aiRes = str
}
if !viper.GetBool("md") {
if !viper.GetBool(config.UI_MARKDOWN_MODE) {
aiRes = wordwrap.String(aiRes, m.viewport.Width)
}

Expand Down
11 changes: 6 additions & 5 deletions ui/chat/chatUpdate.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"

"github.com/MohammadBnei/go-ai-cli/api"
"github.com/MohammadBnei/go-ai-cli/config"
"github.com/MohammadBnei/go-ai-cli/service"
"github.com/MohammadBnei/go-ai-cli/ui/event"
"github.com/MohammadBnei/go-ai-cli/ui/style"
Expand All @@ -24,8 +25,8 @@ func getInfoContent(m chatModel) string {
smallTitleStyle := style.TitleStyle.Copy().Margin(0).Padding(0, 2)
return banner.Inline("go ai cli") + "\n" +
lipgloss.NewStyle().AlignVertical(lipgloss.Center).Height(m.viewport.Height).Render(
"Api : "+smallTitleStyle.Render(viper.GetString("API_TYPE"))+"\n"+
"Model : "+smallTitleStyle.Render(viper.GetString("model"))+"\n"+
"Api : "+smallTitleStyle.Render(viper.GetString(config.AI_API_TYPE))+"\n"+
"Model : "+smallTitleStyle.Render(viper.GetString(config.AI_MODEL_NAME))+"\n"+
"Messages : "+smallTitleStyle.Render(fmt.Sprintf("%d", len(m.promptConfig.ChatMessages.Messages)))+"\n"+
"Tokens : "+smallTitleStyle.Render(fmt.Sprintf("%d", m.promptConfig.ChatMessages.TotalTokens))+"\n",
)
Expand Down Expand Up @@ -207,14 +208,14 @@ func sendPrompt(pc *service.PromptConfig, currentChatMsgs currentChatMessages) e
}),
}

if v := viper.GetFloat64("temperature"); v >= 0 {
if v := viper.GetFloat64(config.AI_TEMPERATURE); v >= 0 {
options = append(options, llms.WithTemperature(v))
}
if v := viper.GetInt("topK"); v >= 0 {
if v := viper.GetInt(config.AI_TOP_K); v >= 0 {
options = append(options, llms.WithTopK(v))

}
if v := viper.GetFloat64("topP"); v >= 0 {
if v := viper.GetFloat64(config.AI_TOP_P); v >= 0 {
options = append(options, llms.WithTopP(v))
}

Expand Down
17 changes: 9 additions & 8 deletions ui/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"sort"
"strconv"

"github.com/MohammadBnei/go-ai-cli/config"
"github.com/MohammadBnei/go-ai-cli/service"
"github.com/MohammadBnei/go-ai-cli/ui/event"
"github.com/MohammadBnei/go-ai-cli/ui/form"
Expand All @@ -18,10 +19,10 @@ import (

func NewConfigModel(promptConfig *service.PromptConfig) tea.Model {

savedDefaultSystemPrompt := viper.GetStringMapString("default-systems")
savedDefaultSystemPrompt := viper.GetStringMapString(config.PR_SYSTEM_DEFAULT)
if savedDefaultSystemPrompt == nil {
savedDefaultSystemPrompt = make(map[string]string)
viper.Set("default-systems", savedDefaultSystemPrompt)
viper.Set(config.PR_SYSTEM_DEFAULT, savedDefaultSystemPrompt)
}

items := getItemsAsUiList(promptConfig)
Expand Down Expand Up @@ -79,23 +80,23 @@ func getEditModel(id string) (tea.Model, error) {
var editModel *huh.Form
var afterCmd tea.Cmd
switch id {
case "model":
case config.AI_MODEL_NAME:
modelSelectForm, err := newModelSelectForm(value)
if err != nil {
return nil, err
}
editModel = modelSelectForm

case "api_type":
case config.AI_API_TYPE:
editModel = newApiTypeSelectForm(value)
afterCmd = func() tea.Msg {
modelSelectForm, err := newModelSelectForm(viper.GetString("model"))
modelSelectForm, err := newModelSelectForm(viper.GetString(config.AI_MODEL_NAME))
if err != nil {
return err
}
return event.AddStackEvent{Stack: form.NewEditModel("Editing config model after updating the api type", modelSelectForm, func(form *huh.Form) tea.Cmd {
result := form.GetString("model")
return UpdateConfigValue("model", result, result)
result := form.GetString(config.AI_MODEL_NAME)
return UpdateConfigValue(config.AI_MODEL_NAME, result, result)
})}
}

Expand Down Expand Up @@ -129,7 +130,7 @@ func getEditModel(id string) (tea.Model, error) {
), func(form *huh.Form) tea.Cmd {
result := form.GetBool(id)
updateEvent := UpdateConfigValue(id, result, helper.CheckedStringHelper(result))
if id == "md" {
if id == config.UI_MARKDOWN_MODE {
return tea.Sequence(updateEvent, event.UpdateChatContent("", ""))
}
return updateEvent
Expand Down
5 changes: 3 additions & 2 deletions ui/config/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package config

import (
"github.com/MohammadBnei/go-ai-cli/api"
"github.com/MohammadBnei/go-ai-cli/config"
"github.com/charmbracelet/huh"
)

Expand All @@ -11,9 +12,9 @@ func newModelSelectForm(value string) (*huh.Form, error) {
return nil, err
}

return huh.NewForm(huh.NewGroup(huh.NewSelect[string]().Key("model").Value(&value).Title("Model").Options(huh.NewOptions[string](models...)...))), nil
return huh.NewForm(huh.NewGroup(huh.NewSelect[string]().Key(config.AI_MODEL_NAME).Value(&value).Title("Model").Options(huh.NewOptions[string](models...)...))), nil
}

func newApiTypeSelectForm(value string) *huh.Form {
return huh.NewForm(huh.NewGroup(huh.NewSelect[string]().Key("api_type").Value(&value).Title("API Type").Options(huh.NewOptions[string](api.GetApiTypeList()...)...)))
return huh.NewForm(huh.NewGroup(huh.NewSelect[string]().Key(config.AI_API_TYPE).Value(&value).Title("API Type").Options(huh.NewOptions[string](api.GetApiTypeList()...)...)))
}
Loading

0 comments on commit 8e77790

Please sign in to comment.