Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backend/cmd/server/wire_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

51 changes: 51 additions & 0 deletions backend/internal/repository/openai_403_counter_cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package repository

import (
"context"
"fmt"

"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)

const openAI403CounterPrefix = "openai_403_count:account:"

var openAI403CounterIncrScript = redis.NewScript(`
local key = KEYS[1]
local ttl = tonumber(ARGV[1])

local count = redis.call('INCR', key)
if count == 1 then
redis.call('EXPIRE', key, ttl)
end

return count
`)

type openAI403CounterCache struct {
rdb *redis.Client
}

func NewOpenAI403CounterCache(rdb *redis.Client) service.OpenAI403CounterCache {
return &openAI403CounterCache{rdb: rdb}
}

func (c *openAI403CounterCache) IncrementOpenAI403Count(ctx context.Context, accountID int64, windowMinutes int) (int64, error) {
key := fmt.Sprintf("%s%d", openAI403CounterPrefix, accountID)

ttlSeconds := windowMinutes * 60
if ttlSeconds < 60 {
ttlSeconds = 60
}

result, err := openAI403CounterIncrScript.Run(ctx, c.rdb, []string{key}, ttlSeconds).Int64()
if err != nil {
return 0, fmt.Errorf("increment openai 403 count: %w", err)
}
return result, nil
}

func (c *openAI403CounterCache) ResetOpenAI403Count(ctx context.Context, accountID int64) error {
key := fmt.Sprintf("%s%d", openAI403CounterPrefix, accountID)
return c.rdb.Del(ctx, key).Err()
}
1 change: 1 addition & 0 deletions backend/internal/repository/wire.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ var ProviderSet = wire.NewSet(
NewAPIKeyCache,
NewTempUnschedCache,
NewTimeoutCounterCache,
NewOpenAI403CounterCache,
NewInternal500CounterCache,
ProvideConcurrencyCache,
ProvideSessionLimitCache,
Expand Down
4 changes: 1 addition & 3 deletions backend/internal/service/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -930,10 +930,8 @@ func (a *Account) SupportsOpenAIImageCapability(capability OpenAIImagesCapabilit
return false
}
switch capability {
case OpenAIImagesCapabilityBasic:
case OpenAIImagesCapabilityBasic, OpenAIImagesCapabilityNative:
return a.Type == AccountTypeOAuth || a.Type == AccountTypeAPIKey
case OpenAIImagesCapabilityNative:
return a.Type == AccountTypeAPIKey
default:
return true
}
Expand Down
231 changes: 46 additions & 185 deletions backend/internal/service/account_test_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"bytes"
"context"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
Expand Down Expand Up @@ -1138,7 +1137,7 @@ func (s *AccountTestService) testOpenAIImageAPIKey(c *gin.Context, ctx context.C
return nil
}

// testOpenAIImageOAuth tests OpenAI image generation using an OAuth account via ChatGPT backend API.
// testOpenAIImageOAuth tests OpenAI image generation using an OAuth account via Codex /responses API.
func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Context, account *Account, modelID, prompt string) error {
authToken := account.GetOpenAIAccessToken()
if authToken == "" {
Expand All @@ -1153,119 +1152,82 @@ func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Co
c.Writer.Flush()

s.sendEvent(c, TestEvent{Type: "test_start", Model: modelID})
s.sendEvent(c, TestEvent{Type: "content", Text: "Initializing ChatGPT backend...\n"})
s.sendEvent(c, TestEvent{Type: "content", Text: "Calling Codex /responses image tool...\n"})

// Build headers (replicating buildOpenAIBackendAPIHeaders logic)
headers := buildOpenAIBackendAPIHeadersForTest(ctx, account, authToken, s.accountRepo)

proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
parsed := &OpenAIImagesRequest{
Endpoint: openAIImagesGenerationsEndpoint,
Model: strings.TrimSpace(modelID),
Prompt: prompt,
}
applyOpenAIImagesDefaults(parsed)

client, err := newOpenAIBackendAPIClient(proxyURL)
responsesBody, err := buildOpenAIImagesResponsesRequest(parsed, parsed.Model)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to create client: %s", err.Error()))
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to build image request: %s", err.Error()))
}

// Bootstrap
if bootstrapErr := bootstrapOpenAIBackendAPI(ctx, client, headers); bootstrapErr != nil {
log.Printf("OpenAI image test bootstrap warning: %v", bootstrapErr)
}

// Fetch chat requirements
s.sendEvent(c, TestEvent{Type: "content", Text: "Fetching chat requirements...\n"})
chatReqs, err := fetchOpenAIChatRequirements(ctx, client, headers)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, chatgptCodexAPIURL, bytes.NewReader(responsesBody))
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Chat requirements failed: %s", err.Error()))
return s.sendErrorAndEnd(c, "Failed to create request")
}
if chatReqs.Arkose.Required {
return s.sendErrorAndEnd(c, "Unsupported challenge: arkose required")
req.Host = "chatgpt.com"
req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("OpenAI-Beta", "responses=experimental")
req.Header.Set("originator", "opencode")
if customUA := strings.TrimSpace(account.GetOpenAIUserAgent()); customUA != "" {
req.Header.Set("User-Agent", customUA)
} else {
req.Header.Set("User-Agent", codexCLIUserAgent)
}

// Initialize and prepare conversation
s.sendEvent(c, TestEvent{Type: "content", Text: "Preparing image conversation...\n"})
parentMessageID := uuid.NewString()
proofToken := generateOpenAIProofToken(chatReqs.ProofOfWork.Required, chatReqs.ProofOfWork.Seed, chatReqs.ProofOfWork.Difficulty, headers.Get("User-Agent"))
_ = initializeOpenAIImageConversation(ctx, client, headers)
conduitToken, err := prepareOpenAIImageConversation(ctx, client, headers, prompt, parentMessageID, chatReqs.Token, proofToken)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Conversation prepare failed: %s", err.Error()))
if chatgptAccountID := strings.TrimSpace(account.GetChatGPTAccountID()); chatgptAccountID != "" {
req.Header.Set("chatgpt-account-id", chatgptAccountID)
}

// Build simplified conversation request (no file uploads)
convReq := buildOpenAIImageTestConversationRequest(prompt, parentMessageID)
convHeaders := cloneHTTPHeader(headers)
convHeaders.Set("Accept", "text/event-stream")
convHeaders.Set("Content-Type", "application/json")
convHeaders.Set("openai-sentinel-chat-requirements-token", chatReqs.Token)
if conduitToken != "" {
convHeaders.Set("x-conduit-token", conduitToken)
}
if proofToken != "" {
convHeaders.Set("openai-sentinel-proof-token", proofToken)
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}

s.sendEvent(c, TestEvent{Type: "content", Text: "Generating image...\n"})

resp, err := client.R().
SetContext(ctx).
DisableAutoReadResponse().
SetHeaders(headerToMap(convHeaders)).
SetBodyJsonMarshal(convReq).
Post(openAIChatGPTConversationURL)
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Conversation request failed: %s", err.Error()))
return s.sendErrorAndEnd(c, fmt.Sprintf("Responses API request failed: %s", err.Error()))
}
defer func() {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}
}()
if resp.StatusCode >= 400 {
return s.sendErrorAndEnd(c, fmt.Sprintf("Conversation API returned %d", resp.StatusCode))
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
message := strings.TrimSpace(extractUpstreamErrorMessage(body))
if message == "" {
message = fmt.Sprintf("Responses API returned %d", resp.StatusCode)
}
return s.sendErrorAndEnd(c, message)
}

