From 1bb68835990e0b45a29364a53c689269b04f9f67 Mon Sep 17 00:00:00 2001 From: joshhn Date: Mon, 6 Oct 2025 23:30:02 -0700 Subject: [PATCH] Move LLM provider strings to type-safe enum in pkg/types --- cmd/cli/createMsg.go | 237 ++++++++++++++--------------- cmd/cli/llmSetup.go | 147 +++++++++--------- cmd/cli/root.go | 3 +- cmd/cli/store/store.go | 97 +++++------- internal/chatgpt/chatgpt.go | 6 +- internal/claude/claude.go | 2 +- internal/git/operations.go | 52 +++---- internal/ollama/ollama.go | 17 +-- internal/scrubber/scrubber.go | 38 ++--- internal/scrubber/scrubber_test.go | 30 ++-- pkg/types/types.go | 49 ++++++ 11 files changed, 359 insertions(+), 319 deletions(-) diff --git a/cmd/cli/createMsg.go b/cmd/cli/createMsg.go index b27570b..7c062b1 100644 --- a/cmd/cli/createMsg.go +++ b/cmd/cli/createMsg.go @@ -18,11 +18,10 @@ import ( "github.com/pterm/pterm" ) +func CreateCommitMsg() { -func CreateCommitMsg () { - - // Validate COMMIT_LLM and required API keys - useLLM,err := store.DefaultLLMKey() + // Validate COMMIT_LLM and required API keys + useLLM, err := store.DefaultLLMKey() if err != nil { pterm.Error.Printf("No LLM configured. Run: commit llm setup\n") os.Exit(1) @@ -31,144 +30,142 @@ func CreateCommitMsg () { commitLLM := useLLM.LLM apiKey := useLLM.APIKey + // Get current directory + currentDir, err := os.Getwd() + if err != nil { + pterm.Error.Printf("Failed to get current directory: %v\n", err) + os.Exit(1) + } - // Get current directory - currentDir, err := os.Getwd() - if err != nil { - pterm.Error.Printf("Failed to get current directory: %v\n", err) - os.Exit(1) - } + // Check if current directory is a git repository + if !git.IsRepository(currentDir) { + pterm.Error.Printf("Current directory is not a Git repository: %s\n", currentDir) + os.Exit(1) + } - // Check if current directory is a git repository - if !git.IsRepository(currentDir) { - pterm.Error.Printf("Current directory is not a Git repository: %s\n", currentDir) - os.Exit(1) - } + // Create a minimal config for the API + config := &types.Config{ + GrokAPI: "https://api.x.ai/v1/chat/completions", + } - // Create a minimal config for the API - config := &types.Config{ - GrokAPI: "https://api.x.ai/v1/chat/completions", - } + // Create a repo config for the current directory + repoConfig := types.RepoConfig{ + Path: currentDir, + } - // Create a repo config for the current directory - repoConfig := types.RepoConfig{ - Path: currentDir, - } + // Get file statistics before fetching changes + fileStats, err := stats.GetFileStatistics(&repoConfig) + if err != nil { + pterm.Error.Printf("Failed to get file statistics: %v\n", err) + os.Exit(1) + } - // Get file statistics before fetching changes - fileStats, err := stats.GetFileStatistics(&repoConfig) - if err != nil { - pterm.Error.Printf("Failed to get file statistics: %v\n", err) - os.Exit(1) - } + // Display header + pterm.DefaultHeader.WithFullWidth(). + WithBackgroundStyle(pterm.NewStyle(pterm.BgDarkGray)). + WithTextStyle(pterm.NewStyle(pterm.FgLightWhite)). + Println("Commit Message Generator") - // Display header - pterm.DefaultHeader.WithFullWidth(). - WithBackgroundStyle(pterm.NewStyle(pterm.BgDarkGray)). - WithTextStyle(pterm.NewStyle(pterm.FgLightWhite)). - Println("Commit Message Generator") + pterm.Println() - pterm.Println() + // Display file statistics with icons + display.ShowFileStatistics(fileStats) - // Display file statistics with icons - display.ShowFileStatistics(fileStats) + if fileStats.TotalFiles == 0 { + pterm.Warning.Println("No changes detected in the Git repository.") + pterm.Info.Println("Tips:") + pterm.Info.Println(" - Stage your changes with: git add .") + pterm.Info.Println(" - Check repository status with: git status") + pterm.Info.Println(" - Make sure you're in the correct Git repository") + return + } - if fileStats.TotalFiles == 0 { - pterm.Warning.Println("No changes detected in the Git repository.") - pterm.Info.Println("Tips:") - pterm.Info.Println(" - Stage your changes with: git add .") - pterm.Info.Println(" - Check repository status with: git status") - pterm.Info.Println(" - Make sure you're in the correct Git repository") - return - } + // Get the changes + changes, err := git.GetChanges(&repoConfig) + if err != nil { + pterm.Error.Printf("Failed to get Git changes: %v\n", err) + os.Exit(1) + } - // Get the changes - changes, err := git.GetChanges(&repoConfig) - if err != nil { - pterm.Error.Printf("Failed to get Git changes: %v\n", err) - os.Exit(1) - } + if len(changes) == 0 { + pterm.Warning.Println("No changes detected in the Git repository.") + pterm.Info.Println("Tips:") + pterm.Info.Println(" - Stage your changes with: git add .") + pterm.Info.Println(" - Check repository status with: git status") + pterm.Info.Println(" - Make sure you're in the correct Git repository") + return + } - if len(changes) == 0 { - pterm.Warning.Println("No changes detected in the Git repository.") - pterm.Info.Println("Tips:") - pterm.Info.Println(" - Stage your changes with: git add .") - pterm.Info.Println(" - Check repository status with: git status") - pterm.Info.Println(" - Make sure you're in the correct Git repository") - return - } + pterm.Println() + + // Show generating spinner + spinnerGenerating, err := pterm.DefaultSpinner. + WithSequence("⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"). + Start("Generating commit message with " + commitLLM.String() + "...") + if err != nil { + pterm.Error.Printf("Failed to start spinner: %v\n", err) + os.Exit(1) + } - pterm.Println() + var commitMsg string - // Show generating spinner - spinnerGenerating, err := pterm.DefaultSpinner. - WithSequence("⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"). - Start("Generating commit message with " + commitLLM + "...") - if err != nil { - pterm.Error.Printf("Failed to start spinner: %v\n", err) - os.Exit(1) - } + switch commitLLM { + + case types.ProviderGemini: + commitMsg, err = gemini.GenerateCommitMessage(config, changes, apiKey) + + case types.ProviderOpenAI: + commitMsg, err = chatgpt.GenerateCommitMessage(config, changes, apiKey) - var commitMsg string + case types.ProviderClaude: + commitMsg, err = claude.GenerateCommitMessage(config, changes, apiKey) + case types.ProviderGroq: + commitMsg, err = groq.GenerateCommitMessage(config, changes, apiKey) + case types.ProviderOllama: + model := "llama3:latest" + commitMsg, err = ollama.GenerateCommitMessage(config, changes, apiKey, model) + default: + commitMsg, err = grok.GenerateCommitMessage(config, changes, apiKey) + } + + if err != nil { + spinnerGenerating.Fail("Failed to generate commit message") switch commitLLM { - - case "Gemini": - commitMsg, err = gemini.GenerateCommitMessage(config, changes, apiKey) - - case "OpenAI": - commitMsg, err = chatgpt.GenerateCommitMessage(config, changes, apiKey) - - case "Claude": - commitMsg, err = claude.GenerateCommitMessage(config, changes, apiKey) - case "Groq": - commitMsg, err = groq.GenerateCommitMessage(config, changes, apiKey) - case "Ollama": - model := "llama3:latest" - - commitMsg, err = ollama.GenerateCommitMessage(config, changes, apiKey, model) + case types.ProviderGemini: + pterm.Error.Printf("Gemini API error. Check your GEMINI_API_KEY environment variable or run: commit llm setup\n") + case types.ProviderOpenAI: + pterm.Error.Printf("OpenAI API error. Check your OPENAI_API_KEY environment variable or run: commit llm setup\n") + case types.ProviderClaude: + pterm.Error.Printf("Claude API error. Check your CLAUDE_API_KEY environment variable or run: commit llm setup\n") + case types.ProviderGroq: + pterm.Error.Printf("Groq API error. Check your GROQ_API_KEY environment variable or run: commit llm setup\n") + case types.ProviderGrok: + pterm.Error.Printf("Grok API error. Check your GROK_API_KEY environment variable or run: commit llm setup\n") default: - commitMsg, err = grok.GenerateCommitMessage(config, changes, apiKey) - } - - - if err != nil { - spinnerGenerating.Fail("Failed to generate commit message") - switch commitLLM { - case "Gemini": - pterm.Error.Printf("Gemini API error. Check your GEMINI_API_KEY environment variable or run: commit llm setup\n") - case "OpenAI": - pterm.Error.Printf("OpenAI API error. Check your OPENAI_API_KEY environment variable or run: commit llm setup\n") - case "Claude": - pterm.Error.Printf("Claude API error. Check your CLAUDE_API_KEY environment variable or run: commit llm setup\n") - case "Groq": - pterm.Error.Printf("Groq API error. Check your GROQ_API_KEY environment variable or run: commit llm setup\n") - case "Grok": - pterm.Error.Printf("Grok API error. Check your GROK_API_KEY environment variable or run: commit llm setup\n") - default: - pterm.Error.Printf("LLM API error: %v\n", err) - } - os.Exit(1) + pterm.Error.Printf("LLM API error: %v\n", err) } + os.Exit(1) + } - spinnerGenerating.Success("Commit message generated successfully!") + spinnerGenerating.Success("Commit message generated successfully!") - pterm.Println() + pterm.Println() - // Display the commit message in a styled panel - display.ShowCommitMessage(commitMsg) + // Display the commit message in a styled panel + display.ShowCommitMessage(commitMsg) - // Copy to clipboard - err = clipboard.WriteAll(commitMsg) - if err != nil { - pterm.Warning.Printf("Could not copy to clipboard: %v\n", err) - } else { - pterm.Success.Println("Commit message copied to clipboard!") - } + // Copy to clipboard + err = clipboard.WriteAll(commitMsg) + if err != nil { + pterm.Warning.Printf("Could not copy to clipboard: %v\n", err) + } else { + pterm.Success.Println("Commit message copied to clipboard!") + } - pterm.Println() + pterm.Println() - // Display changes preview - display.ShowChangesPreview(fileStats) + // Display changes preview + display.ShowChangesPreview(fileStats) -} \ No newline at end of file +} diff --git a/cmd/cli/llmSetup.go b/cmd/cli/llmSetup.go index d774786..f551c46 100644 --- a/cmd/cli/llmSetup.go +++ b/cmd/cli/llmSetup.go @@ -5,58 +5,59 @@ import ( "fmt" "github.com/dfanso/commit-msg/cmd/cli/store" + "github.com/dfanso/commit-msg/pkg/types" "github.com/manifoldco/promptui" ) - func SetupLLM() error { - providers := []string{"OpenAI", "Claude", "Gemini", "Grok", "Groq", "Ollama"} + providers := types.GetSupportedProviderStrings() prompt := promptui.Select{ Label: "Select LLM", Items: providers, } - _, model, err := prompt.Run() + _, modelStr, err := prompt.Run() if err != nil { return fmt.Errorf("prompt failed") } + model, valid := types.ParseLLMProvider(modelStr) + if !valid { + return fmt.Errorf("invalid LLM provider: %s", modelStr) + } + var apiKey string - + // Skip API key prompt for Ollama (local LLM) apiKeyPrompt := promptui.Prompt{ - Label: "Enter API Key", - Mask: '*', - } - - - switch model { - case "Ollama": - urlPrompt := promptui.Prompt{ + Label: "Enter API Key", + Mask: '*', + } + + switch model { + case types.ProviderOllama: + urlPrompt := promptui.Prompt{ Label: "Enter URL", - } - apiKey, err = urlPrompt.Run() - if err != nil { + } + apiKey, err = urlPrompt.Run() + if err != nil { return fmt.Errorf("failed to read Url: %w", err) - } - - default: - apiKey, err = apiKeyPrompt.Run() - if err != nil { - return fmt.Errorf("failed to read API Key: %w", err) } + default: + apiKey, err = apiKeyPrompt.Run() + if err != nil { + return fmt.Errorf("failed to read API Key: %w", err) } + } LLMConfig := store.LLMProvider{ LLM: model, APIKey: apiKey, } - - err = store.Save(LLMConfig) if err != nil { return err @@ -67,15 +68,15 @@ func SetupLLM() error { } func UpdateLLM() error { - + SavedModels, err := store.ListSavedModels() if err != nil { return err } if len(SavedModels.LLMProviders) == 0 { - return errors.New("no model exists, Please add atleast one model Run: 'commit llm setup'") - + return errors.New("no model exists, Please add atleast one model Run: 'commit llm setup'") + } models := []string{} @@ -83,7 +84,7 @@ func UpdateLLM() error { options2 := []string{"Set Default", "Change URL", "Delete"} //different option for local model for _, p := range SavedModels.LLMProviders { - models = append(models, p.LLM) + models = append(models, p.LLM.String()) } prompt := promptui.Select{ @@ -91,65 +92,75 @@ func UpdateLLM() error { Items: models, } - _,model,err := prompt.Run() + _, model, err := prompt.Run() if err != nil { return err } - prompt = promptui.Select{ + prompt = promptui.Select{ Label: "Select Option", Items: options1, - } + } - apiKeyPrompt := promptui.Prompt { + apiKeyPrompt := promptui.Prompt{ Label: "Enter API Key", - } + } - - if model == "Ollama" { - prompt = promptui.Select{ + if model == types.ProviderOllama.String() { + prompt = promptui.Select{ Label: "Select Option", - Items: options2, - } - - apiKeyPrompt = promptui.Prompt { - Label: "Enter URL", - } + Items: options2, } + apiKeyPrompt = promptui.Prompt{ + Label: "Enter URL", + } + } - opNo,_,err := prompt.Run() + opNo, _, err := prompt.Run() if err != nil { return err } - - switch opNo { - case 0: - err := store.ChangeDefault(model) - if err != nil { - return err - } - fmt.Printf("%s set as default", model) - case 1: - apiKey, err := apiKeyPrompt.Run() - if err != nil { - return err - } - err = store.UpdateAPIKey(model, apiKey) - if err != nil { - return err - } - event := "API Key" - if model == "Ollama"{event = "URL"} - fmt.Printf("%s %s Updated", model,event) - case 2: - err := store.DeleteModel(model) - if err != nil { - return err - } - fmt.Printf("%s model deleted", model) + case 0: + modelProvider, valid := types.ParseLLMProvider(model) + if !valid { + return fmt.Errorf("invalid LLM provider: %s", model) + } + err := store.ChangeDefault(modelProvider) + if err != nil { + return err + } + fmt.Printf("%s set as default", model) + case 1: + apiKey, err := apiKeyPrompt.Run() + if err != nil { + return err + } + modelProvider, valid := types.ParseLLMProvider(model) + if !valid { + return fmt.Errorf("invalid LLM provider: %s", model) + } + err = store.UpdateAPIKey(modelProvider, apiKey) + if err != nil { + return err + } + event := "API Key" + if model == types.ProviderOllama.String() { + event = "URL" + } + fmt.Printf("%s %s Updated", model, event) + case 2: + modelProvider, valid := types.ParseLLMProvider(model) + if !valid { + return fmt.Errorf("invalid LLM provider: %s", model) + } + err := store.DeleteModel(modelProvider) + if err != nil { + return err + } + fmt.Printf("%s model deleted", model) } return nil diff --git a/cmd/cli/root.go b/cmd/cli/root.go index 8b3d565..9b61366 100644 --- a/cmd/cli/root.go +++ b/cmd/cli/root.go @@ -50,7 +50,7 @@ var llmUpdateCmd = &cobra.Command{ } var creatCommitMsg = &cobra.Command{ - Use: ".", + Use: ".", Short: "Create Commit Message", RunE: func(cmd *cobra.Command, args []string) error { CreateCommitMsg() @@ -73,4 +73,3 @@ func init() { llmCmd.AddCommand(llmSetupCmd) llmCmd.AddCommand(llmUpdateCmd) } - diff --git a/cmd/cli/store/store.go b/cmd/cli/store/store.go index d60315a..b18c1c2 100644 --- a/cmd/cli/store/store.go +++ b/cmd/cli/store/store.go @@ -7,26 +7,27 @@ import ( "os" "path/filepath" "runtime" + + "github.com/dfanso/commit-msg/pkg/types" ) type LLMProvider struct { - LLM string `json:"model"` - APIKey string `json:"api_key"` + LLM types.LLMProvider `json:"model"` + APIKey string `json:"api_key"` } type Config struct { - Default string `json:"default"` - LLMProviders []LLMProvider `json:"models"` + Default types.LLMProvider `json:"default"` + LLMProviders []LLMProvider `json:"models"` } func Save(LLMConfig LLMProvider) error { - + cfg := Config{ - LLMConfig.LLM, - []LLMProvider{LLMConfig}, + Default: LLMConfig.LLM, + LLMProviders: []LLMProvider{LLMConfig}, } - configPath, err := getConfigPath() if err != nil { return err @@ -40,23 +41,20 @@ func Save(LLMConfig LLMProvider) error { } } - data, err := os.ReadFile(configPath) - if errors.Is(err, os.ErrNotExist){ + if errors.Is(err, os.ErrNotExist) { data = []byte("{}") } else if err != nil { return err } - if len(data) > 0 { err = json.Unmarshal(data, &cfg) if err != nil { return err } } - - + updated := false for i, p := range cfg.LLMProviders { if p.LLM == LLMConfig.LLM { @@ -70,8 +68,7 @@ func Save(LLMConfig LLMProvider) error { cfg.LLMProviders = append(cfg.LLMProviders, LLMConfig) } - cfg.Default = LLMConfig.LLM - + cfg.Default = LLMConfig.LLM data, err = json.MarshalIndent(cfg, "", " ") if err != nil { @@ -81,18 +78,16 @@ func Save(LLMConfig LLMProvider) error { return os.WriteFile(configPath, data, 0600) } - func checkConfig(configPath string) bool { - _,err := os.Stat(configPath) - if err != nil ||os.IsNotExist(err) { + _, err := os.Stat(configPath) + if err != nil || os.IsNotExist(err) { return false } return true } - func createConfigFile(configPath string) error { err := os.MkdirAll(filepath.Dir(configPath), 0700) @@ -148,7 +143,7 @@ func DefaultLLMKey() (*LLMProvider, error) { var cfg Config var useModel LLMProvider - + configPath, err := getConfigPath() if err != nil { return nil, err @@ -159,7 +154,6 @@ func DefaultLLMKey() (*LLMProvider, error) { return nil, errors.New("config file Not exists") } - data, err := os.ReadFile(configPath) if err != nil { return nil, err @@ -174,22 +168,19 @@ func DefaultLLMKey() (*LLMProvider, error) { return nil, errors.New("config file is empty, Please add at least one LLM Key") } - - defaultLLM := cfg.Default for i, p := range cfg.LLMProviders { if p.LLM == defaultLLM { - useModel= cfg.LLMProviders[i] + useModel = cfg.LLMProviders[i] return &useModel, nil } } return nil, errors.New("not found default model in config") } +func ListSavedModels() (*Config, error) { -func ListSavedModels() (*Config, error){ - var cfg Config configPath, err := getConfigPath() @@ -216,13 +207,11 @@ func ListSavedModels() (*Config, error){ return nil, errors.New("config file is empty, Please add at least one LLM Key") } - return &cfg, nil } - -func ChangeDefault(Model string) error { +func ChangeDefault(Model types.LLMProvider) error { var cfg Config @@ -242,10 +231,10 @@ func ChangeDefault(Model string) error { } if len(data) > 0 { - err = json.Unmarshal(data, &cfg) - if err != nil { - return err - } + err = json.Unmarshal(data, &cfg) + if err != nil { + return err + } } cfg.Default = Model @@ -258,9 +247,8 @@ func ChangeDefault(Model string) error { return os.WriteFile(configPath, data, 0600) } +func DeleteModel(Model types.LLMProvider) error { -func DeleteModel(Model string) error { - var cfg Config var newCfg Config @@ -280,45 +268,42 @@ func DeleteModel(Model string) error { } if len(data) > 0 { - err = json.Unmarshal(data, &cfg) - if err != nil { - return err - } + err = json.Unmarshal(data, &cfg) + if err != nil { + return err + } } - if Model == cfg.Default { if len(cfg.LLMProviders) > 1 { - return fmt.Errorf("cannot delete %s while it is default, set other model default first", Model) + return fmt.Errorf("cannot delete %s while it is default, set other model default first", Model.String()) } else { return os.WriteFile(configPath, []byte("{}"), 0600) } } else { - for _,p := range cfg.LLMProviders { - + for _, p := range cfg.LLMProviders { + if p.LLM != Model { newCfg.LLMProviders = append(newCfg.LLMProviders, p) } } - newCfg.Default = cfg.Default + newCfg.Default = cfg.Default - data, err = json.MarshalIndent(newCfg, "", " ") - if err != nil { + data, err = json.MarshalIndent(newCfg, "", " ") + if err != nil { return err - } - return os.WriteFile(configPath, data, 0600) + } + return os.WriteFile(configPath, data, 0600) } } - -func UpdateAPIKey(Model, APIKey string) error { +func UpdateAPIKey(Model types.LLMProvider, APIKey string) error { var cfg Config - configPath, err := getConfigPath() if err != nil { return err @@ -335,10 +320,10 @@ func UpdateAPIKey(Model, APIKey string) error { } if len(data) > 0 { - err = json.Unmarshal(data, &cfg) - if err != nil { - return err - } + err = json.Unmarshal(data, &cfg) + if err != nil { + return err + } } for i, p := range cfg.LLMProviders { @@ -354,4 +339,4 @@ func UpdateAPIKey(Model, APIKey string) error { return os.WriteFile(configPath, data, 0600) -} \ No newline at end of file +} diff --git a/internal/chatgpt/chatgpt.go b/internal/chatgpt/chatgpt.go index 93429c4..ba1cf2e 100644 --- a/internal/chatgpt/chatgpt.go +++ b/internal/chatgpt/chatgpt.go @@ -11,7 +11,7 @@ import ( ) func GenerateCommitMessage(config *types.Config, changes string, apiKey string) (string, error) { - + client := openai.NewClient(option.WithAPIKey(apiKey)) prompt := fmt.Sprintf("%s\n\n%s", types.CommitPrompt, changes) @@ -20,7 +20,7 @@ func GenerateCommitMessage(config *types.Config, changes string, apiKey string) Messages: []openai.ChatCompletionMessageParamUnion{ openai.UserMessage(prompt), }, - Model: openai.ChatModelGPT4o, + Model: openai.ChatModelGPT4o, }) if err != nil { return "", fmt.Errorf("OpenAI error: %w", err) @@ -29,4 +29,4 @@ func GenerateCommitMessage(config *types.Config, changes string, apiKey string) // Extract and return the commit message commitMsg := resp.Choices[0].Message.Content return commitMsg, nil -} \ No newline at end of file +} diff --git a/internal/claude/claude.go b/internal/claude/claude.go index 85cad4b..1ed0dc6 100644 --- a/internal/claude/claude.go +++ b/internal/claude/claude.go @@ -68,7 +68,7 @@ func GenerateCommitMessage(config *types.Config, changes string, apiKey string) defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("claude AI response %d", resp.StatusCode) + return "", fmt.Errorf("claude AI response %d", resp.StatusCode) } var claudeResponse ClaudeResponse diff --git a/internal/git/operations.go b/internal/git/operations.go index ebd7f2b..8b26fc3 100644 --- a/internal/git/operations.go +++ b/internal/git/operations.go @@ -87,31 +87,31 @@ func GetChanges(config *types.RepoConfig) (string, error) { changes.WriteString("\n\n") // Try to get content of untracked files (limited to text files and smaller size) - untrackedFiles := strings.Split(strings.TrimSpace(string(untrackedOutput)), "\n") - for _, file := range untrackedFiles { - if file == "" { - continue - } - - fullPath := filepath.Join(config.Path, file) - if utils.IsTextFile(fullPath) && utils.IsSmallFile(fullPath) { - fileContent, err := os.ReadFile(fullPath) - if err != nil { - // Log but don't fail - untracked file may have been deleted or is inaccessible - continue - } - changes.WriteString(fmt.Sprintf("Content of new file %s:\n", file)) - - // Use special scrubbing for .env files - if strings.HasSuffix(strings.ToLower(file), ".env") || - strings.Contains(strings.ToLower(file), ".env.") { - changes.WriteString(scrubber.ScrubEnvFile(string(fileContent))) - } else { - changes.WriteString(string(fileContent)) - } - changes.WriteString("\n\n") - } - } + untrackedFiles := strings.Split(strings.TrimSpace(string(untrackedOutput)), "\n") + for _, file := range untrackedFiles { + if file == "" { + continue + } + + fullPath := filepath.Join(config.Path, file) + if utils.IsTextFile(fullPath) && utils.IsSmallFile(fullPath) { + fileContent, err := os.ReadFile(fullPath) + if err != nil { + // Log but don't fail - untracked file may have been deleted or is inaccessible + continue + } + changes.WriteString(fmt.Sprintf("Content of new file %s:\n", file)) + + // Use special scrubbing for .env files + if strings.HasSuffix(strings.ToLower(file), ".env") || + strings.Contains(strings.ToLower(file), ".env.") { + changes.WriteString(scrubber.ScrubEnvFile(string(fileContent))) + } else { + changes.WriteString(string(fileContent)) + } + changes.WriteString("\n\n") + } + } } // 4. Get recent commits for context @@ -125,6 +125,6 @@ func GetChanges(config *types.RepoConfig) (string, error) { // Scrub sensitive data before returning scrubbedChanges := scrubber.ScrubDiff(changes.String()) - + return scrubbedChanges, nil } diff --git a/internal/ollama/ollama.go b/internal/ollama/ollama.go index 90dae08..2094485 100644 --- a/internal/ollama/ollama.go +++ b/internal/ollama/ollama.go @@ -1,30 +1,29 @@ package ollama import ( - "fmt" - "net/http" - "encoding/json" "bytes" + "encoding/json" + "fmt" "io" + "net/http" "github.com/dfanso/commit-msg/pkg/types" - ) type OllamaRequest struct { - Model string `json:"model"` + Model string `json:"model"` Prompt string `json:"prompt"` } type OllamaResponse struct { Response string `json:"response"` - Done bool `json:"done"` + Done bool `json:"done"` } func GenerateCommitMessage(_ *types.Config, changes string, url string, model string) (string, error) { // Use llama3:latest as the default model if model == "" { - model = "llama3:latest" + model = "llama3:latest" } // Preparing the prompt @@ -36,7 +35,7 @@ func GenerateCommitMessage(_ *types.Config, changes string, url string, model st "prompt": prompt, "stream": false, } - + // Generating the body body, err := json.Marshal(reqBody) if err != nil { @@ -72,4 +71,4 @@ func GenerateCommitMessage(_ *types.Config, changes string, url string, model st } return response.Response, nil -} \ No newline at end of file +} diff --git a/internal/scrubber/scrubber.go b/internal/scrubber/scrubber.go index abd2430..69edb29 100644 --- a/internal/scrubber/scrubber.go +++ b/internal/scrubber/scrubber.go @@ -31,7 +31,7 @@ var ( Pattern: regexp.MustCompile(`(?i)(authorization\s*[=:]\s*["\']?)([a-zA-Z0-9_\-\.]{20,})["\']?`), Redact: "${1}[REDACTED_AUTH_TOKEN]\"", }, - + // AWS Credentials { Name: "AWS Access Key", @@ -43,7 +43,7 @@ var ( Pattern: regexp.MustCompile(`(?i)(aws[_-]?secret[_-]?access[_-]?key|AWS_SECRET_ACCESS_KEY)\s*[=:]\s*["\']?([a-zA-Z0-9/+=]{40})["\']?`), Redact: "${1}=\"[REDACTED_AWS_SECRET]\"", }, - + // Database Credentials { Name: "Database URL with Password", @@ -55,7 +55,7 @@ var ( Pattern: regexp.MustCompile(`(?i)(db[_-]?password|database[_-]?password|DB_PASSWORD)\s*[=:]\s*["\']?([^\s"']+)["\']?`), Redact: "${1}=\"[REDACTED_DB_PASSWORD]\"", }, - + // OAuth and Social Media { Name: "GitHub Token", @@ -87,49 +87,49 @@ var ( Pattern: regexp.MustCompile(`(?i)(slack[_-]?token|SLACK_TOKEN)\s*[=:]\s*["\']?(xox[baprs]-[a-zA-Z0-9\-]{10,})["\']?`), Redact: "${1}=\"[REDACTED_SLACK_TOKEN]\"", }, - + // Private Keys { Name: "Private Key", Pattern: regexp.MustCompile(`(?s)(-----BEGIN (?:RSA |EC |DSA |OPENSSH )?PRIVATE KEY-----).*?(-----END (?:RSA |EC |DSA |OPENSSH )?PRIVATE KEY-----)`), Redact: "${1}\n[REDACTED_PRIVATE_KEY]\n${2}", }, - + // JWT Tokens { Name: "JWT Token", Pattern: regexp.MustCompile(`(?i)(jwt|token)\s*[=:]\s*["\']?(eyJ[a-zA-Z0-9_\-]*\.eyJ[a-zA-Z0-9_\-]*\.[a-zA-Z0-9_\-]+)["\']?`), Redact: "${1}=\"[REDACTED_JWT_TOKEN]\"", }, - + // Generic Passwords { Name: "Password", Pattern: regexp.MustCompile(`(?i)(password|passwd|pwd)\s*[=:]\s*["\']([^\s"']{8,})["\']`), Redact: "${1}=\"[REDACTED_PASSWORD]\"", }, - + // Generic Secrets { Name: "Secret", Pattern: regexp.MustCompile(`(?i)(secret|SECRET)\s*[=:]\s*["\']?([a-zA-Z0-9_\-]{20,})["\']?`), Redact: "${1}=\"[REDACTED_SECRET]\"", }, - + // Environment Variable Assignments (catch-all for .env patterns) { Name: "Generic Token", Pattern: regexp.MustCompile(`(?i)(access[_-]?token|auth[_-]?token|client[_-]?secret|private[_-]?key)\s*[=:]\s*["\']?([a-zA-Z0-9_\-\.]{20,})["\']?`), Redact: "${1}=\"[REDACTED_TOKEN]\"", }, - + // Credit Card Numbers (basic pattern) { Name: "Credit Card", Pattern: regexp.MustCompile(`\b([0-9]{4}[\s\-]?){3}[0-9]{4}\b`), Redact: "[REDACTED_CREDIT_CARD]", }, - + // Email in credentials context { Name: "Email in Credentials", @@ -142,12 +142,12 @@ var ( // ScrubDiff removes sensitive information from git diff output func ScrubDiff(diff string) string { scrubbed := diff - + // Apply each pattern for _, pattern := range sensitivePatterns { scrubbed = pattern.Pattern.ReplaceAllString(scrubbed, pattern.Redact) } - + return scrubbed } @@ -156,7 +156,7 @@ func ScrubDiff(diff string) string { func ScrubLines(content string) string { lines := strings.Split(content, "\n") scrubbedLines := make([]string, len(lines)) - + for i, line := range lines { scrubbedLine := line for _, pattern := range sensitivePatterns { @@ -164,7 +164,7 @@ func ScrubLines(content string) string { } scrubbedLines[i] = scrubbedLine } - + return strings.Join(scrubbedLines, "\n") } @@ -193,16 +193,16 @@ func GetDetectedPatterns(content string) []string { func ScrubEnvFile(content string) string { lines := strings.Split(content, "\n") scrubbedLines := make([]string, len(lines)) - + for i, line := range lines { trimmed := strings.TrimSpace(line) - + // Skip comments and empty lines if trimmed == "" || strings.HasPrefix(trimmed, "#") { scrubbedLines[i] = line continue } - + // Check if line contains an assignment if strings.Contains(line, "=") { parts := strings.SplitN(line, "=", 2) @@ -222,10 +222,10 @@ func ScrubEnvFile(content string) string { } } } - + // Apply normal scrubbing scrubbedLines[i] = ScrubDiff(line) } - + return strings.Join(scrubbedLines, "\n") } diff --git a/internal/scrubber/scrubber_test.go b/internal/scrubber/scrubber_test.go index 1dcde7d..cfff4e5 100644 --- a/internal/scrubber/scrubber_test.go +++ b/internal/scrubber/scrubber_test.go @@ -117,7 +117,7 @@ func TestScrubDatabaseCredentials(t *testing.T) { func TestScrubJWTTokens(t *testing.T) { input := `token="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.dozjgNryP4J3jVmNHl0w5N_XgL0n3I9PlFUP0THsR8U"` result := ScrubDiff(input) - + if !strings.Contains(result, "[REDACTED_JWT_TOKEN]") { t.Errorf("ScrubDiff() failed to redact JWT token.\nInput: %s\nOutput: %s", input, result) } @@ -128,9 +128,9 @@ func TestScrubPrivateKeys(t *testing.T) { MIIEpAIBAAKCAQEA1234567890abcdef ghijklmnopqrstuvwxyz -----END RSA PRIVATE KEY-----` - + result := ScrubDiff(input) - + if strings.Contains(result, "MIIEpAIBAAKCAQEA") || strings.Contains(result, "ghijklmnopqrstuvwxyz") { t.Errorf("ScrubDiff() failed to redact private key.\nOutput: %s", result) } @@ -142,7 +142,7 @@ ghijklmnopqrstuvwxyz func TestScrubBearerToken(t *testing.T) { input := `Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9abcdefghijk` result := ScrubDiff(input) - + if !strings.Contains(result, "[REDACTED") { t.Errorf("ScrubDiff() failed to redact bearer token.\nInput: %s\nOutput: %s", input, result) } @@ -176,7 +176,7 @@ func TestScrubGitHubToken(t *testing.T) { func TestScrubSlackToken(t *testing.T) { input := `SLACK_TOKEN=xoxb-1234567890-1234567890-abcdefghijk` result := ScrubDiff(input) - + if !strings.Contains(result, "[REDACTED_SLACK_TOKEN]") { t.Errorf("ScrubDiff() failed to redact Slack token.\nInput: %s\nOutput: %s", input, result) } @@ -221,19 +221,19 @@ PORT=3000 OPENAI_API_KEY=sk-proj-abcdefghijk` result := ScrubEnvFile(input) - + // Check that sensitive values are redacted if strings.Contains(result, "sk-1234567890") || strings.Contains(result, "mysecrettoken123") || strings.Contains(result, "sk-proj-abcdefghijk") { t.Errorf("ScrubEnvFile() failed to redact sensitive values.\nOutput: %s", result) } - + // Check that non-sensitive values are preserved if !strings.Contains(result, "DEBUG=true") || !strings.Contains(result, "PORT=3000") { t.Errorf("ScrubEnvFile() incorrectly redacted non-sensitive values.\nOutput: %s", result) } - + // Check that comments are preserved if !strings.Contains(result, "# Environment variables") { t.Errorf("ScrubEnvFile() removed comments.\nOutput: %s", result) @@ -284,18 +284,18 @@ func TestGetDetectedPatterns(t *testing.T) { password="mySecretPass123" GITHUB_TOKEN=ghp_1234567890abcdefghijklmnopqrstuvw ` - + patterns := GetDetectedPatterns(input) - + if len(patterns) == 0 { t.Error("GetDetectedPatterns() returned no patterns for input with sensitive data") } - + // Should detect at least these patterns hasOpenAI := false hasPassword := false hasGitHub := false - + for _, p := range patterns { if strings.Contains(p, "OpenAI") { hasOpenAI = true @@ -307,7 +307,7 @@ func TestGetDetectedPatterns(t *testing.T) { hasGitHub = true } } - + if !hasOpenAI { t.Error("GetDetectedPatterns() did not detect OpenAI API key") } @@ -344,7 +344,7 @@ index abcdef..123456 100644 };` result := ScrubDiff(input) - + // Check that sensitive values are removed if strings.Contains(result, "secretpass") { t.Error("Failed to scrub database password from diff") @@ -358,7 +358,7 @@ index abcdef..123456 100644 if strings.Contains(result, "sk-1234567890abcdefghijklmnop") { t.Error("Failed to scrub API key from config file") } - + // Check that non-sensitive values are preserved if !strings.Contains(result, "PORT=3000") { t.Error("Incorrectly removed non-sensitive PORT value") diff --git a/pkg/types/types.go b/pkg/types/types.go index 4e6f309..2e6f848 100644 --- a/pkg/types/types.go +++ b/pkg/types/types.go @@ -1,5 +1,54 @@ package types +type LLMProvider string + +const ( + ProviderOpenAI LLMProvider = "OpenAI" + ProviderClaude LLMProvider = "Claude" + ProviderGemini LLMProvider = "Gemini" + ProviderGrok LLMProvider = "Grok" + ProviderGroq LLMProvider = "Groq" + ProviderOllama LLMProvider = "Ollama" +) + +func (p LLMProvider) String() string { + return string(p) +} + +func (p LLMProvider) IsValid() bool { + switch p { + case ProviderOpenAI, ProviderClaude, ProviderGemini, ProviderGrok, ProviderGroq, ProviderOllama: + return true + default: + return false + } +} + +func GetSupportedProviders() []LLMProvider { + return []LLMProvider{ + ProviderOpenAI, + ProviderClaude, + ProviderGemini, + ProviderGrok, + ProviderGroq, + ProviderOllama, + } +} + +func GetSupportedProviderStrings() []string { + providers := GetSupportedProviders() + strings := make([]string, len(providers)) + for i, provider := range providers { + strings[i] = provider.String() + } + return strings +} + +func ParseLLMProvider(s string) (LLMProvider, bool) { + provider := LLMProvider(s) + return provider, provider.IsValid() +} + // Configuration structure type Config struct { GrokAPI string `json:"grok_api"`