diff --git a/cmd/cli/createMsg.go b/cmd/cli/createMsg.go index 88a81fb..e2f3a31 100644 --- a/cmd/cli/createMsg.go +++ b/cmd/cli/createMsg.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "errors" "fmt" "os" @@ -10,14 +11,9 @@ import ( "github.com/atotto/clipboard" "github.com/dfanso/commit-msg/cmd/cli/store" - "github.com/dfanso/commit-msg/internal/chatgpt" - "github.com/dfanso/commit-msg/internal/claude" "github.com/dfanso/commit-msg/internal/display" - "github.com/dfanso/commit-msg/internal/gemini" "github.com/dfanso/commit-msg/internal/git" - "github.com/dfanso/commit-msg/internal/grok" - "github.com/dfanso/commit-msg/internal/groq" - "github.com/dfanso/commit-msg/internal/ollama" + "github.com/dfanso/commit-msg/internal/llm" "github.com/dfanso/commit-msg/internal/stats" "github.com/dfanso/commit-msg/pkg/types" "github.com/google/shlex" @@ -102,6 +98,17 @@ func CreateCommitMsg(dryRun bool, autoCommit bool) { return } + ctx := context.Background() + + providerInstance, err := llm.NewProvider(commitLLM, llm.ProviderOptions{ + Credential: apiKey, + Config: config, + }) + if err != nil { + displayProviderError(commitLLM, err) + os.Exit(1) + } + pterm.Println() spinnerGenerating, err := pterm.DefaultSpinner. WithSequence("⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"). @@ -112,7 +119,7 @@ func CreateCommitMsg(dryRun bool, autoCommit bool) { } attempt := 1 - commitMsg, err := generateMessage(commitLLM, config, changes, apiKey, withAttempt(nil, attempt)) + commitMsg, err := generateMessage(ctx, providerInstance, changes, withAttempt(nil, attempt)) if err != nil { spinnerGenerating.Fail("Failed to generate commit message") displayProviderError(commitLLM, err) @@ -174,7 +181,7 @@ interactionLoop: pterm.Error.Printf("Failed to start spinner: %v\n", err) continue } - updatedMessage, genErr := generateMessage(commitLLM, config, changes, apiKey, generationOpts) + updatedMessage, genErr := generateMessage(ctx, providerInstance, changes, generationOpts) if genErr != nil { spinner.Fail("Regeneration failed") displayProviderError(commitLLM, genErr) @@ -283,22 +290,8 @@ func resolveOllamaConfig(apiKey string) (url, model string) { return url, model } -func generateMessage(provider types.LLMProvider, config *types.Config, changes string, apiKey string, opts *types.GenerationOptions) (string, error) { - switch provider { - case types.ProviderGemini: - return gemini.GenerateCommitMessage(config, changes, apiKey, opts) - case types.ProviderOpenAI: - return chatgpt.GenerateCommitMessage(config, changes, apiKey, opts) - case types.ProviderClaude: - return claude.GenerateCommitMessage(config, changes, apiKey, opts) - case types.ProviderGroq: - return groq.GenerateCommitMessage(config, changes, apiKey, opts) - case types.ProviderOllama: - url, model := resolveOllamaConfig(apiKey) - return ollama.GenerateCommitMessage(config, changes, url, model, opts) - default: - return grok.GenerateCommitMessage(config, changes, apiKey, opts) - } +func generateMessage(ctx context.Context, provider llm.Provider, changes string, opts *types.GenerationOptions) (string, error) { + return provider.Generate(ctx, changes, opts) } func promptActionSelection() (string, error) { @@ -456,6 +449,11 @@ func withAttempt(styleOpts *types.GenerationOptions, attempt int) *types.Generat } func displayProviderError(provider types.LLMProvider, err error) { + if errors.Is(err, llm.ErrMissingCredential) { + displayMissingCredentialHint(provider) + return + } + switch provider { case types.ProviderGemini: pterm.Error.Printf("Gemini API error: %v. Check your GEMINI_API_KEY environment variable or run: commit llm setup\n", err) @@ -467,8 +465,29 @@ func displayProviderError(provider types.LLMProvider, err error) { pterm.Error.Printf("Groq API error: %v. Check your GROQ_API_KEY environment variable or run: commit llm setup\n", err) case types.ProviderGrok: pterm.Error.Printf("Grok API error: %v. Check your GROK_API_KEY environment variable or run: commit llm setup\n", err) + case types.ProviderOllama: + pterm.Error.Printf("Ollama error: %v. Verify the Ollama service URL or run: commit llm setup\n", err) + default: + pterm.Error.Printf("LLM error: %v\n", err) + } +} + +func displayMissingCredentialHint(provider types.LLMProvider) { + switch provider { + case types.ProviderGemini: + pterm.Error.Println("Gemini requires an API key. Run: commit llm setup or set GEMINI_API_KEY.") + case types.ProviderOpenAI: + pterm.Error.Println("OpenAI requires an API key. Run: commit llm setup or set OPENAI_API_KEY.") + case types.ProviderClaude: + pterm.Error.Println("Claude requires an API key. Run: commit llm setup or set CLAUDE_API_KEY.") + case types.ProviderGroq: + pterm.Error.Println("Groq requires an API key. Run: commit llm setup or set GROQ_API_KEY.") + case types.ProviderGrok: + pterm.Error.Println("Grok requires an API key. Run: commit llm setup or set GROK_API_KEY.") + case types.ProviderOllama: + pterm.Error.Println("Ollama requires a reachable service URL. Run: commit llm setup or set OLLAMA_URL.") default: - pterm.Error.Printf("LLM API error: %v\n", err) + pterm.Error.Printf("%s is missing credentials. Run: commit llm setup.\n", provider) } } diff --git a/internal/llm/provider.go b/internal/llm/provider.go new file mode 100644 index 0000000..e2e6402 --- /dev/null +++ b/internal/llm/provider.go @@ -0,0 +1,246 @@ +package llm + +import ( + "context" + "errors" + "fmt" + "os" + "strings" + "sync" + + "github.com/dfanso/commit-msg/internal/chatgpt" + "github.com/dfanso/commit-msg/internal/claude" + "github.com/dfanso/commit-msg/internal/gemini" + "github.com/dfanso/commit-msg/internal/grok" + "github.com/dfanso/commit-msg/internal/groq" + "github.com/dfanso/commit-msg/internal/ollama" + "github.com/dfanso/commit-msg/pkg/types" +) + +// ErrMissingCredential signals that a provider requires a credential such as an API key or URL. +var ErrMissingCredential = errors.New("llm: missing credential") + +// Provider declares the behaviour required by commit-msg to talk to an LLM backend. +type Provider interface { + // Name returns the LLM provider identifier this instance represents. + Name() types.LLMProvider + // Generate requests a commit message for the supplied repository changes. + Generate(ctx context.Context, changes string, opts *types.GenerationOptions) (string, error) +} + +// ProviderOptions captures the data needed to construct a provider instance. +type ProviderOptions struct { + Credential string + Config *types.Config +} + +// Factory describes a function capable of building a Provider. +type Factory func(ProviderOptions) (Provider, error) + +var ( + factoryMu sync.RWMutex + factories = map[types.LLMProvider]Factory{ + types.ProviderOpenAI: newOpenAIProvider, + types.ProviderClaude: newClaudeProvider, + types.ProviderGemini: newGeminiProvider, + types.ProviderGrok: newGrokProvider, + types.ProviderGroq: newGroqProvider, + types.ProviderOllama: newOllamaProvider, + } +) + +// RegisterFactory allows callers (primarily tests) to override or extend provider creation logic. +func RegisterFactory(name types.LLMProvider, factory Factory) { + factoryMu.Lock() + defer factoryMu.Unlock() + factories[name] = factory +} + +// NewProvider returns a concrete Provider implementation for the requested name. +func NewProvider(name types.LLMProvider, opts ProviderOptions) (Provider, error) { + factoryMu.RLock() + factory, ok := factories[name] + factoryMu.RUnlock() + if !ok { + return nil, fmt.Errorf("llm: unsupported provider %s", name) + } + + opts.Config = ensureConfig(opts.Config) + return factory(opts) +} + +type missingCredentialError struct { + provider types.LLMProvider +} + +func (e *missingCredentialError) Error() string { + return fmt.Sprintf("%s credential is required", e.provider.String()) +} + +func (e *missingCredentialError) Unwrap() error { + return ErrMissingCredential +} + +func newMissingCredentialError(provider types.LLMProvider) error { + return &missingCredentialError{provider: provider} +} + +func ensureConfig(cfg *types.Config) *types.Config { + if cfg != nil { + return cfg + } + return &types.Config{} +} + +// --- Provider implementations ------------------------------------------------ + +type openAIProvider struct { + apiKey string + config *types.Config +} + +func newOpenAIProvider(opts ProviderOptions) (Provider, error) { + key := strings.TrimSpace(opts.Credential) + if key == "" { + key = strings.TrimSpace(os.Getenv("OPENAI_API_KEY")) + } + if key == "" { + return nil, newMissingCredentialError(types.ProviderOpenAI) + } + return &openAIProvider{apiKey: key, config: opts.Config}, nil +} + +func (p *openAIProvider) Name() types.LLMProvider { + return types.ProviderOpenAI +} + +func (p *openAIProvider) Generate(_ context.Context, changes string, opts *types.GenerationOptions) (string, error) { + return chatgpt.GenerateCommitMessage(p.config, changes, p.apiKey, opts) +} + +type claudeProvider struct { + apiKey string + config *types.Config +} + +func newClaudeProvider(opts ProviderOptions) (Provider, error) { + key := strings.TrimSpace(opts.Credential) + if key == "" { + key = strings.TrimSpace(os.Getenv("CLAUDE_API_KEY")) + } + if key == "" { + return nil, newMissingCredentialError(types.ProviderClaude) + } + return &claudeProvider{apiKey: key, config: opts.Config}, nil +} + +func (p *claudeProvider) Name() types.LLMProvider { + return types.ProviderClaude +} + +func (p *claudeProvider) Generate(_ context.Context, changes string, opts *types.GenerationOptions) (string, error) { + return claude.GenerateCommitMessage(p.config, changes, p.apiKey, opts) +} + +type geminiProvider struct { + apiKey string + config *types.Config +} + +func newGeminiProvider(opts ProviderOptions) (Provider, error) { + key := strings.TrimSpace(opts.Credential) + if key == "" { + key = strings.TrimSpace(os.Getenv("GEMINI_API_KEY")) + } + if key == "" { + return nil, newMissingCredentialError(types.ProviderGemini) + } + return &geminiProvider{apiKey: key, config: opts.Config}, nil +} + +func (p *geminiProvider) Name() types.LLMProvider { + return types.ProviderGemini +} + +func (p *geminiProvider) Generate(_ context.Context, changes string, opts *types.GenerationOptions) (string, error) { + return gemini.GenerateCommitMessage(p.config, changes, p.apiKey, opts) +} + +type grokProvider struct { + apiKey string + config *types.Config +} + +func newGrokProvider(opts ProviderOptions) (Provider, error) { + key := strings.TrimSpace(opts.Credential) + if key == "" { + key = strings.TrimSpace(os.Getenv("GROK_API_KEY")) + } + if key == "" { + return nil, newMissingCredentialError(types.ProviderGrok) + } + return &grokProvider{apiKey: key, config: opts.Config}, nil +} + +func (p *grokProvider) Name() types.LLMProvider { + return types.ProviderGrok +} + +func (p *grokProvider) Generate(_ context.Context, changes string, opts *types.GenerationOptions) (string, error) { + return grok.GenerateCommitMessage(p.config, changes, p.apiKey, opts) +} + +type groqProvider struct { + apiKey string + config *types.Config +} + +func newGroqProvider(opts ProviderOptions) (Provider, error) { + key := strings.TrimSpace(opts.Credential) + if key == "" { + key = strings.TrimSpace(os.Getenv("GROQ_API_KEY")) + } + if key == "" { + return nil, newMissingCredentialError(types.ProviderGroq) + } + return &groqProvider{apiKey: key, config: opts.Config}, nil +} + +func (p *groqProvider) Name() types.LLMProvider { + return types.ProviderGroq +} + +func (p *groqProvider) Generate(_ context.Context, changes string, opts *types.GenerationOptions) (string, error) { + return groq.GenerateCommitMessage(p.config, changes, p.apiKey, opts) +} + +type ollamaProvider struct { + url string + model string + config *types.Config +} + +func newOllamaProvider(opts ProviderOptions) (Provider, error) { + url := strings.TrimSpace(opts.Credential) + if url == "" { + url = strings.TrimSpace(os.Getenv("OLLAMA_URL")) + if url == "" { + url = "http://localhost:11434/api/generate" + } + } + + model := strings.TrimSpace(os.Getenv("OLLAMA_MODEL")) + if model == "" { + model = "llama3.1" + } + + return &ollamaProvider{url: url, model: model, config: opts.Config}, nil +} + +func (p *ollamaProvider) Name() types.LLMProvider { + return types.ProviderOllama +} + +func (p *ollamaProvider) Generate(_ context.Context, changes string, opts *types.GenerationOptions) (string, error) { + return ollama.GenerateCommitMessage(p.config, changes, p.url, p.model, opts) +} diff --git a/internal/llm/provider_test.go b/internal/llm/provider_test.go new file mode 100644 index 0000000..00201e6 --- /dev/null +++ b/internal/llm/provider_test.go @@ -0,0 +1,130 @@ +package llm + +import ( + "context" + "errors" + "testing" + + "github.com/dfanso/commit-msg/pkg/types" +) + +func TestNewProviderRequiresCredential(t *testing.T) { + remoteProviders := []types.LLMProvider{ + types.ProviderOpenAI, + types.ProviderClaude, + types.ProviderGemini, + types.ProviderGrok, + types.ProviderGroq, + } + + for _, provider := range remoteProviders { + provider := provider + t.Run(provider.String(), func(t *testing.T) { + switch provider { + case types.ProviderOpenAI: + t.Setenv("OPENAI_API_KEY", "") + case types.ProviderClaude: + t.Setenv("CLAUDE_API_KEY", "") + case types.ProviderGemini: + t.Setenv("GEMINI_API_KEY", "") + case types.ProviderGrok: + t.Setenv("GROK_API_KEY", "") + case types.ProviderGroq: + t.Setenv("GROQ_API_KEY", "") + } + + _, err := NewProvider(provider, ProviderOptions{}) + if !errors.Is(err, ErrMissingCredential) { + t.Fatalf("expected ErrMissingCredential for %s, got %v", provider, err) + } + }) + } +} + +func TestNewProviderUsesEnvFallback(t *testing.T) { + t.Setenv("OPENAI_API_KEY", "env-key") + provider, err := NewProvider(types.ProviderOpenAI, ProviderOptions{}) + if err != nil { + t.Fatalf("expected no error using env fallback, got %v", err) + } + + p, ok := provider.(*openAIProvider) + if !ok { + t.Fatalf("expected *openAIProvider, got %T", provider) + } + + if p.apiKey != "env-key" { + t.Fatalf("expected api key to come from env, got %q", p.apiKey) + } +} + +func TestNewProviderUnsupported(t *testing.T) { + _, err := NewProvider(types.LLMProvider("unknown"), ProviderOptions{}) + if err == nil { + t.Fatal("expected error for unsupported provider") + } +} + +func TestNewProviderOllamaDefaults(t *testing.T) { + t.Setenv("OLLAMA_URL", "") + t.Setenv("OLLAMA_MODEL", "") + + provider, err := NewProvider(types.ProviderOllama, ProviderOptions{}) + if err != nil { + t.Fatalf("expected no error for ollama provider, got %v", err) + } + + p, ok := provider.(*ollamaProvider) + if !ok { + t.Fatalf("expected *ollamaProvider, got %T", provider) + } + + if p.url == "" { + t.Fatalf("expected default URL to be set") + } + + if p.model == "" { + t.Fatalf("expected default model to be set") + } +} + +func TestRegisterFactoryOverrides(t *testing.T) { + factoryMu.Lock() + original := factories[types.ProviderOpenAI] + factoryMu.Unlock() + + t.Cleanup(func() { + RegisterFactory(types.ProviderOpenAI, original) + }) + + called := 0 + RegisterFactory(types.ProviderOpenAI, func(opts ProviderOptions) (Provider, error) { + called++ + return fakeProvider{name: types.ProviderOpenAI}, nil + }) + + provider, err := NewProvider(types.ProviderOpenAI, ProviderOptions{}) + if err != nil { + t.Fatalf("expected no error from overridden factory, got %v", err) + } + + if called != 1 { + t.Fatalf("expected overridden factory to be called once, got %d", called) + } + + if provider.Name() != types.ProviderOpenAI { + t.Fatalf("expected provider name %s, got %s", types.ProviderOpenAI, provider.Name()) + } +} + +type fakeProvider struct { + name types.LLMProvider +} + +func (f fakeProvider) Name() types.LLMProvider { + return f.name +} + +func (f fakeProvider) Generate(context.Context, string, *types.GenerationOptions) (string, error) { + return "", nil +}