startTime := time.Now()
conversationID, pointerInfos, _, _, err := readOpenAIImageConversationStream(resp, startTime)
body, err := io.ReadAll(resp.Body)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read failed: %s", err.Error()))
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to read image response: %s", err.Error()))
}

pointerInfos = mergeOpenAIImagePointerInfos(pointerInfos, nil)
if conversationID != "" && !hasOpenAIFileServicePointerInfos(pointerInfos) {
s.sendEvent(c, TestEvent{Type: "content", Text: "Waiting for image generation to complete...\n"})
polledPointers, pollErr := pollOpenAIImageConversation(ctx, client, headers, conversationID)
if pollErr != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Poll failed: %s", pollErr.Error()))
}
pointerInfos = mergeOpenAIImagePointerInfos(pointerInfos, polledPointers)
results, _, _, _, _, err := collectOpenAIImagesFromResponsesBody(body)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to parse image response: %s", err.Error()))
}
pointerInfos = preferOpenAIFileServicePointerInfos(pointerInfos)
if len(pointerInfos) == 0 {
return s.sendErrorAndEnd(c, "No images returned from conversation")
if len(results) == 0 {
return s.sendErrorAndEnd(c, "No images returned from responses API")
}

s.sendEvent(c, TestEvent{Type: "content", Text: "Downloading generated image...\n"})

// Download and encode each image
for _, pointer := range pointerInfos {
downloadURL, err := fetchOpenAIImageDownloadURL(ctx, client, headers, conversationID, pointer.Pointer)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Download URL fetch failed: %s", err.Error()))
}
data, err := downloadOpenAIImageBytes(ctx, client, headers, downloadURL)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Image download failed: %s", err.Error()))
}
b64 := base64.StdEncoding.EncodeToString(data)
mimeType := http.DetectContentType(data)
if pointer.Prompt != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: pointer.Prompt})
for _, item := range results {
if item.RevisedPrompt != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: item.RevisedPrompt})
}
mimeType := openAIImageOutputMIMEType(item.OutputFormat)
s.sendEvent(c, TestEvent{
Type: "image",
ImageURL: "data:" + mimeType + ";base64," + b64,
ImageURL: "data:" + mimeType + ";base64," + item.Result,
MimeType: mimeType,
})
}
Expand All @@ -1274,107 +1236,6 @@ func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Co
return nil
}

