diff --git a/README.md b/README.md index 6c04354..52b2d94 100644 --- a/README.md +++ b/README.md @@ -327,6 +327,10 @@ export TMUXAI_DEBUG=true export TMUXAI_MAX_CAPTURE_LINES=300 export TMUXAI_OPENROUTER_API_KEY="your-api-key-here" export TMUXAI_OPENROUTER_MODEL="..." +export TMUXAI_AZURE_OPENAI_API_KEY="your-azure-api-key" +export TMUXAI_AZURE_OPENAI_API_BASE="https://your-resource.openai.azure.com/" +export TMUXAI_AZURE_OPENAI_API_VERSION="2025-04-01-preview" +export TMUXAI_AZURE_OPENAI_DEPLOYMENT_NAME="gpt-4o" ``` You can also use environment variables directly within your configuration file values. The application will automatically expand these variables when loading the configuration: @@ -393,6 +397,16 @@ openrouter: base_url: http://localhost:11434/v1 ``` +For Azure OpenAI: + +```yaml +azure_openai: + api_key: "your-azure-openai-key" + api_base: "https://your-resource.openai.azure.com/" + api_version: "2025-04-01-preview" + deployment_name: "gpt-4o" +``` + _Prompts are currently tuned for Gemini 2.5 by default; behavior with other models may vary._ ## Contributing diff --git a/config.example.yaml b/config.example.yaml index 772ce06..53d461a 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -12,6 +12,13 @@ openrouter: model: google/gemini-2.5-flash-preview # default model base_url: https://openrouter.ai/api/v1 # default base url +# Azure OpenAI configuration +# azure_openai: +# api_key: +# api_base: https://your-resource.openai.azure.com/ +# api_version: 2025-04-01-preview +# deployment_name: gpt-4o + # OpenAI example # openrouter: # api_key: sk-XXXXXXXXX diff --git a/config/config.go b/config/config.go index 9c28aba..542d72c 100644 --- a/config/config.go +++ b/config/config.go @@ -12,17 +12,18 @@ import ( // Config holds the application configuration type Config struct { - Debug bool `mapstructure:"debug"` - MaxCaptureLines int `mapstructure:"max_capture_lines"` - MaxContextSize int `mapstructure:"max_context_size"` - WaitInterval int `mapstructure:"wait_interval"` - SendKeysConfirm bool `mapstructure:"send_keys_confirm"` - PasteMultilineConfirm bool `mapstructure:"paste_multiline_confirm"` - ExecConfirm bool `mapstructure:"exec_confirm"` - WhitelistPatterns []string `mapstructure:"whitelist_patterns"` - BlacklistPatterns []string `mapstructure:"blacklist_patterns"` - OpenRouter OpenRouterConfig `mapstructure:"openrouter"` - Prompts PromptsConfig `mapstructure:"prompts"` + Debug bool `mapstructure:"debug"` + MaxCaptureLines int `mapstructure:"max_capture_lines"` + MaxContextSize int `mapstructure:"max_context_size"` + WaitInterval int `mapstructure:"wait_interval"` + SendKeysConfirm bool `mapstructure:"send_keys_confirm"` + PasteMultilineConfirm bool `mapstructure:"paste_multiline_confirm"` + ExecConfirm bool `mapstructure:"exec_confirm"` + WhitelistPatterns []string `mapstructure:"whitelist_patterns"` + BlacklistPatterns []string `mapstructure:"blacklist_patterns"` + OpenRouter OpenRouterConfig `mapstructure:"openrouter"` + AzureOpenAI AzureOpenAIConfig `mapstructure:"azure_openai"` + Prompts PromptsConfig `mapstructure:"prompts"` } // OpenRouterConfig holds OpenRouter API configuration @@ -32,6 +33,14 @@ type OpenRouterConfig struct { BaseURL string `mapstructure:"base_url"` } +// AzureOpenAIConfig holds Azure OpenAI API configuration +type AzureOpenAIConfig struct { + APIKey string `mapstructure:"api_key"` + APIBase string `mapstructure:"api_base"` + APIVersion string `mapstructure:"api_version"` + DeploymentName string `mapstructure:"deployment_name"` +} + // PromptsConfig holds customizable prompt templates type PromptsConfig struct { BaseSystem string `mapstructure:"base_system"` @@ -56,6 +65,7 @@ func DefaultConfig() *Config { BaseURL: "https://openrouter.ai/api/v1", Model: "google/gemini-2.5-flash-preview", }, + AzureOpenAI: AzureOpenAIConfig{}, Prompts: PromptsConfig{ BaseSystem: ``, ChatAssistant: ``, diff --git a/internal/ai_client.go b/internal/ai_client.go index 5befe4c..40a8332 100644 --- a/internal/ai_client.go +++ b/internal/ai_client.go @@ -15,9 +15,9 @@ import ( "github.com/alvinunreal/tmuxai/logger" ) -// AiClient represents an AI client for interacting with OpenRouter API +// AiClient represents an AI client for interacting with OpenAI-compatible APIs including Azure OpenAI type AiClient struct { - config *config.OpenRouterConfig + config *config.Config client *http.Client } @@ -29,7 +29,7 @@ type Message struct { // ChatCompletionRequest represents a request to the chat completion API type ChatCompletionRequest struct { - Model string `json:"model"` + Model string `json:"model,omitempty"` Messages []Message `json:"messages"` } @@ -47,7 +47,7 @@ type ChatCompletionResponse struct { Choices []ChatCompletionChoice `json:"choices"` } -func NewAiClient(cfg *config.OpenRouterConfig) *AiClient { +func NewAiClient(cfg *config.Config) *AiClient { return &AiClient{ config: cfg, client: &http.Client{}, @@ -94,16 +94,37 @@ func (c *AiClient) ChatCompletion(ctx context.Context, messages []Message, model Messages: messages, } + // determine endpoint and headers based on configuration + var url string + var apiKeyHeader string + var apiKey string + + if c.config.AzureOpenAI.APIKey != "" { + // Use Azure OpenAI endpoint + base := strings.TrimSuffix(c.config.AzureOpenAI.APIBase, "/") + url = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", + base, + c.config.AzureOpenAI.DeploymentName, + c.config.AzureOpenAI.APIVersion) + apiKeyHeader = "api-key" + apiKey = c.config.AzureOpenAI.APIKey + + // Azure endpoint doesn't expect model in body + reqBody.Model = "" + } else { + // default OpenRouter/OpenAI compatible endpoint + baseURL := strings.TrimSuffix(c.config.OpenRouter.BaseURL, "/") + url = baseURL + "/chat/completions" + apiKeyHeader = "Authorization" + apiKey = "Bearer " + c.config.OpenRouter.APIKey + } + reqJSON, err := json.Marshal(reqBody) if err != nil { logger.Error("Failed to marshal request: %v", err) return "", fmt.Errorf("failed to marshal request: %w", err) } - // Remove trailing slash from BaseURL if present: https://github.com/alvinunreal/tmuxai/issues/13 - baseURL := strings.TrimSuffix(c.config.BaseURL, "/") - url := baseURL + "/chat/completions" - req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(reqJSON)) if err != nil { logger.Error("Failed to create request: %v", err) @@ -112,7 +133,7 @@ func (c *AiClient) ChatCompletion(ctx context.Context, messages []Message, model // Set headers req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+c.config.APIKey) + req.Header.Set(apiKeyHeader, apiKey) req.Header.Set("HTTP-Referer", "https://github.com/alvinunreal/tmuxai") req.Header.Set("X-Title", "TmuxAI") diff --git a/internal/ai_client_test.go b/internal/ai_client_test.go new file mode 100644 index 0000000..87b6dba --- /dev/null +++ b/internal/ai_client_test.go @@ -0,0 +1,47 @@ +package internal + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/alvinunreal/tmuxai/config" +) + +func TestAzureOpenAIEndpoint(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/openai/deployments/test-dep/chat/completions" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + if r.URL.Query().Get("api-version") != "2025-04-01-preview" { + t.Errorf("missing api-version query") + } + if r.Header.Get("api-key") != "test-key" { + t.Errorf("missing api-key header") + } + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"choices":[{"message":{"content":"ok"}}]}`)) + })) + defer server.Close() + + cfg := &config.Config{ + OpenRouter: config.OpenRouterConfig{}, + AzureOpenAI: config.AzureOpenAIConfig{ + APIKey: "test-key", + APIBase: server.URL, + APIVersion: "2025-04-01-preview", + DeploymentName: "test-dep", + }, + } + + client := NewAiClient(cfg) + msg := []Message{{Role: "user", Content: "hi"}} + resp, err := client.ChatCompletion(context.Background(), msg, "model") + if err != nil { + t.Fatalf("ChatCompletion error: %v", err) + } + if resp != "ok" { + t.Errorf("unexpected response: %s", resp) + } +} diff --git a/internal/manager.go b/internal/manager.go index 465dc44..9e7353e 100644 --- a/internal/manager.go +++ b/internal/manager.go @@ -46,9 +46,9 @@ type Manager struct { // NewManager creates a new manager agent func NewManager(cfg *config.Config) (*Manager, error) { - if cfg.OpenRouter.APIKey == "" { - fmt.Println("OpenRouter API key is required. Set it in the config file or as an environment variable: TMUXAI_OPENROUTER_API_KEY") - return nil, fmt.Errorf("OpenRouter API key is required") + if cfg.OpenRouter.APIKey == "" && cfg.AzureOpenAI.APIKey == "" { + fmt.Println("An API key is required. Set OpenRouter or Azure OpenAI credentials in the config file or environment variables.") + return nil, fmt.Errorf("API key required") } paneId, err := system.TmuxCurrentPaneId() @@ -71,7 +71,7 @@ func NewManager(cfg *config.Config) (*Manager, error) { os.Exit(0) } - aiClient := NewAiClient(&cfg.OpenRouter) + aiClient := NewAiClient(cfg) os := system.GetOSDetails() manager := &Manager{