Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 44 additions & 25 deletions cmd/cli/createMsg.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cmd

import (
"context"
"errors"
"fmt"
"os"
Expand All @@ -10,14 +11,9 @@ import (

"github.com/atotto/clipboard"
"github.com/dfanso/commit-msg/cmd/cli/store"
"github.com/dfanso/commit-msg/internal/chatgpt"
"github.com/dfanso/commit-msg/internal/claude"
"github.com/dfanso/commit-msg/internal/display"
"github.com/dfanso/commit-msg/internal/gemini"
"github.com/dfanso/commit-msg/internal/git"
"github.com/dfanso/commit-msg/internal/grok"
"github.com/dfanso/commit-msg/internal/groq"
"github.com/dfanso/commit-msg/internal/ollama"
"github.com/dfanso/commit-msg/internal/llm"
"github.com/dfanso/commit-msg/internal/stats"
"github.com/dfanso/commit-msg/pkg/types"
"github.com/google/shlex"
Expand Down Expand Up @@ -102,6 +98,17 @@ func CreateCommitMsg(dryRun bool, autoCommit bool) {
return
}

ctx := context.Background()

providerInstance, err := llm.NewProvider(commitLLM, llm.ProviderOptions{
Credential: apiKey,
Config: config,
})
if err != nil {
displayProviderError(commitLLM, err)
os.Exit(1)
}

pterm.Println()
spinnerGenerating, err := pterm.DefaultSpinner.
WithSequence("⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏").
Expand All @@ -112,7 +119,7 @@ func CreateCommitMsg(dryRun bool, autoCommit bool) {
}

attempt := 1
commitMsg, err := generateMessage(commitLLM, config, changes, apiKey, withAttempt(nil, attempt))
commitMsg, err := generateMessage(ctx, providerInstance, changes, withAttempt(nil, attempt))
if err != nil {
spinnerGenerating.Fail("Failed to generate commit message")
displayProviderError(commitLLM, err)
Expand Down Expand Up @@ -174,7 +181,7 @@ interactionLoop:
pterm.Error.Printf("Failed to start spinner: %v\n", err)
continue
}
updatedMessage, genErr := generateMessage(commitLLM, config, changes, apiKey, generationOpts)
updatedMessage, genErr := generateMessage(ctx, providerInstance, changes, generationOpts)
if genErr != nil {
spinner.Fail("Regeneration failed")
displayProviderError(commitLLM, genErr)
Expand Down Expand Up @@ -283,22 +290,8 @@ func resolveOllamaConfig(apiKey string) (url, model string) {
return url, model
}

func generateMessage(provider types.LLMProvider, config *types.Config, changes string, apiKey string, opts *types.GenerationOptions) (string, error) {
switch provider {
case types.ProviderGemini:
return gemini.GenerateCommitMessage(config, changes, apiKey, opts)
case types.ProviderOpenAI:
return chatgpt.GenerateCommitMessage(config, changes, apiKey, opts)
case types.ProviderClaude:
return claude.GenerateCommitMessage(config, changes, apiKey, opts)
case types.ProviderGroq:
return groq.GenerateCommitMessage(config, changes, apiKey, opts)
case types.ProviderOllama:
url, model := resolveOllamaConfig(apiKey)
return ollama.GenerateCommitMessage(config, changes, url, model, opts)
default:
return grok.GenerateCommitMessage(config, changes, apiKey, opts)
}
func generateMessage(ctx context.Context, provider llm.Provider, changes string, opts *types.GenerationOptions) (string, error) {
return provider.Generate(ctx, changes, opts)
}

func promptActionSelection() (string, error) {
Expand Down Expand Up @@ -456,6 +449,11 @@ func withAttempt(styleOpts *types.GenerationOptions, attempt int) *types.Generat
}

func displayProviderError(provider types.LLMProvider, err error) {
if errors.Is(err, llm.ErrMissingCredential) {
displayMissingCredentialHint(provider)
return
}

switch provider {
case types.ProviderGemini:
pterm.Error.Printf("Gemini API error: %v. Check your GEMINI_API_KEY environment variable or run: commit llm setup\n", err)
Expand All @@ -467,8 +465,29 @@ func displayProviderError(provider types.LLMProvider, err error) {
pterm.Error.Printf("Groq API error: %v. Check your GROQ_API_KEY environment variable or run: commit llm setup\n", err)
case types.ProviderGrok:
pterm.Error.Printf("Grok API error: %v. Check your GROK_API_KEY environment variable or run: commit llm setup\n", err)
case types.ProviderOllama:
pterm.Error.Printf("Ollama error: %v. Verify the Ollama service URL or run: commit llm setup\n", err)
default:
pterm.Error.Printf("LLM error: %v\n", err)
}
}

func displayMissingCredentialHint(provider types.LLMProvider) {
switch provider {
case types.ProviderGemini:
pterm.Error.Println("Gemini requires an API key. Run: commit llm setup or set GEMINI_API_KEY.")
case types.ProviderOpenAI:
pterm.Error.Println("OpenAI requires an API key. Run: commit llm setup or set OPENAI_API_KEY.")
case types.ProviderClaude:
pterm.Error.Println("Claude requires an API key. Run: commit llm setup or set CLAUDE_API_KEY.")
case types.ProviderGroq:
pterm.Error.Println("Groq requires an API key. Run: commit llm setup or set GROQ_API_KEY.")
case types.ProviderGrok:
pterm.Error.Println("Grok requires an API key. Run: commit llm setup or set GROK_API_KEY.")
case types.ProviderOllama:
pterm.Error.Println("Ollama requires a reachable service URL. Run: commit llm setup or set OLLAMA_URL.")
default:
pterm.Error.Printf("LLM API error: %v\n", err)
pterm.Error.Printf("%s is missing credentials. Run: commit llm setup.\n", provider)
}
}

Expand Down
246 changes: 246 additions & 0 deletions internal/llm/provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
package llm

import (
"context"
"errors"
"fmt"
"os"
"strings"
"sync"

"github.com/dfanso/commit-msg/internal/chatgpt"
"github.com/dfanso/commit-msg/internal/claude"
"github.com/dfanso/commit-msg/internal/gemini"
"github.com/dfanso/commit-msg/internal/grok"
"github.com/dfanso/commit-msg/internal/groq"
"github.com/dfanso/commit-msg/internal/ollama"
"github.com/dfanso/commit-msg/pkg/types"
)

// ErrMissingCredential signals that a provider requires a credential such as an API key or URL.
var ErrMissingCredential = errors.New("llm: missing credential")

// Provider declares the behaviour required by commit-msg to talk to an LLM backend.
type Provider interface {
// Name returns the LLM provider identifier this instance represents.
Name() types.LLMProvider
// Generate requests a commit message for the supplied repository changes.
Generate(ctx context.Context, changes string, opts *types.GenerationOptions) (string, error)
}

// ProviderOptions captures the data needed to construct a provider instance.
type ProviderOptions struct {
Credential string
Config *types.Config
}

// Factory describes a function capable of building a Provider.
type Factory func(ProviderOptions) (Provider, error)

var (
factoryMu sync.RWMutex
factories = map[types.LLMProvider]Factory{
types.ProviderOpenAI: newOpenAIProvider,
types.ProviderClaude: newClaudeProvider,
types.ProviderGemini: newGeminiProvider,
types.ProviderGrok: newGrokProvider,
types.ProviderGroq: newGroqProvider,
types.ProviderOllama: newOllamaProvider,
}
)

// RegisterFactory allows callers (primarily tests) to override or extend provider creation logic.
func RegisterFactory(name types.LLMProvider, factory Factory) {
factoryMu.Lock()
defer factoryMu.Unlock()
factories[name] = factory
}

// NewProvider returns a concrete Provider implementation for the requested name.
func NewProvider(name types.LLMProvider, opts ProviderOptions) (Provider, error) {
factoryMu.RLock()
factory, ok := factories[name]
factoryMu.RUnlock()
if !ok {
return nil, fmt.Errorf("llm: unsupported provider %s", name)
}

opts.Config = ensureConfig(opts.Config)
return factory(opts)
}

type missingCredentialError struct {
provider types.LLMProvider
}

func (e *missingCredentialError) Error() string {
return fmt.Sprintf("%s credential is required", e.provider.String())
}

func (e *missingCredentialError) Unwrap() error {
return ErrMissingCredential
}

func newMissingCredentialError(provider types.LLMProvider) error {
return &missingCredentialError{provider: provider}
}

func ensureConfig(cfg *types.Config) *types.Config {
if cfg != nil {
return cfg
}
return &types.Config{}
}

// --- Provider implementations ------------------------------------------------

type openAIProvider struct {
apiKey string
config *types.Config
}

func newOpenAIProvider(opts ProviderOptions) (Provider, error) {
key := strings.TrimSpace(opts.Credential)
if key == "" {
key = strings.TrimSpace(os.Getenv("OPENAI_API_KEY"))
}
if key == "" {
return nil, newMissingCredentialError(types.ProviderOpenAI)
}
return &openAIProvider{apiKey: key, config: opts.Config}, nil
}

func (p *openAIProvider) Name() types.LLMProvider {
return types.ProviderOpenAI
}

func (p *openAIProvider) Generate(_ context.Context, changes string, opts *types.GenerationOptions) (string, error) {
return chatgpt.GenerateCommitMessage(p.config, changes, p.apiKey, opts)
}

type claudeProvider struct {
apiKey string
config *types.Config
}

func newClaudeProvider(opts ProviderOptions) (Provider, error) {
key := strings.TrimSpace(opts.Credential)
if key == "" {
key = strings.TrimSpace(os.Getenv("CLAUDE_API_KEY"))
}
if key == "" {
return nil, newMissingCredentialError(types.ProviderClaude)
}
return &claudeProvider{apiKey: key, config: opts.Config}, nil
}

func (p *claudeProvider) Name() types.LLMProvider {
return types.ProviderClaude
}

func (p *claudeProvider) Generate(_ context.Context, changes string, opts *types.GenerationOptions) (string, error) {
return claude.GenerateCommitMessage(p.config, changes, p.apiKey, opts)
}

type geminiProvider struct {
apiKey string
config *types.Config
}

func newGeminiProvider(opts ProviderOptions) (Provider, error) {
key := strings.TrimSpace(opts.Credential)
if key == "" {
key = strings.TrimSpace(os.Getenv("GEMINI_API_KEY"))
}
if key == "" {
return nil, newMissingCredentialError(types.ProviderGemini)
}
return &geminiProvider{apiKey: key, config: opts.Config}, nil
}

func (p *geminiProvider) Name() types.LLMProvider {
return types.ProviderGemini
}

func (p *geminiProvider) Generate(_ context.Context, changes string, opts *types.GenerationOptions) (string, error) {
return gemini.GenerateCommitMessage(p.config, changes, p.apiKey, opts)
}

type grokProvider struct {
apiKey string
config *types.Config
}

func newGrokProvider(opts ProviderOptions) (Provider, error) {
key := strings.TrimSpace(opts.Credential)
if key == "" {
key = strings.TrimSpace(os.Getenv("GROK_API_KEY"))
}
if key == "" {
return nil, newMissingCredentialError(types.ProviderGrok)
}
return &grokProvider{apiKey: key, config: opts.Config}, nil
}

func (p *grokProvider) Name() types.LLMProvider {
return types.ProviderGrok
}

func (p *grokProvider) Generate(_ context.Context, changes string, opts *types.GenerationOptions) (string, error) {
return grok.GenerateCommitMessage(p.config, changes, p.apiKey, opts)
}

type groqProvider struct {
apiKey string
config *types.Config
}

func newGroqProvider(opts ProviderOptions) (Provider, error) {
key := strings.TrimSpace(opts.Credential)
if key == "" {
key = strings.TrimSpace(os.Getenv("GROQ_API_KEY"))
}
if key == "" {
return nil, newMissingCredentialError(types.ProviderGroq)
}
return &groqProvider{apiKey: key, config: opts.Config}, nil
}

func (p *groqProvider) Name() types.LLMProvider {
return types.ProviderGroq
}

func (p *groqProvider) Generate(_ context.Context, changes string, opts *types.GenerationOptions) (string, error) {
return groq.GenerateCommitMessage(p.config, changes, p.apiKey, opts)
}

type ollamaProvider struct {
url string
model string
config *types.Config
}

func newOllamaProvider(opts ProviderOptions) (Provider, error) {
url := strings.TrimSpace(opts.Credential)
if url == "" {
url = strings.TrimSpace(os.Getenv("OLLAMA_URL"))
if url == "" {
url = "http://localhost:11434/api/generate"
}
}

model := strings.TrimSpace(os.Getenv("OLLAMA_MODEL"))
if model == "" {
model = "llama3.1"
}

return &ollamaProvider{url: url, model: model, config: opts.Config}, nil
}

func (p *ollamaProvider) Name() types.LLMProvider {
return types.ProviderOllama
}

func (p *ollamaProvider) Generate(_ context.Context, changes string, opts *types.GenerationOptions) (string, error) {
return ollama.GenerateCommitMessage(p.config, changes, p.url, p.model, opts)
}
Loading
Loading