// buildOpenAIBackendAPIHeadersForTest builds ChatGPT backend API headers for test purposes.
// Replicates the logic from OpenAIGatewayService.buildOpenAIBackendAPIHeaders without
// requiring the full gateway service dependency.
func buildOpenAIBackendAPIHeadersForTest(ctx context.Context, account *Account, token string, repo AccountRepository) http.Header {
// Ensure device and session IDs exist
deviceID := account.GetOpenAIDeviceID()
sessionID := account.GetOpenAISessionID()
if deviceID == "" || sessionID == "" {
updates := map[string]any{}
if deviceID == "" {
deviceID = uuid.NewString()
updates["openai_device_id"] = deviceID
}
if sessionID == "" {
sessionID = uuid.NewString()
updates["openai_session_id"] = sessionID
}
if account.Extra == nil {
account.Extra = map[string]any{}
}
for key, value := range updates {
account.Extra[key] = value
}
if repo != nil {
updateCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
_ = repo.UpdateExtra(updateCtx, account.ID, updates)
}
}

headers := make(http.Header)
headers.Set("Authorization", "Bearer "+token)
headers.Set("Accept", "application/json")
headers.Set("Origin", "https://chatgpt.com")
headers.Set("Referer", "https://chatgpt.com/")
headers.Set("Sec-Fetch-Dest", "empty")
headers.Set("Sec-Fetch-Mode", "cors")
headers.Set("Sec-Fetch-Site", "same-origin")
headers.Set("User-Agent", openAIImageBackendUserAgent)
if customUA := strings.TrimSpace(account.GetOpenAIUserAgent()); customUA != "" {
headers.Set("User-Agent", customUA)
}
if chatgptAccountID := strings.TrimSpace(account.GetChatGPTAccountID()); chatgptAccountID != "" {
headers.Set("chatgpt-account-id", chatgptAccountID)
}
if deviceID != "" {
headers.Set("oai-device-id", deviceID)
headers.Set("Cookie", "oai-did="+deviceID)
}
if sessionID != "" {
headers.Set("oai-session-id", sessionID)
}
return headers
}

// buildOpenAIImageTestConversationRequest creates a simplified image generation conversation request.
func buildOpenAIImageTestConversationRequest(prompt, parentMessageID string) map[string]any {
promptText := strings.TrimSpace(prompt)
if promptText == "" {
promptText = "Generate an image."
}
metadata := map[string]any{
"developer_mode_connector_ids": []any{},
"selected_github_repos": []any{},
"selected_all_github_repos": false,
"system_hints": []string{"picture_v2"},
"serialization_metadata": map[string]any{
"custom_symbol_offsets": []any{},
},
}
message := map[string]any{
"id": uuid.NewString(),
"author": map[string]any{"role": "user"},
"content": map[string]any{
"content_type": "text",
"parts": []any{promptText},
},
"metadata": metadata,
"create_time": float64(time.Now().UnixMilli()) / 1000,
}
return map[string]any{
"action": "next",
"client_prepare_state": "sent",
"parent_message_id": parentMessageID,
"messages": []any{message},
"model": "auto",
"timezone_offset_min": openAITimezoneOffsetMinutes(),
"timezone": openAITimezoneName(),
"conversation_mode": map[string]any{"kind": "primary_assistant"},
"system_hints": []string{"picture_v2"},
"supports_buffering": true,
"supported_encodings": []string{"v1"},
"client_contextual_info": map[string]any{"app_name": "chatgpt.com"},
"force_nulligen": false,
"force_paragen": false,
"force_paragen_model_slug": "",
"force_rate_limit": false,
"websocket_request_id": uuid.NewString(),
}
}

func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) {
eventJSON, _ := json.Marshal(event)
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON); err != nil {
Expand Down
Loading
Loading