diff --git a/sdk/go/ai/media_provider.go b/sdk/go/ai/media_provider.go new file mode 100644 index 00000000..6782d0d3 --- /dev/null +++ b/sdk/go/ai/media_provider.go @@ -0,0 +1,140 @@ +package ai + +import ( + "context" + "fmt" + "sort" + "strings" + "time" +) + +// VideoRequest holds parameters for video generation. +type VideoRequest struct { + Prompt string `json:"prompt"` + Model string `json:"model"` + Duration int `json:"duration,omitempty"` + Resolution string `json:"resolution,omitempty"` + AspectRatio string `json:"aspect_ratio,omitempty"` + GenerateAudio *bool `json:"generate_audio,omitempty"` + Seed *int `json:"seed,omitempty"` + FrameImages []map[string]any `json:"frame_images,omitempty"` + InputReferences []map[string]any `json:"input_references,omitempty"` + PollInterval time.Duration `json:"-"` + Timeout time.Duration `json:"-"` + Extra map[string]any `json:"-"` +} + +// ImageRequest holds parameters for image generation. +type ImageRequest struct { + Prompt string `json:"prompt"` + Model string `json:"model,omitempty"` + Size string `json:"size,omitempty"` + Quality string `json:"quality,omitempty"` + ImageConfig *ImageConfig `json:"image_config,omitempty"` +} + +// ImageConfig holds OpenRouter-specific image configuration. +type ImageConfig struct { + AspectRatio string `json:"aspect_ratio,omitempty"` + ImageSize string `json:"image_size,omitempty"` + SuperResolutionReferences []string `json:"super_resolution_references,omitempty"` +} + +// AudioRequest holds parameters for audio generation. +type AudioRequest struct { + Text string `json:"text"` + Model string `json:"model,omitempty"` + Voice string `json:"voice,omitempty"` + Format string `json:"format,omitempty"` +} + +// MediaResponse holds the result of a media generation call. +type MediaResponse struct { + Text string `json:"text"` + Images []ImageData `json:"images,omitempty"` + Audio *AudioData `json:"audio,omitempty"` + Files []FileData `json:"files,omitempty"` + Videos []VideoData `json:"videos,omitempty"` + RawResponse any `json:"raw_response,omitempty"` +} + +// ImageData holds data for a generated image. +type ImageData struct { + URL string `json:"url,omitempty"` + B64JSON string `json:"b64_json,omitempty"` + RevisedPrompt string `json:"revised_prompt,omitempty"` +} + +// AudioData holds data for generated audio. +type AudioData struct { + Data string `json:"data,omitempty"` + Format string `json:"format"` + URL string `json:"url,omitempty"` +} + +// FileData holds data for a generated file. +type FileData struct { + URL string `json:"url,omitempty"` + Data string `json:"data,omitempty"` + MimeType string `json:"mime_type,omitempty"` + Filename string `json:"filename,omitempty"` +} + +// VideoData holds data for a generated video. +type VideoData struct { + URL string `json:"url,omitempty"` + Data string `json:"data,omitempty"` + MimeType string `json:"mime_type,omitempty"` + Filename string `json:"filename,omitempty"` + Duration float64 `json:"duration,omitempty"` + Resolution string `json:"resolution,omitempty"` + AspectRatio string `json:"aspect_ratio,omitempty"` + HasAudio bool `json:"has_audio,omitempty"` + CostUSD float64 `json:"cost_usd,omitempty"` +} + +// MediaProvider defines the interface for media generation backends. +type MediaProvider interface { + Name() string + SupportedModalities() []string + GenerateImage(ctx context.Context, req ImageRequest) (*MediaResponse, error) + GenerateAudio(ctx context.Context, req AudioRequest) (*MediaResponse, error) + GenerateVideo(ctx context.Context, req VideoRequest) (*MediaResponse, error) +} + +// MediaRouter dispatches (model, capability) pairs to the correct MediaProvider. +type MediaRouter struct { + providers []routerEntry +} + +type routerEntry struct { + prefix string + provider MediaProvider +} + +// NewMediaRouter creates a new MediaRouter. +func NewMediaRouter() *MediaRouter { + return &MediaRouter{} +} + +// Register adds a provider with a model prefix. Longer prefixes match first. +func (r *MediaRouter) Register(prefix string, provider MediaProvider) { + r.providers = append(r.providers, routerEntry{prefix: prefix, provider: provider}) + sort.Slice(r.providers, func(i, j int) bool { + return len(r.providers[i].prefix) > len(r.providers[j].prefix) + }) +} + +// Resolve finds the provider for a model and capability. +func (r *MediaRouter) Resolve(model, capability string) (MediaProvider, error) { + for _, entry := range r.providers { + if strings.HasPrefix(model, entry.prefix) { + for _, mod := range entry.provider.SupportedModalities() { + if mod == capability { + return entry.provider, nil + } + } + } + } + return nil, fmt.Errorf("no provider for model %q with %q capability", model, capability) +} diff --git a/sdk/go/ai/media_provider_test.go b/sdk/go/ai/media_provider_test.go new file mode 100644 index 00000000..2700cb6e --- /dev/null +++ b/sdk/go/ai/media_provider_test.go @@ -0,0 +1,329 @@ +package ai + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockProvider implements MediaProvider for testing. +type mockProvider struct { + name string + modalities []string +} + +func (m *mockProvider) Name() string { return m.name } +func (m *mockProvider) SupportedModalities() []string { return m.modalities } +func (m *mockProvider) GenerateImage(_ context.Context, _ ImageRequest) (*MediaResponse, error) { + return &MediaResponse{Text: m.name + ":image"}, nil +} +func (m *mockProvider) GenerateAudio(_ context.Context, _ AudioRequest) (*MediaResponse, error) { + return &MediaResponse{Text: m.name + ":audio"}, nil +} +func (m *mockProvider) GenerateVideo(_ context.Context, _ VideoRequest) (*MediaResponse, error) { + return &MediaResponse{Text: m.name + ":video"}, nil +} + +func TestMediaRouterResolve(t *testing.T) { + router := NewMediaRouter() + + or := &mockProvider{name: "openrouter", modalities: []string{"image", "audio", "video"}} + other := &mockProvider{name: "other", modalities: []string{"image"}} + + router.Register("openrouter/", or) + router.Register("other/", other) + + tests := []struct { + model string + capability string + wantName string + wantErr bool + }{ + {"openrouter/kling", "video", "openrouter", false}, + {"openrouter/gpt-image-1", "image", "openrouter", false}, + {"other/some-model", "image", "other", false}, + {"other/some-model", "video", "", true}, // other doesn't support video + {"unknown/model", "image", "", true}, // no matching prefix + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("%s_%s", tt.model, tt.capability), func(t *testing.T) { + p, err := router.Resolve(tt.model, tt.capability) + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, p) + } else { + require.NoError(t, err) + assert.Equal(t, tt.wantName, p.Name()) + } + }) + } +} + +func TestMediaRouterLongestPrefixMatch(t *testing.T) { + router := NewMediaRouter() + + general := &mockProvider{name: "general", modalities: []string{"image"}} + specific := &mockProvider{name: "specific", modalities: []string{"image"}} + + router.Register("openrouter/", general) + router.Register("openrouter/dall-e", specific) + + p, err := router.Resolve("openrouter/dall-e-3", "image") + require.NoError(t, err) + assert.Equal(t, "specific", p.Name(), "longer prefix should match first") + + p, err = router.Resolve("openrouter/kling-v2", "image") + require.NoError(t, err) + assert.Equal(t, "general", p.Name(), "shorter prefix should match as fallback") +} + +func TestMediaRouterEmpty(t *testing.T) { + router := NewMediaRouter() + _, err := router.Resolve("any/model", "image") + assert.Error(t, err) + assert.Contains(t, err.Error(), "no provider") +} + +func TestStripPrefix(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"openrouter/kling-video/v2", "kling-video/v2"}, + {"openrouter/gpt-image-1", "gpt-image-1"}, + {"plain-model", "plain-model"}, + {"openrouter/", ""}, + } + for _, tt := range tests { + assert.Equal(t, tt.want, stripPrefix(tt.input)) + } +} + +func TestOpenRouterMediaProviderName(t *testing.T) { + p, err := NewOpenRouterMediaProvider("test-key") + require.NoError(t, err) + assert.Equal(t, "openrouter", p.Name()) + assert.Equal(t, []string{"image", "audio", "video"}, p.SupportedModalities()) +} + +func TestOpenRouterMediaProviderDefaultKey(t *testing.T) { + t.Setenv("OPENROUTER_API_KEY", "env-key") + p, err := NewOpenRouterMediaProvider("") + require.NoError(t, err) + assert.Equal(t, "env-key", p.APIKey) +} + +func TestOpenRouterMediaProviderEmptyKey(t *testing.T) { + t.Setenv("OPENROUTER_API_KEY", "") + _, err := NewOpenRouterMediaProvider("") + assert.Error(t, err) + assert.Contains(t, err.Error(), "API key required") +} + +func TestOpenRouterGenerateImage(t *testing.T) { + // Mock server returning a chat completion with image content + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/chat/completions", r.URL.Path) + assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization")) + + var payload map[string]any + require.NoError(t, json.NewDecoder(r.Body).Decode(&payload)) + assert.Equal(t, "gpt-image-1", payload["model"]) + + resp := map[string]any{ + "choices": []map[string]any{ + { + "message": map[string]any{ + "content": []map[string]any{ + {"type": "text", "text": "Here is your image"}, + {"type": "image_url", "b64_json": "aW1hZ2VkYXRh"}, + }, + }, + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer srv.Close() + + p := &OpenRouterMediaProvider{ + APIKey: "test-key", + BaseURL: srv.URL, + Client: srv.Client(), + } + + resp, err := p.GenerateImage(context.Background(), ImageRequest{ + Prompt: "a cat", + Model: "openrouter/gpt-image-1", + }) + require.NoError(t, err) + assert.Equal(t, "Here is your image", resp.Text) + require.Len(t, resp.Images, 1) + assert.Equal(t, "aW1hZ2VkYXRh", resp.Images[0].B64JSON) +} + +func TestOpenRouterGenerateImageError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{"error": "bad request"}`)) + })) + defer srv.Close() + + p := &OpenRouterMediaProvider{ + APIKey: "test-key", + BaseURL: srv.URL, + Client: srv.Client(), + } + + _, err := p.GenerateImage(context.Background(), ImageRequest{Prompt: "test"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "400") +} + +func TestOpenRouterGenerateAudio(t *testing.T) { + // Mock SSE streaming server + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/chat/completions", r.URL.Path) + + var payload map[string]any + require.NoError(t, json.NewDecoder(r.Body).Decode(&payload)) + assert.Equal(t, true, payload["stream"]) + + w.Header().Set("Content-Type", "text/event-stream") + flusher, _ := w.(http.Flusher) + + // Send text chunk + fmt.Fprintf(w, "data: %s\n\n", `{"choices":[{"delta":{"content":"Hello"}}]}`) + flusher.Flush() + + // Send audio chunk (base64 of "audio") + fmt.Fprintf(w, "data: %s\n\n", `{"choices":[{"delta":{"audio":{"data":"YXVkaW8="}}}]}`) + flusher.Flush() + + fmt.Fprintf(w, "data: [DONE]\n\n") + flusher.Flush() + })) + defer srv.Close() + + p := &OpenRouterMediaProvider{ + APIKey: "test-key", + BaseURL: srv.URL, + Client: srv.Client(), + } + + resp, err := p.GenerateAudio(context.Background(), AudioRequest{ + Text: "Say hello", + Voice: "nova", + }) + require.NoError(t, err) + assert.Equal(t, "Hello", resp.Text) + require.NotNil(t, resp.Audio) + assert.Equal(t, "wav", resp.Audio.Format) + assert.NotEmpty(t, resp.Audio.Data) +} + +func TestOpenRouterGenerateVideoSubmitError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + w.Write([]byte(`{"error":"forbidden"}`)) + })) + defer srv.Close() + + p := &OpenRouterMediaProvider{ + APIKey: "test-key", + BaseURL: srv.URL, + Client: srv.Client(), + } + + _, err := p.GenerateVideo(context.Background(), VideoRequest{ + Prompt: "test", + Model: "openrouter/kling", + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "403") +} + +func TestOpenRouterGenerateVideoFullLifecycle(t *testing.T) { + callCount := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch { + case r.Method == http.MethodPost && r.URL.Path == "/videos": + json.NewEncoder(w).Encode(map[string]string{"id": "job-123"}) + case r.Method == http.MethodGet && r.URL.Path == "/videos/job-123": + callCount++ + if callCount == 1 { + json.NewEncoder(w).Encode(map[string]any{ + "id": "job-123", "status": "processing", + }) + } else { + json.NewEncoder(w).Encode(map[string]any{ + "id": "job-123", + "status": "completed", + "unsigned_url": "https://example.com/video.mp4", + "duration": 5.0, + "cost_usd": 0.05, + }) + } + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + p := &OpenRouterMediaProvider{ + APIKey: "test-key", + BaseURL: srv.URL, + Client: srv.Client(), + } + + resp, err := p.GenerateVideo(context.Background(), VideoRequest{ + Prompt: "test video", + Model: "openrouter/kling", + PollInterval: 10 * time.Millisecond, + Timeout: 5 * time.Second, + }) + require.NoError(t, err) + require.Len(t, resp.Videos, 1) + assert.Equal(t, "https://example.com/video.mp4", resp.Videos[0].URL) + assert.Equal(t, "generated_video.mp4", resp.Videos[0].Filename) + assert.Equal(t, 5.0, resp.Videos[0].Duration) + assert.Equal(t, 0.05, resp.Videos[0].CostUSD) + assert.Equal(t, "video/mp4", resp.Videos[0].MimeType) +} + +func TestOpenRouterGenerateVideoEmptyPrompt(t *testing.T) { + p := &OpenRouterMediaProvider{ + APIKey: "test-key", + BaseURL: "http://localhost", + Client: &http.Client{}, + } + _, err := p.GenerateVideo(context.Background(), VideoRequest{ + Prompt: "", + Model: "openrouter/kling", + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "prompt must not be empty") +} + +func TestOpenRouterGenerateAudioEmptyText(t *testing.T) { + p := &OpenRouterMediaProvider{ + APIKey: "test-key", + BaseURL: "http://localhost", + Client: &http.Client{}, + } + _, err := p.GenerateAudio(context.Background(), AudioRequest{ + Text: " ", + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "text input must not be empty") +} diff --git a/sdk/go/ai/openrouter_media.go b/sdk/go/ai/openrouter_media.go new file mode 100644 index 00000000..13fd811e --- /dev/null +++ b/sdk/go/ai/openrouter_media.go @@ -0,0 +1,469 @@ +package ai + +import ( + "bufio" + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "regexp" + "strings" + "time" +) + +// validJobID restricts job IDs to safe characters (prevents SSRF via path traversal). +var validJobID = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`) + +const ( + defaultOpenRouterBaseURL = "https://openrouter.ai/api/v1" + defaultVideoPollInterval = 30 * time.Second + defaultVideoTimeout = 10 * time.Minute +) + +// OpenRouterMediaProvider implements MediaProvider for OpenRouter's media APIs. +type OpenRouterMediaProvider struct { + APIKey string + BaseURL string + Client *http.Client +} + +// NewOpenRouterMediaProvider creates a provider. If apiKey is empty, reads OPENROUTER_API_KEY. +// Returns error if no API key is available. +func NewOpenRouterMediaProvider(apiKey string) (*OpenRouterMediaProvider, error) { + if apiKey == "" { + apiKey = os.Getenv("OPENROUTER_API_KEY") + } + if apiKey == "" { + return nil, fmt.Errorf("OpenRouter API key required: pass apiKey or set OPENROUTER_API_KEY") + } + return &OpenRouterMediaProvider{ + APIKey: apiKey, + BaseURL: defaultOpenRouterBaseURL, + Client: &http.Client{Timeout: 60 * time.Second}, + }, nil +} + +func (p *OpenRouterMediaProvider) Name() string { + return "openrouter" +} + +func (p *OpenRouterMediaProvider) SupportedModalities() []string { + return []string{"image", "audio", "video"} +} + +func (p *OpenRouterMediaProvider) baseURL() string { + if p.BaseURL != "" { + return strings.TrimSuffix(p.BaseURL, "/") + } + return defaultOpenRouterBaseURL +} + +// stripPrefix removes the "openrouter/" prefix from model names. +func stripPrefix(model string) string { + return strings.TrimPrefix(model, "openrouter/") +} + +// GenerateVideo submits a video job, polls until complete, downloads result. +func (p *OpenRouterMediaProvider) GenerateVideo(ctx context.Context, req VideoRequest) (*MediaResponse, error) { + if strings.TrimSpace(req.Prompt) == "" { + return nil, fmt.Errorf("video prompt must not be empty") + } + + pollInterval := req.PollInterval + if pollInterval == 0 { + pollInterval = defaultVideoPollInterval + } + timeout := req.Timeout + if timeout == 0 { + timeout = defaultVideoTimeout + } + + // Build submit payload + payload := map[string]any{ + "model": stripPrefix(req.Model), + "prompt": req.Prompt, + } + if req.Duration > 0 { + payload["duration"] = req.Duration + } + if req.Resolution != "" { + payload["resolution"] = req.Resolution + } + if req.AspectRatio != "" { + payload["aspect_ratio"] = req.AspectRatio + } + if req.GenerateAudio != nil { + payload["generate_audio"] = *req.GenerateAudio + } + if req.Seed != nil { + payload["seed"] = *req.Seed + } + if len(req.FrameImages) > 0 { + payload["frame_images"] = req.FrameImages + } + if len(req.InputReferences) > 0 { + payload["input_references"] = req.InputReferences + } + for k, v := range req.Extra { + payload[k] = v + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("marshal video request: %w", err) + } + + // Submit job + submitURL := p.baseURL() + "/videos" + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, submitURL, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("create submit request: %w", err) + } + p.setHeaders(httpReq) + + resp, err := p.Client.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("submit video job: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(io.LimitReader(resp.Body, 10*1024*1024)) + if err != nil { + return nil, fmt.Errorf("read submit response: %w", err) + } + if resp.StatusCode >= 400 { + return nil, fmt.Errorf("video submit error (%d): %s", resp.StatusCode, string(respBody)) + } + + var submitResp struct { + ID string `json:"id"` + } + if err := json.Unmarshal(respBody, &submitResp); err != nil { + return nil, fmt.Errorf("parse submit response: %w", err) + } + if submitResp.ID == "" { + return nil, fmt.Errorf("no job ID in submit response: %s", string(respBody)) + } + + // Validate job ID to prevent SSRF via path traversal + if !validJobID.MatchString(submitResp.ID) { + return nil, fmt.Errorf("invalid job ID in submit response: %q", submitResp.ID) + } + + // Derive a context with the video-specific timeout, but respect caller's deadline + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + pollURL := p.baseURL() + "/videos/" + submitResp.ID + + // Poll loop using context for deadline enforcement + const maxTransientErrors = 3 + transientErrors := 0 + + ticker := time.NewTicker(pollInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return nil, fmt.Errorf("video generation: %w", ctx.Err()) + case <-ticker.C: + } + + status, err := p.pollVideoJob(ctx, pollURL) + if err != nil { + transientErrors++ + if transientErrors >= maxTransientErrors { + return nil, fmt.Errorf("video poll failed after %d retries: %w", transientErrors, err) + } + continue // retry on next tick + } + transientErrors = 0 + + switch status.Status { + case "completed": + return p.buildVideoResponse(ctx, status) + case "failed": + return nil, fmt.Errorf("video generation failed: %s", status.Error) + } + // pending/processing — continue polling + } +} + +type videoJobStatus struct { + ID string `json:"id"` + Status string `json:"status"` + Error string `json:"error,omitempty"` + UnsignedURL string `json:"unsigned_url,omitempty"` + Duration float64 `json:"duration,omitempty"` + CostUSD float64 `json:"cost_usd,omitempty"` +} + +func (p *OpenRouterMediaProvider) pollVideoJob(ctx context.Context, url string) (*videoJobStatus, error) { + httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, fmt.Errorf("create poll request: %w", err) + } + p.setHeaders(httpReq) + + resp, err := p.Client.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("poll video job: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(io.LimitReader(resp.Body, 10*1024*1024)) + if err != nil { + return nil, fmt.Errorf("read poll response: %w", err) + } + if resp.StatusCode >= 400 { + return nil, fmt.Errorf("poll error (%d): %s", resp.StatusCode, string(body)) + } + + var status videoJobStatus + if err := json.Unmarshal(body, &status); err != nil { + return nil, fmt.Errorf("parse poll response: %w", err) + } + return &status, nil +} + +func (p *OpenRouterMediaProvider) buildVideoResponse(_ context.Context, status *videoJobStatus) (*MediaResponse, error) { + video := VideoData{ + URL: status.UnsignedURL, + MimeType: "video/mp4", + Filename: "generated_video.mp4", + Duration: status.Duration, + CostUSD: status.CostUSD, + } + + return &MediaResponse{ + Videos: []VideoData{video}, + RawResponse: status, + }, nil +} + +// GenerateImage uses chat completions with image modality. +func (p *OpenRouterMediaProvider) GenerateImage(ctx context.Context, req ImageRequest) (*MediaResponse, error) { + model := req.Model + if model == "" { + model = "openai/gpt-image-1" + } + model = stripPrefix(model) + + payload := map[string]any{ + "model": model, + "messages": []map[string]any{ + {"role": "user", "content": req.Prompt}, + }, + "modalities": []string{"image", "text"}, + } + if req.Size != "" { + payload["size"] = req.Size + } + if req.Quality != "" { + payload["quality"] = req.Quality + } + if req.ImageConfig != nil { + payload["image_config"] = req.ImageConfig + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("marshal image request: %w", err) + } + + url := p.baseURL() + "/chat/completions" + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("create image request: %w", err) + } + p.setHeaders(httpReq) + + resp, err := p.Client.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("execute image request: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(io.LimitReader(resp.Body, 10*1024*1024)) + if err != nil { + return nil, fmt.Errorf("read image response: %w", err) + } + if resp.StatusCode >= 400 { + return nil, fmt.Errorf("image generation error (%d): %s", resp.StatusCode, string(respBody)) + } + + var chatResp struct { + Choices []struct { + Message struct { + Content []struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + B64JSON string `json:"b64_json,omitempty"` + } `json:"content"` + } `json:"message"` + } `json:"choices"` + } + if err := json.Unmarshal(respBody, &chatResp); err != nil { + return nil, fmt.Errorf("parse image response: %w", err) + } + + result := &MediaResponse{RawResponse: json.RawMessage(respBody)} + for _, choice := range chatResp.Choices { + for _, part := range choice.Message.Content { + switch part.Type { + case "text": + result.Text = part.Text + case "image_url", "image": + result.Images = append(result.Images, ImageData{ + B64JSON: part.B64JSON, + }) + } + } + } + + return result, nil +} + +// GenerateAudio uses streaming chat completions with audio modality. +func (p *OpenRouterMediaProvider) GenerateAudio(ctx context.Context, req AudioRequest) (*MediaResponse, error) { + if strings.TrimSpace(req.Text) == "" { + return nil, fmt.Errorf("audio text input must not be empty") + } + + model := req.Model + if model == "" { + model = "openai/gpt-4o-audio-preview" + } + model = stripPrefix(model) + + payload := map[string]any{ + "model": model, + "messages": []map[string]any{ + {"role": "user", "content": req.Text}, + }, + "modalities": []string{"text", "audio"}, + "stream": true, + } + + audioConfig := map[string]string{"format": "wav"} + if req.Voice != "" { + audioConfig["voice"] = req.Voice + } else { + audioConfig["voice"] = "alloy" + } + if req.Format != "" { + audioConfig["format"] = req.Format + } + payload["audio"] = audioConfig + + body, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("marshal audio request: %w", err) + } + + url := p.baseURL() + "/chat/completions" + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("create audio request: %w", err) + } + p.setHeaders(httpReq) + httpReq.Header.Set("Accept", "text/event-stream") + + resp, err := p.Client.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("execute audio request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 10*1024*1024)) + return nil, fmt.Errorf("audio generation error (%d): %s", resp.StatusCode, string(respBody)) + } + + // Parse SSE stream, collect audio chunks + var audioChunks []string + var textParts []string + + scanner := bufio.NewScanner(resp.Body) + // SSE audio chunks can be large base64; set 1MB max line size + scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data: ") { + continue + } + data := strings.TrimPrefix(line, "data: ") + data = strings.TrimSpace(data) + if data == "[DONE]" { + break + } + + var chunk struct { + Choices []struct { + Delta struct { + Content string `json:"content,omitempty"` + Audio *struct { + Data string `json:"data,omitempty"` + } `json:"audio,omitempty"` + } `json:"delta"` + } `json:"choices"` + } + if err := json.Unmarshal([]byte(data), &chunk); err != nil { + continue + } + + for _, choice := range chunk.Choices { + if choice.Delta.Content != "" { + textParts = append(textParts, choice.Delta.Content) + } + if choice.Delta.Audio != nil && choice.Delta.Audio.Data != "" { + audioChunks = append(audioChunks, choice.Delta.Audio.Data) + } + } + } + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("read audio stream: %w", err) + } + + // Concatenate base64 audio chunks + audioFormat := "wav" + if req.Format != "" { + audioFormat = req.Format + } + + var audioData string + if len(audioChunks) > 0 { + // Decode all chunks, concatenate raw bytes, re-encode + var raw []byte + for _, chunk := range audioChunks { + decoded, err := base64.StdEncoding.DecodeString(chunk) + if err != nil { + // Try without padding + decoded, err = base64.RawStdEncoding.DecodeString(chunk) + if err != nil { + return nil, fmt.Errorf("decode audio chunk: %w (chunk length: %d)", err, len(chunk)) + } + } + raw = append(raw, decoded...) + } + audioData = base64.StdEncoding.EncodeToString(raw) + } + + return &MediaResponse{ + Text: strings.Join(textParts, ""), + Audio: &AudioData{ + Data: audioData, + Format: audioFormat, + }, + }, nil +} + +func (p *OpenRouterMediaProvider) setHeaders(req *http.Request) { + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+p.APIKey) +}