diff --git a/cmd/commit-msg/main.go b/cmd/commit-msg/main.go index 7689d28..79c0c38 100644 --- a/cmd/commit-msg/main.go +++ b/cmd/commit-msg/main.go @@ -11,13 +11,20 @@ import ( "github.com/dfanso/commit-msg/internal/gemini" "github.com/dfanso/commit-msg/internal/git" "github.com/dfanso/commit-msg/internal/grok" + "github.com/dfanso/commit-msg/internal/ollama" "github.com/dfanso/commit-msg/internal/stats" "github.com/dfanso/commit-msg/pkg/types" + "github.com/joho/godotenv" "github.com/pterm/pterm" ) // main is the entry point of the commit message generator func main() { + // Load the .env file + if err := godotenv.Load(); err != nil { + log.Printf("warning: unable to load .env file: %v", err) + } + // Validate COMMIT_LLM and required API keys commitLLM := os.Getenv("COMMIT_LLM") var apiKey string @@ -43,6 +50,9 @@ func main() { if apiKey == "" { log.Fatalf("CLAUDE_API_KEY is not set") } + case "ollama": + // No API key required to run a local LLM + apiKey = "" default: log.Fatalf("Invalid COMMIT_LLM value: %s", commitLLM) } @@ -119,6 +129,16 @@ func main() { commitMsg, err = chatgpt.GenerateCommitMessage(config, changes, apiKey) case "claude": commitMsg, err = claude.GenerateCommitMessage(config, changes, apiKey) + case "ollama": + url := os.Getenv("OLLAMA_URL") + if url == "" { + url = "http://localhost:11434/api/generate" + } + model := os.Getenv("OLLAMA_MODEL") + if model == "" { + model = "llama3:latest" + } + commitMsg, err = ollama.GenerateCommitMessage(config, changes, url, model) default: commitMsg, err = grok.GenerateCommitMessage(config, changes, apiKey) } diff --git a/go.mod b/go.mod index 31b87c3..0c29269 100644 --- a/go.mod +++ b/go.mod @@ -31,6 +31,7 @@ require ( github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect github.com/googleapis/gax-go/v2 v2.14.1 // indirect github.com/gookit/color v1.5.4 // indirect + github.com/joho/godotenv v1.5.1 // indirect github.com/lithammer/fuzzysearch v1.1.8 // indirect github.com/mattn/go-runewidth v0.0.16 // indirect github.com/rivo/uniseg v0.4.7 // indirect diff --git a/go.sum b/go.sum index a6d5c71..fb0f184 100644 --- a/go.sum +++ b/go.sum @@ -61,6 +61,8 @@ github.com/gookit/color v1.4.2/go.mod h1:fqRyamkC1W8uxl+lxCQxOT09l/vYfZ+QeiX3rKQ github.com/gookit/color v1.5.0/go.mod h1:43aQb+Zerm/BWh2GnrgOQm7ffz7tvQXEKV6BFMl7wAo= github.com/gookit/color v1.5.4 h1:FZmqs7XOyGgCAxmWyPslpiok1k05wmY3SJTytgvYFs0= github.com/gookit/color v1.5.4/go.mod h1:pZJOeOS8DM43rXbp4AZo1n9zCU2qjpcRko0b6/QJi9w= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.0.10/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c= github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c= diff --git a/internal/ollama/ollama.go b/internal/ollama/ollama.go new file mode 100644 index 0000000..90dae08 --- /dev/null +++ b/internal/ollama/ollama.go @@ -0,0 +1,75 @@ +package ollama + +import ( + "fmt" + "net/http" + "encoding/json" + "bytes" + "io" + + "github.com/dfanso/commit-msg/pkg/types" + +) + +type OllamaRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` +} + +type OllamaResponse struct { + Response string `json:"response"` + Done bool `json:"done"` +} + +func GenerateCommitMessage(_ *types.Config, changes string, url string, model string) (string, error) { + // Use llama3:latest as the default model + if model == "" { + model = "llama3:latest" + } + + // Preparing the prompt + prompt := fmt.Sprintf("%s\n\n%s", types.CommitPrompt, changes) + + // Generating the request body - add stream: false for non-streaming response + reqBody := map[string]interface{}{ + "model": model, + "prompt": prompt, + "stream": false, + } + + // Generating the body + body, err := json.Marshal(reqBody) + if err != nil { + return "", fmt.Errorf("failed to marshal request: %v", err) + } + + resp, err := http.Post(url, "application/json", bytes.NewBuffer(body)) + if err != nil { + return "", fmt.Errorf("failed to send request to Ollama: %v", err) + } + defer resp.Body.Close() + + // Read the full response body for better error handling + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to read response body: %v", err) + } + + // Check HTTP status + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("Ollama API returned status %d: %s", resp.StatusCode, string(responseBody)) + } + + // Since we set stream: false, we get a single response object + var response OllamaResponse + if err := json.Unmarshal(responseBody, &response); err != nil { + return "", fmt.Errorf("failed to decode response: %v", err) + } + + // Check if we got any response + if response.Response == "" { + return "", fmt.Errorf("received empty response from Ollama") + } + + return response.Response, nil +} \ No newline at end of file