From 9e44200c6569d6983b89e6eedeb7eaef66cdde7b Mon Sep 17 00:00:00 2001 From: phantom5099 <1011668688@qq.com> Date: Sat, 25 Apr 2026 12:50:58 +0800 Subject: [PATCH 1/8] =?UTF-8?q?refactor(provider):=20=E5=B0=86=E9=87=8D?= =?UTF-8?q?=E8=AF=95=E9=80=BB=E8=BE=91=E5=86=85=E8=81=9A=E5=88=B0=20provid?= =?UTF-8?q?er=20=E5=B1=82=E5=B9=B6=E5=BB=B6=E9=95=BF=E8=AF=B7=E6=B1=82?= =?UTF-8?q?=E6=97=B6=E9=97=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/runtime-provider-event-flow.md | 1 - internal/provider/anthropic/driver.go | 3 +- internal/provider/constants.go | 5 + internal/provider/gemini/driver.go | 3 +- internal/provider/gemini/provider.go | 151 ++++++++++++-- internal/provider/gemini/provider_test.go | 190 ++++++++++++++++++ internal/provider/openaicompat/provider.go | 3 +- internal/runtime/events.go | 1 - internal/runtime/run.go | 90 +++------ .../runtime_remaining_branches_test.go | 31 ++- internal/runtime/runtime_test.go | 95 +-------- internal/tui/core/app/update.go | 10 - internal/tui/core/app/update_test.go | 8 - .../tui/services/gateway_stream_client.go | 2 +- internal/tui/services/runtime_contract.go | 1 - 15 files changed, 382 insertions(+), 212 deletions(-) diff --git a/docs/runtime-provider-event-flow.md b/docs/runtime-provider-event-flow.md index d0966bd1..5b9a12e0 100644 --- a/docs/runtime-provider-event-flow.md +++ b/docs/runtime-provider-event-flow.md @@ -14,7 +14,6 @@ - `phase_changed` - `progress_evaluated` - `stop_reason_decided` -- `provider_retry` - `permission_requested` - `permission_resolved` - `budget_checked` diff --git a/internal/provider/anthropic/driver.go b/internal/provider/anthropic/driver.go index a5ec554e..95e551c5 100644 --- a/internal/provider/anthropic/driver.go +++ b/internal/provider/anthropic/driver.go @@ -5,7 +5,6 @@ import ( "fmt" "net/http" "strings" - "time" anthropic "github.com/anthropics/anthropic-sdk-go" anthroption "github.com/anthropics/anthropic-sdk-go/option" @@ -66,7 +65,7 @@ func newSDKClient(cfg provider.RuntimeConfig) (anthropic.Client, error) { } httpClient := &http.Client{ - Timeout: 90 * time.Second, + Timeout: provider.DefaultSDKRequestTimeout, } options := []anthroption.RequestOption{ anthroption.WithHTTPClient(httpClient), diff --git a/internal/provider/constants.go b/internal/provider/constants.go index e0b4b62a..8bc65cae 100644 --- a/internal/provider/constants.go +++ b/internal/provider/constants.go @@ -1,5 +1,7 @@ package provider +import "time" + // Driver 与 OpenAI-compatible 协议常量用于在 config/provider 间共享稳定枚举值,避免字面量漂移。 const ( DriverOpenAICompat = "openaicompat" @@ -8,3 +10,6 @@ const ( DiscoveryEndpointPathModels = "/models" ) + +// DefaultSDKRequestTimeout 定义 provider 层对外部模型 SDK 请求的统一超时,避免流式请求无限悬挂。 +const DefaultSDKRequestTimeout = 10 * time.Minute diff --git a/internal/provider/gemini/driver.go b/internal/provider/gemini/driver.go index 0c627970..10a90b93 100644 --- a/internal/provider/gemini/driver.go +++ b/internal/provider/gemini/driver.go @@ -5,7 +5,6 @@ import ( "fmt" "net/http" "strings" - "time" "google.golang.org/genai" @@ -71,7 +70,7 @@ func newSDKClient(ctx context.Context, cfg provider.RuntimeConfig) (*genai.Clien return nil, err } httpClient := &http.Client{ - Timeout: 90 * time.Second, + Timeout: provider.DefaultSDKRequestTimeout, } clientConfig := &genai.ClientConfig{ APIKey: apiKey, diff --git a/internal/provider/gemini/provider.go b/internal/provider/gemini/provider.go index 017ce282..3a6568a5 100644 --- a/internal/provider/gemini/provider.go +++ b/internal/provider/gemini/provider.go @@ -5,9 +5,12 @@ import ( "encoding/json" "errors" "fmt" + "math/rand/v2" + "net" "net/http" "strings" "sync" + "time" "google.golang.org/genai" @@ -17,12 +20,21 @@ import ( const errorPrefix = "gemini provider: " +const ( + defaultGenerateRetryMax = 2 + generateRetryBaseWait = 1 * time.Second + generateRetryMaxWait = 5 * time.Second +) + // Provider 封装 Gemini native 协议的请求发送与流式响应解析。 type Provider struct { cfg provider.RuntimeConfig mu sync.Mutex prepared *preparedRequest + + retryBackoff func(attempt int) time.Duration + retryWait func(ctx context.Context, wait time.Duration) error } type preparedRequest struct { @@ -67,7 +79,11 @@ func New(cfg provider.RuntimeConfig) (*Provider, error) { if strings.TrimSpace(cfg.APIKeyEnv) == "" { return nil, errors.New(errorPrefix + "api_key_env is empty") } - return &Provider{cfg: cfg}, nil + return &Provider{ + cfg: cfg, + retryBackoff: generateRetryBackoff, + retryWait: waitForRetry, + }, nil } // Generate 发起 Gemini 流式请求,并将 SDK chunk 转为统一流式事件。 @@ -84,9 +100,46 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque if normalizedModel == "" { return errors.New(errorPrefix + "model is empty") } + var lastErr error + for attempt := 0; attempt <= defaultGenerateRetryMax; attempt++ { + if attempt > 0 { + wait := generateRetryBaseWait + if p.retryBackoff != nil { + wait = p.retryBackoff(attempt) + } + if p.retryWait != nil { + if err := p.retryWait(ctx, wait); err != nil { + return err + } + } + } + + started, err := p.generateOnce(ctx, normalizedModel, contents, config, events) + if err == nil { + return nil + } + lastErr = err + if started || !isRetryableGenerateError(err) { + return err + } + if ctx.Err() != nil { + return ctx.Err() + } + } + return lastErr +} + +// generateOnce 执行一次 Gemini 流式请求,并在未收到任何输出时返回可重试错误。 +func (p *Provider) generateOnce( + ctx context.Context, + model string, + contents []*genai.Content, + config *genai.GenerateContentConfig, + events chan<- providertypes.StreamEvent, +) (bool, error) { client, err := newSDKClient(ctx, p.cfg) if err != nil { - return err + return false, err } var ( @@ -95,15 +148,12 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque hasPayload bool callSeq int ) - for chunk, streamErr := range client.Models.GenerateContentStream(ctx, normalizedModel, contents, config) { + for chunk, streamErr := range client.Models.GenerateContentStream(ctx, model, contents, config) { if streamErr != nil { if ctxErr := ctx.Err(); ctxErr != nil { - return ctxErr + return hasPayload, ctxErr } - if mappedErr := mapGeminiSDKError(streamErr); mappedErr != nil { - return mappedErr - } - return fmt.Errorf("%sstream generate: %w", errorPrefix, streamErr) + return hasPayload, normalizeGenerateError(streamErr) } if chunk == nil { @@ -125,7 +175,7 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque } if strings.TrimSpace(part.Text) != "" { if err := provider.EmitTextDelta(ctx, events, part.Text); err != nil { - return err + return hasPayload, err } } if part.FunctionCall == nil { @@ -142,25 +192,25 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque continue } if err := provider.EmitToolCallStart(ctx, events, callSeq-1, callID, name); err != nil { - return err + return hasPayload, err } argsJSON, err := encodeArguments(part.FunctionCall.Args) if err != nil { - return err + return hasPayload, err } if err := provider.EmitToolCallDelta(ctx, events, callSeq-1, callID, argsJSON); err != nil { - return err + return hasPayload, err } } } } if !hasPayload { - return fmt.Errorf("%w: empty gemini stream payload", provider.ErrStreamInterrupted) + return false, fmt.Errorf("%w: empty gemini stream payload", provider.ErrStreamInterrupted) } if !usage.InputObserved && !usage.OutputObserved { - return provider.EmitMessageDone(ctx, events, finishReason, nil) + return true, provider.EmitMessageDone(ctx, events, finishReason, nil) } - return provider.EmitMessageDone(ctx, events, finishReason, &usage) + return true, provider.EmitMessageDone(ctx, events, finishReason, &usage) } // storePreparedRequest 缓存估算阶段的 Gemini 构建结果,供同轮发送直接复用。 @@ -233,6 +283,77 @@ func normalizeFinishReason(raw string) string { return strings.ToLower(strings.TrimSpace(raw)) } +// normalizeGenerateError 统一归类 Gemini 流式生成错误,避免把网络异常直接泄漏到 runtime。 +func normalizeGenerateError(err error) error { + if mappedErr := mapGeminiSDKError(err); mappedErr != nil { + return mappedErr + } + + message := strings.TrimSpace(err.Error()) + if message == "" { + message = "unknown stream error" + } + if isTimeoutGenerateError(err) { + return provider.NewTimeoutProviderError("gemini generate timeout: " + message) + } + + var netErr net.Error + if errors.As(err, &netErr) { + return provider.NewNetworkProviderError("gemini generate network error: " + message) + } + return fmt.Errorf("%sstream generate: %w", errorPrefix, err) +} + +// isRetryableGenerateError 判断 Gemini provider 是否应在请求级重试当前错误。 +func isRetryableGenerateError(err error) bool { + if err == nil { + return false + } + var providerErr *provider.ProviderError + return errors.As(err, &providerErr) && providerErr.Retryable +} + +// isTimeoutGenerateError 判断 Gemini 流式生成错误是否由超时触发。 +func isTimeoutGenerateError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, context.DeadlineExceeded) { + return true + } + var netErr net.Error + return errors.As(err, &netErr) && netErr.Timeout() +} + +// generateRetryBackoff 计算 Gemini provider 请求级重试的指数退避时长。 +func generateRetryBackoff(attempt int) time.Duration { + if attempt <= 0 { + return 0 + } + wait := generateRetryBaseWait << (attempt - 1) + jitter := float64(wait) * (0.5 + rand.Float64()) + wait = time.Duration(jitter) + if wait > generateRetryMaxWait { + wait = generateRetryMaxWait + } + return wait +} + +// waitForRetry 在重试窗口内等待,同时尊重上层上下文取消。 +func waitForRetry(ctx context.Context, wait time.Duration) error { + if wait <= 0 { + return nil + } + timer := time.NewTimer(wait) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} + // mapGeminiSDKError 将 Gemini SDK 错误映射为 provider 领域错误,仅保留状态码级别兜底。 func mapGeminiSDKError(err error) error { var apiErr genai.APIError diff --git a/internal/provider/gemini/provider_test.go b/internal/provider/gemini/provider_test.go index 8109fe8c..e2d2c5e8 100644 --- a/internal/provider/gemini/provider_test.go +++ b/internal/provider/gemini/provider_test.go @@ -3,12 +3,15 @@ package gemini import ( "bytes" "context" + "errors" "fmt" "io" + "net" "net/http" "net/http/httptest" "strings" "testing" + "time" "neo-code/internal/provider" providertypes "neo-code/internal/provider/types" @@ -322,6 +325,177 @@ func TestEstimateThenGenerateReusesPreparedRequest(t *testing.T) { } } +func TestProviderGenerateRetriesRetryableErrorBeforeStreamStarts(t *testing.T) { + t.Parallel() + + requests := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requests++ + if requests == 1 { + http.Error(w, "temporary", http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "text/event-stream") + _, _ = fmt.Fprint(w, "data: {\"candidates\":[{\"index\":0,\"content\":{\"parts\":[{\"text\":\"retry ok\"}]}}],\"usageMetadata\":{\"promptTokenCount\":2,\"candidatesTokenCount\":1,\"totalTokenCount\":3}}\n\n") + _, _ = fmt.Fprint(w, "data: {\"candidates\":[{\"index\":0,\"finishReason\":\"STOP\",\"content\":{\"parts\":[]}}]}\n\n") + })) + defer server.Close() + + p, err := New(provider.RuntimeConfig{ + Driver: provider.DriverGemini, + BaseURL: server.URL, + DefaultModel: "gemini-2.5-flash", + APIKeyEnv: "GEMINI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + }) + if err != nil { + t.Fatalf("New() error = %v", err) + } + p.retryBackoff = func(attempt int) time.Duration { + _ = attempt + return 0 + } + p.retryWait = func(ctx context.Context, wait time.Duration) error { + _ = ctx + _ = wait + return nil + } + + events := make(chan providertypes.StreamEvent, 8) + if err := p.Generate(context.Background(), providertypes.GenerateRequest{ + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("hi")}, + }}, + }, events); err != nil { + t.Fatalf("Generate() error = %v", err) + } + if requests != 2 { + t.Fatalf("expected 2 requests after retry, got %d", requests) + } + if len(drainEvents(events)) == 0 { + t.Fatal("expected stream events after retry recovery") + } +} + +func TestProviderGenerateReturnsRetryableErrorAfterRetryExhausted(t *testing.T) { + t.Parallel() + + requests := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requests++ + http.Error(w, "temporary", http.StatusInternalServerError) + })) + defer server.Close() + + p, err := New(provider.RuntimeConfig{ + Driver: provider.DriverGemini, + BaseURL: server.URL, + DefaultModel: "gemini-2.5-flash", + APIKeyEnv: "GEMINI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + }) + if err != nil { + t.Fatalf("New() error = %v", err) + } + p.retryBackoff = func(attempt int) time.Duration { + _ = attempt + return 0 + } + p.retryWait = func(ctx context.Context, wait time.Duration) error { + _ = ctx + _ = wait + return nil + } + + events := make(chan providertypes.StreamEvent, 8) + err = p.Generate(context.Background(), providertypes.GenerateRequest{ + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("hi")}, + }}, + }, events) + if err == nil { + t.Fatal("expected retryable error after exhausting retries") + } + if requests != defaultGenerateRetryMax+1 { + t.Fatalf("expected %d requests, got %d", defaultGenerateRetryMax+1, requests) + } + + var providerErr *provider.ProviderError + if !errors.As(err, &providerErr) || !providerErr.Retryable { + t.Fatalf("expected retryable provider error, got %v", err) + } +} + +func TestProviderGenerateRetryStateResetsAfterSuccess(t *testing.T) { + t.Parallel() + + requests := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requests++ + if requests == 1 { + http.Error(w, "temporary", http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "text/event-stream") + _, _ = fmt.Fprint(w, "data: {\"candidates\":[{\"index\":0,\"content\":{\"parts\":[{\"text\":\"ok\"}]}}],\"usageMetadata\":{\"promptTokenCount\":2,\"candidatesTokenCount\":1,\"totalTokenCount\":3}}\n\n") + _, _ = fmt.Fprint(w, "data: {\"candidates\":[{\"index\":0,\"finishReason\":\"STOP\",\"content\":{\"parts\":[]}}]}\n\n") + })) + defer server.Close() + + p, err := New(provider.RuntimeConfig{ + Driver: provider.DriverGemini, + BaseURL: server.URL, + DefaultModel: "gemini-2.5-flash", + APIKeyEnv: "GEMINI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + }) + if err != nil { + t.Fatalf("New() error = %v", err) + } + p.retryBackoff = func(attempt int) time.Duration { + _ = attempt + return 0 + } + p.retryWait = func(ctx context.Context, wait time.Duration) error { + _ = ctx + _ = wait + return nil + } + + for i := 0; i < 2; i++ { + events := make(chan providertypes.StreamEvent, 8) + if err := p.Generate(context.Background(), providertypes.GenerateRequest{ + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("hi")}, + }}, + }, events); err != nil { + t.Fatalf("Generate() call %d error = %v", i+1, err) + } + } + if requests != 3 { + t.Fatalf("expected second request to start from a fresh retry window, got %d total requests", requests) + } +} + +func TestNormalizeGenerateErrorMapsNetworkTimeouts(t *testing.T) { + t.Parallel() + + timeoutErr := timeoutNetError{message: "i/o timeout"} + err := normalizeGenerateError(timeoutErr) + var providerErr *provider.ProviderError + if !errors.As(err, &providerErr) || providerErr.Code != provider.ErrorCodeTimeout { + t.Fatalf("expected timeout provider error, got %v", err) + } + + err = normalizeGenerateError(net.UnknownNetworkError("dns failure")) + if !errors.As(err, &providerErr) || providerErr.Code != provider.ErrorCodeNetwork { + t.Fatalf("expected network provider error, got %v", err) + } +} + func drainEvents(events <-chan providertypes.StreamEvent) []providertypes.StreamEvent { var drained []providertypes.StreamEvent for { @@ -340,6 +514,22 @@ type stubSessionAsset struct { err error } +type timeoutNetError struct { + message string +} + +func (e timeoutNetError) Error() string { + return e.message +} + +func (e timeoutNetError) Timeout() bool { + return true +} + +func (e timeoutNetError) Temporary() bool { + return true +} + type stubSessionAssetReader struct { assets map[string]stubSessionAsset openCount int diff --git a/internal/provider/openaicompat/provider.go b/internal/provider/openaicompat/provider.go index d7680cfa..affa532f 100644 --- a/internal/provider/openaicompat/provider.go +++ b/internal/provider/openaicompat/provider.go @@ -7,7 +7,6 @@ import ( "net/http" "strings" "sync" - "time" "neo-code/internal/provider" "neo-code/internal/provider/openaicompat/chatcompletions" @@ -131,7 +130,7 @@ func New(cfg provider.RuntimeConfig, opts ...buildOption) (*Provider, error) { return &Provider{ cfg: cfg, client: &http.Client{ - Timeout: 90 * time.Second, + Timeout: provider.DefaultSDKRequestTimeout, Transport: o.transport, }, }, nil diff --git a/internal/runtime/events.go b/internal/runtime/events.go index 2471e8f1..8b5fa223 100644 --- a/internal/runtime/events.go +++ b/internal/runtime/events.go @@ -247,7 +247,6 @@ const ( // EventToolCallThinking 表示模型发起工具调用思考阶段。 EventToolCallThinking EventType = "tool_call_thinking" // EventProviderRetry 表示 provider 调用重试。 - EventProviderRetry EventType = "provider_retry" // EventPermissionRequested 表示发起权限请求。 EventPermissionRequested EventType = "permission_requested" // EventPermissionResolved 表示权限请求已决议。 diff --git a/internal/runtime/run.go b/internal/runtime/run.go index efff1116..ea1ed85e 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -183,7 +183,7 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { return nil } - turnOutput, err := s.callProviderWithRetry(ctx, &state, snapshot, modelProvider) + turnOutput, err := s.callProvider(ctx, &state, snapshot, modelProvider) if err != nil { if provider.IsContextTooLong(err) && state.reactiveCompactAttempts < snapshot.Config.Context.Budget.MaxReactiveCompacts { @@ -479,73 +479,35 @@ func resolveRuntimeMaxTurns(rc config.RuntimeConfig) int { return rc.MaxTurns } -// callProviderWithRetry 使用冻结后的 TurnBudgetSnapshot 执行 provider 调用与必要重试。 -func (s *Service) callProviderWithRetry( +// callProvider 使用冻结后的 TurnBudgetSnapshot 执行单次 provider 调用。 +func (s *Service) callProvider( ctx context.Context, state *runState, snapshot TurnBudgetSnapshot, - initialProvider provider.Provider, + modelProvider provider.Provider, ) (turnProviderOutput, error) { - var lastErr error - - for retryAttempt := 0; retryAttempt <= defaultProviderRetryMax; retryAttempt++ { - if retryAttempt > 0 { - wait := providerRetryBackoff(retryAttempt) - s.emitRunScoped(ctx, EventProviderRetry, state, - fmt.Sprintf("retrying provider call (attempt %d/%d, wait=%.1fs)...", - retryAttempt, defaultProviderRetryMax, wait.Seconds())) - - select { - case <-ctx.Done(): - return turnProviderOutput{}, ctx.Err() - case <-time.After(wait): - } - } - - modelProvider := initialProvider - if retryAttempt > 0 { - var err error - modelProvider, err = s.providerFactory.Build(ctx, snapshot.ProviderConfig) - if err != nil { - return turnProviderOutput{}, err - } - } - - streamOutcome := generateStreamingMessage(ctx, modelProvider, snapshot.Request, streaming.Hooks{ - OnTextDelta: func(text string) { - s.emitRunScoped(ctx, EventAgentChunk, state, text) - }, - OnToolCallStart: func(payload providertypes.ToolCallStartPayload) { - s.emitRunScoped(ctx, EventToolCallThinking, state, payload.Name) - }, - }) - if streamOutcome.err != nil { - lastErr = streamOutcome.err - if !isRetryableProviderError(lastErr) { - return turnProviderOutput{}, lastErr - } - if ctx.Err() != nil { - return turnProviderOutput{}, ctx.Err() - } - continue - } - - return turnProviderOutput{ - assistant: streamOutcome.message, - usageObservation: newTurnBudgetUsageObservation( - snapshot.ID, - streamOutcome.inputTokens, - streamOutcome.outputTokens, - streamOutcome.inputObserved, - streamOutcome.outputObserved, - ), - }, nil - } - - if lastErr == nil { - lastErr = errors.New("max retries exceeded") - } - return turnProviderOutput{}, fmt.Errorf("runtime: max retries exhausted, last error: %w", lastErr) + streamOutcome := generateStreamingMessage(ctx, modelProvider, snapshot.Request, streaming.Hooks{ + OnTextDelta: func(text string) { + s.emitRunScoped(ctx, EventAgentChunk, state, text) + }, + OnToolCallStart: func(payload providertypes.ToolCallStartPayload) { + s.emitRunScoped(ctx, EventToolCallThinking, state, payload.Name) + }, + }) + if streamOutcome.err != nil { + return turnProviderOutput{}, streamOutcome.err + } + + return turnProviderOutput{ + assistant: streamOutcome.message, + usageObservation: newTurnBudgetUsageObservation( + snapshot.ID, + streamOutcome.inputTokens, + streamOutcome.outputTokens, + streamOutcome.inputObserved, + streamOutcome.outputObserved, + ), + }, nil } // emitTokenUsage 在单轮 provider 调用成功后发出 token_usage 事件。 diff --git a/internal/runtime/runtime_remaining_branches_test.go b/internal/runtime/runtime_remaining_branches_test.go index 53400e1f..ec084c1d 100644 --- a/internal/runtime/runtime_remaining_branches_test.go +++ b/internal/runtime/runtime_remaining_branches_test.go @@ -619,24 +619,16 @@ func TestRunAndProviderRetryRemainingBranches(t *testing.T) { } }) - t.Run("callProviderWithRetry exits on context done during backoff wait", func(t *testing.T) { + t.Run("callProvider returns context error from provider", func(t *testing.T) { t.Parallel() ctx, cancel := context.WithCancel(context.Background()) - firstCallDone := make(chan struct{}, 1) providerRetry := &scriptedProvider{chatFn: func(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { - select { - case firstCallDone <- struct{}{}: - default: - } - return &provider.ProviderError{StatusCode: 500, Code: provider.ErrorCodeServer, Message: "retry", Retryable: true} - }} - go func() { - <-firstCallDone cancel() - }() + return ctx.Err() + }} service := &Service{providerFactory: &scriptedProviderFactory{provider: providerRetry}, events: make(chan RuntimeEvent, 8)} state := newRunState("run-retry-backoff", newRuntimeSession("session-retry-backoff")) - _, err := service.callProviderWithRetry( + _, err := service.callProvider( ctx, &state, TurnBudgetSnapshot{ProviderConfig: provider.RuntimeConfig{Name: "x"}}, @@ -647,23 +639,24 @@ func TestRunAndProviderRetryRemainingBranches(t *testing.T) { } }) - t.Run("callProviderWithRetry checks ctx after retryable stream error", func(t *testing.T) { + t.Run("callProvider returns retryable provider error without retry", func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithCancel(context.Background()) providerRetry := &scriptedProvider{chatFn: func(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { - cancel() return &provider.ProviderError{StatusCode: 500, Code: provider.ErrorCodeServer, Message: "retry", Retryable: true} }} service := &Service{providerFactory: &scriptedProviderFactory{provider: providerRetry}, events: make(chan RuntimeEvent, 8)} state := newRunState("run-retry-ctx-check", newRuntimeSession("session-retry-ctx-check")) - _, err := service.callProviderWithRetry( - ctx, + _, err := service.callProvider( + context.Background(), &state, TurnBudgetSnapshot{ProviderConfig: provider.RuntimeConfig{Name: "x"}}, providerRetry, ) - if !errors.Is(err, context.Canceled) { - t.Fatalf("expected context.Canceled, got %v", err) + if err == nil || !containsError(err, "retry") { + t.Fatalf("expected retryable provider error, got %v", err) + } + if providerRetry.callCount != 1 { + t.Fatalf("expected provider to be called once, got %d", providerRetry.callCount) } }) } diff --git a/internal/runtime/runtime_test.go b/internal/runtime/runtime_test.go index bbfd0df8..e2bd24a7 100644 --- a/internal/runtime/runtime_test.go +++ b/internal/runtime/runtime_test.go @@ -2350,18 +2350,12 @@ func TestServiceRunErrorPaths(t *testing.T) { }, } }(), - expectEvents: []EventType{EventUserMessage, EventProviderRetry, EventAgentDone}, + expectErr: "internal server error", + expectEvents: []EventType{EventUserMessage, EventStopReasonDecided}, assert: func(t *testing.T, store *memoryStore, scripted *scriptedProvider, tool *stubTool) { t.Helper() - if scripted.callCount < 2 { - t.Fatalf("expected at least 2 provider calls (initial + retry), got %d", scripted.callCount) - } - session := onlySession(t, store) - if len(session.Messages) != 2 { - t.Fatalf("expected user + assistant messages, got %d", len(session.Messages)) - } - if renderPartsForTest(session.Messages[1].Parts) != "recovered" { - t.Fatalf("expected assistant content %q, got %q", "recovered", renderPartsForTest(session.Messages[1].Parts)) + if scripted.callCount != 1 { + t.Fatalf("expected runtime to call provider once, got %d", scripted.callCount) } }, }, @@ -2389,7 +2383,7 @@ func TestServiceRunErrorPaths(t *testing.T) { }, }, { - name: "runtime retry exhausted emits error", + name: "retryable provider error does not trigger runtime retry", input: UserInput{RunID: "run-retry-exhausted", Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}}, provider: &scriptedProvider{ name: "always-500", @@ -2403,13 +2397,11 @@ func TestServiceRunErrorPaths(t *testing.T) { }, }, expectErr: "internal server error", - expectEvents: []EventType{EventUserMessage, EventProviderRetry, EventStopReasonDecided}, + expectEvents: []EventType{EventUserMessage, EventStopReasonDecided}, assert: func(t *testing.T, store *memoryStore, scripted *scriptedProvider, tool *stubTool) { t.Helper() - // 1 initial + 2 retries = 3 calls - if scripted.callCount != defaultProviderRetryMax+1 { - t.Fatalf("expected %d provider calls (1 initial + %d retries), got %d", - defaultProviderRetryMax+1, defaultProviderRetryMax, scripted.callCount) + if scripted.callCount != 1 { + t.Fatalf("expected runtime not to retry provider calls, got %d", scripted.callCount) } }, }, @@ -3785,75 +3777,6 @@ func TestWorkdirHelperFunctions(t *testing.T) { }) } -func TestIsRetryableProviderError(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - err error - want bool - }{ - {"nil error", nil, false}, - {"retryable provider error", &provider.ProviderError{Retryable: true}, true}, - {"non-retryable provider error", &provider.ProviderError{Retryable: false}, false}, - {"plain error", errors.New("something failed"), false}, - {"wrapped retryable", fmt.Errorf("wrapped: %w", &provider.ProviderError{Retryable: true}), true}, - {"stream interrupted sentinel", provider.ErrStreamInterrupted, false}, - } - - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - if got := isRetryableProviderError(tt.err); got != tt.want { - t.Fatalf("isRetryableProviderError() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestProviderRetryBackoff(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - attempt int - min time.Duration - max time.Duration - }{ - { - name: "first retry stays within jittered base window", - attempt: 1, - min: 500 * time.Millisecond, - max: 1500 * time.Millisecond, - }, - { - name: "second retry stays within jittered doubled window", - attempt: 2, - min: 1 * time.Second, - max: 3 * time.Second, - }, - { - name: "large retry is capped at max wait", - attempt: 20, - min: providerRetryMaxWait, - max: providerRetryMaxWait, - }, - } - - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - got := providerRetryBackoff(tt.attempt) - if got < tt.min || got > tt.max { - t.Fatalf("providerRetryBackoff(%d) = %v, want within [%v, %v]", tt.attempt, got, tt.min, tt.max) - } - }) - } -} - func TestStreamAccumulatorBuildMessageRejectsMissingToolName(t *testing.T) { t.Parallel() @@ -4014,7 +3937,7 @@ func TestCallProviderWithRetryReturnsCombinedForwardError(t *testing.T) { }, } - _, err := service.callProviderWithRetry( + _, err := service.callProvider( context.Background(), &state, snapshot, diff --git a/internal/tui/core/app/update.go b/internal/tui/core/app/update.go index b0c18e9b..79a85074 100644 --- a/internal/tui/core/app/update.go +++ b/internal/tui/core/app/update.go @@ -1597,7 +1597,6 @@ var runtimeEventHandlerRegistry = map[tuiservices.EventType]func(*App, tuiservic tuiservices.EventAgentChunk: runtimeEventAgentChunkHandler, tuiservices.EventToolChunk: runtimeEventToolChunkHandler, tuiservices.EventAgentDone: runtimeEventAgentDoneHandler, - tuiservices.EventProviderRetry: runtimeEventProviderRetryHandler, tuiservices.EventPermissionRequested: runtimeEventPermissionRequestHandler, tuiservices.EventPermissionResolved: runtimeEventPermissionResolvedHandler, tuiservices.EventCompactApplied: runtimeEventCompactDoneHandler, @@ -2268,15 +2267,6 @@ func runtimeEventErrorHandler(a *App, event tuiservices.RuntimeEvent) bool { return false } -func runtimeEventProviderRetryHandler(a *App, event tuiservices.RuntimeEvent) bool { - if payload, ok := event.Payload.(string); ok && strings.TrimSpace(payload) != "" { - a.state.StatusText = statusThinking - a.runProgressKnown = false - a.appendActivity("provider", "Retrying provider call", payload, false) - } - return false -} - func runtimeEventPermissionRequestHandler(a *App, event tuiservices.RuntimeEvent) bool { payload, ok := parsePermissionRequestPayload(event.Payload) if !ok { diff --git a/internal/tui/core/app/update_test.go b/internal/tui/core/app/update_test.go index f02bc44f..9defda9b 100644 --- a/internal/tui/core/app/update_test.go +++ b/internal/tui/core/app/update_test.go @@ -1839,14 +1839,6 @@ func TestRuntimeEventErrorHandler(t *testing.T) { } } -func TestRuntimeEventProviderRetryHandler(t *testing.T) { - app, _ := newTestApp(t) - runtimeEventProviderRetryHandler(&app, agentruntime.RuntimeEvent{Payload: "retry"}) - if app.state.StatusText != statusThinking { - t.Fatalf("expected thinking status") - } -} - func TestRuntimeEventCompactDoneHandler(t *testing.T) { app, _ := newTestApp(t) payload := agentruntime.CompactResult{TriggerMode: "auto", SavedRatio: 0.5, BeforeChars: 10, AfterChars: 5, TranscriptPath: "path"} diff --git a/internal/tui/services/gateway_stream_client.go b/internal/tui/services/gateway_stream_client.go index 1f9274de..c97dbfd9 100644 --- a/internal/tui/services/gateway_stream_client.go +++ b/internal/tui/services/gateway_stream_client.go @@ -241,7 +241,7 @@ func restoreRuntimePayload(eventType EventType, payload any) (any, error) { return decodeRuntimePayload[RuntimeToolStatusPayload](payload) case EventType(RuntimeEventUsage): return decodeRuntimePayload[RuntimeUsagePayload](payload) - case EventAgentChunk, EventToolChunk, EventError, EventProviderRetry, EventToolCallThinking: + case EventAgentChunk, EventToolChunk, EventError, EventToolCallThinking: return decodeStringPayload(payload), nil default: return payload, nil diff --git a/internal/tui/services/runtime_contract.go b/internal/tui/services/runtime_contract.go index 16524502..18fa6d96 100644 --- a/internal/tui/services/runtime_contract.go +++ b/internal/tui/services/runtime_contract.go @@ -321,7 +321,6 @@ const ( EventRunCanceled EventType = "run_canceled" EventError EventType = "error" EventToolCallThinking EventType = "tool_call_thinking" - EventProviderRetry EventType = "provider_retry" EventPermissionRequested EventType = "permission_requested" EventPermissionResolved EventType = "permission_resolved" EventCompactStart EventType = "compact_start" From a169d52258964cf75b764ca3f9a538d6dcb3ae5f Mon Sep 17 00:00:00 2001 From: phantom5099 <1011668688@qq.com> Date: Sat, 25 Apr 2026 13:11:13 +0800 Subject: [PATCH 2/8] =?UTF-8?q?=E5=88=A0=E9=99=A4=E5=86=97=E4=BD=99?= =?UTF-8?q?=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/runtime/events.go | 1 - 1 file changed, 1 deletion(-) diff --git a/internal/runtime/events.go b/internal/runtime/events.go index 8b5fa223..ff1fa660 100644 --- a/internal/runtime/events.go +++ b/internal/runtime/events.go @@ -246,7 +246,6 @@ const ( EventError EventType = "error" // EventToolCallThinking 表示模型发起工具调用思考阶段。 EventToolCallThinking EventType = "tool_call_thinking" - // EventProviderRetry 表示 provider 调用重试。 // EventPermissionRequested 表示发起权限请求。 EventPermissionRequested EventType = "permission_requested" // EventPermissionResolved 表示权限请求已决议。 From 07a69f11748e3a60c7c44e495c46bea3aa416f36 Mon Sep 17 00:00:00 2001 From: xgopilot Date: Sat, 25 Apr 2026 05:15:59 +0000 Subject: [PATCH 3/8] test(gemini): improve retry/error branch coverage Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: phantom5099 <245659304+phantom5099@users.noreply.github.com> --- internal/provider/gemini/provider_test.go | 218 ++++++++++++++++++++++ 1 file changed, 218 insertions(+) diff --git a/internal/provider/gemini/provider_test.go b/internal/provider/gemini/provider_test.go index e2d2c5e8..1e540d0f 100644 --- a/internal/provider/gemini/provider_test.go +++ b/internal/provider/gemini/provider_test.go @@ -13,6 +13,8 @@ import ( "testing" "time" + "google.golang.org/genai" + "neo-code/internal/provider" providertypes "neo-code/internal/provider/types" ) @@ -496,6 +498,222 @@ func TestNormalizeGenerateErrorMapsNetworkTimeouts(t *testing.T) { } } +func TestProviderGenerateReturnsRetryWaitError(t *testing.T) { + t.Parallel() + + requests := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requests++ + http.Error(w, "temporary", http.StatusInternalServerError) + })) + defer server.Close() + + p, err := New(provider.RuntimeConfig{ + Driver: provider.DriverGemini, + BaseURL: server.URL, + DefaultModel: "gemini-2.5-flash", + APIKeyEnv: "GEMINI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + }) + if err != nil { + t.Fatalf("New() error = %v", err) + } + sentinel := errors.New("retry wait failed") + p.retryBackoff = func(attempt int) time.Duration { + _ = attempt + return 0 + } + p.retryWait = func(ctx context.Context, wait time.Duration) error { + _ = ctx + _ = wait + return sentinel + } + + events := make(chan providertypes.StreamEvent, 8) + err = p.Generate(context.Background(), providertypes.GenerateRequest{ + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("hi")}, + }}, + }, events) + if !errors.Is(err, sentinel) { + t.Fatalf("expected retry wait error, got %v", err) + } + if requests != 1 { + t.Fatalf("expected one request before retry wait failure, got %d", requests) + } +} + +func TestProviderGenerateReturnsEmptyModelForPreparedRequest(t *testing.T) { + t.Parallel() + + p, err := New(provider.RuntimeConfig{ + Driver: provider.DriverGemini, + BaseURL: "https://generativelanguage.googleapis.com/v1beta", + DefaultModel: "gemini-2.5-flash", + APIKeyEnv: "GEMINI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + }) + if err != nil { + t.Fatalf("New() error = %v", err) + } + + req := providertypes.GenerateRequest{ + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("hi")}, + }}, + } + signature := provider.BuildGenerateRequestSignature(req) + p.storePreparedRequest(signature, " ", nil, nil) + + events := make(chan providertypes.StreamEvent, 4) + err = p.Generate(context.Background(), req, events) + if err == nil || !strings.Contains(err.Error(), "model is empty") { + t.Fatalf("expected model empty error, got %v", err) + } +} + +func TestGenerateHelpers(t *testing.T) { + t.Parallel() + + t.Run("normalize_model", func(t *testing.T) { + t.Parallel() + if got := normalizeGeminiModelName(" models/gemini-2.5-pro "); got != "gemini-2.5-pro" { + t.Fatalf("normalizeGeminiModelName() = %q", got) + } + if got := normalizeGeminiModelName(" "); got != "" { + t.Fatalf("normalizeGeminiModelName() = %q, want empty", got) + } + }) + + t.Run("encode_arguments", func(t *testing.T) { + t.Parallel() + encoded, err := encodeArguments(nil) + if err != nil || encoded != "{}" { + t.Fatalf("encodeArguments(nil) = %q, %v", encoded, err) + } + _, err = encodeArguments(map[string]any{"bad": make(chan int)}) + if err == nil || !strings.Contains(err.Error(), "encode function args") { + t.Fatalf("expected encode error, got %v", err) + } + }) + + t.Run("retryable_error", func(t *testing.T) { + t.Parallel() + if isRetryableGenerateError(nil) { + t.Fatal("nil should not be retryable") + } + if isRetryableGenerateError(errors.New("plain")) { + t.Fatal("plain error should not be retryable") + } + if !isRetryableGenerateError(provider.NewNetworkProviderError("temporary")) { + t.Fatal("network provider error should be retryable") + } + }) + + t.Run("timeout_error", func(t *testing.T) { + t.Parallel() + if !isTimeoutGenerateError(context.DeadlineExceeded) { + t.Fatal("context deadline should be timeout") + } + if isTimeoutGenerateError(errors.New("plain")) { + t.Fatal("plain error should not be timeout") + } + }) +} + +func TestRetryBackoffAndWait(t *testing.T) { + t.Parallel() + + if wait := generateRetryBackoff(0); wait != 0 { + t.Fatalf("attempt 0 backoff = %v, want 0", wait) + } + + for attempt := 1; attempt <= 6; attempt++ { + wait := generateRetryBackoff(attempt) + if wait < 0 { + t.Fatalf("attempt %d backoff should be non-negative, got %v", attempt, wait) + } + if wait > generateRetryMaxWait { + t.Fatalf("attempt %d backoff should be <= %v, got %v", attempt, generateRetryMaxWait, wait) + } + } + + if err := waitForRetry(context.Background(), 0); err != nil { + t.Fatalf("waitForRetry(0) error = %v", err) + } + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if err := waitForRetry(ctx, time.Millisecond); !errors.Is(err, context.Canceled) { + t.Fatalf("expected context canceled, got %v", err) + } + if err := waitForRetry(context.Background(), time.Millisecond); err != nil { + t.Fatalf("waitForRetry(timeout) error = %v", err) + } +} + +func TestMapGeminiSDKError(t *testing.T) { + t.Parallel() + + if err := mapGeminiSDKError(errors.New("plain")); err != nil { + t.Fatalf("plain error should not map, got %v", err) + } + + cases := []struct { + name string + err error + wantCode provider.ProviderErrorCode + wantSubstr string + }{ + { + name: "status from name unauthenticated", + err: genai.APIError{Status: "UNAUTHENTICATED", Message: "bad token"}, + wantCode: provider.ErrorCodeAuthFailed, + wantSubstr: "bad token", + }, + { + name: "bad request api key heuristic", + err: genai.APIError{Code: http.StatusBadRequest, Message: "x-goog-api-key invalid"}, + wantCode: provider.ErrorCodeAuthFailed, + wantSubstr: "x-goog-api-key invalid", + }, + { + name: "bad request quota heuristic", + err: &genai.APIError{Code: http.StatusBadRequest, Message: "RESOURCE_EXHAUSTED quota"}, + wantCode: provider.ErrorCodeRateLimit, + wantSubstr: "RESOURCE_EXHAUSTED quota", + }, + { + name: "permission denied", + err: genai.APIError{Status: "PERMISSION_DENIED", Message: "forbidden"}, + wantCode: provider.ErrorCodeForbidden, + wantSubstr: "forbidden", + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + err := mapGeminiSDKError(tc.err) + if err == nil { + t.Fatal("expected mapped error") + } + var providerErr *provider.ProviderError + if !errors.As(err, &providerErr) { + t.Fatalf("expected provider error, got %T %v", err, err) + } + if providerErr.Code != tc.wantCode { + t.Fatalf("provider code = %q, want %q", providerErr.Code, tc.wantCode) + } + if !strings.Contains(err.Error(), tc.wantSubstr) { + t.Fatalf("mapped error %q does not contain %q", err.Error(), tc.wantSubstr) + } + }) + } +} + func drainEvents(events <-chan providertypes.StreamEvent) []providertypes.StreamEvent { var drained []providertypes.StreamEvent for { From fbe701c9a4f22c4ce36002605f4581806c3dbdc9 Mon Sep 17 00:00:00 2001 From: phantom5099 <1011668688@qq.com> Date: Sat, 25 Apr 2026 20:52:50 +0800 Subject: [PATCH 4/8] =?UTF-8?q?refactor(provider):=20=E7=BB=9F=E4=B8=80?= =?UTF-8?q?=E7=94=9F=E6=88=90=E9=87=8D=E8=AF=95=E4=B8=BA=E9=A6=96=E5=8C=85?= =?UTF-8?q?=E8=B6=85=E6=97=B6=E4=B8=8E=E6=B5=81=E7=A9=BA=E9=97=B2=E8=B6=85?= =?UTF-8?q?=E6=97=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/config-management-detail-design.md | 2 +- docs/guides/adding-providers.md | 4 + docs/guides/configuration.md | 9 + internal/config/loader_test.go | 86 ++++ internal/config/provider.go | 58 ++- internal/config/provider_custom_normalize.go | 46 ++- internal/config/provider_loader.go | 99 +++-- internal/config/provider_test.go | 93 ++++- internal/gateway/launcher/launcher.go | 9 +- internal/gateway/launcher/launcher_test.go | 52 ++- internal/provider/anthropic/driver.go | 25 +- internal/provider/anthropic/provider.go | 26 +- internal/provider/constants.go | 42 +- internal/provider/contracts.go | 19 + internal/provider/gemini/driver.go | 21 +- internal/provider/gemini/provider.go | 117 ++---- internal/provider/generate_attempt.go | 370 ++++++++++++++++++ internal/provider/generate_attempt_test.go | 270 +++++++++++++ .../openaicompat/driver_internal_test.go | 69 ++-- .../provider/openaicompat/generate_sdk.go | 96 ++++- .../openaicompat/generate_sdk_test.go | 109 +++++- .../openaicompat/openaicompat_test.go | 107 ++++- internal/provider/openaicompat/provider.go | 67 +++- 23 files changed, 1530 insertions(+), 266 deletions(-) create mode 100644 internal/provider/generate_attempt.go create mode 100644 internal/provider/generate_attempt_test.go diff --git a/docs/config-management-detail-design.md b/docs/config-management-detail-design.md index 437fb02b..c6b47fe5 100644 --- a/docs/config-management-detail-design.md +++ b/docs/config-management-detail-design.md @@ -69,7 +69,7 @@ custom provider 来自: ``` 当前只接受明确受支持的字段;未知字段会直接报错,不做“旧格式自动迁移”。 -`provider.yaml` 只支持平铺字段:`name/driver/base_url/api_key_env/model_source/chat_endpoint_path/discovery_endpoint_path/models`。 +`provider.yaml` 只支持平铺字段:`name/driver/base_url/api_key_env/model_source/chat_endpoint_path/discovery_endpoint_path/generate_max_retries/generate_start_timeout_sec/generate_idle_timeout_sec/models`。 ## 加载流程 diff --git a/docs/guides/adding-providers.md b/docs/guides/adding-providers.md index f925d711..007534dc 100644 --- a/docs/guides/adding-providers.md +++ b/docs/guides/adding-providers.md @@ -99,6 +99,9 @@ model_source: discover chat_api_mode: responses chat_endpoint_path: / discovery_endpoint_path: /models +generate_max_retries: 5 +generate_start_timeout_sec: 60 +generate_idle_timeout_sec: 300 ``` 说明: @@ -107,6 +110,7 @@ discovery_endpoint_path: /models - `chat_endpoint_path` 为 `/` 表示直连 `base_url`;为空时会按 `chat_api_mode` 自动回填默认子路径(`/chat/completions` 或 `/responses`)。 - 当 `chat_api_mode` 已显式指定时,`chat_endpoint_path` 可使用任意以 `/` 开头的相对路径;未显式指定时,仅支持标准端点推断(`/chat/completions`、`/responses`、`/`)。 - `model_source: manual` 时必须提供 `models`,且会忽略 `discovery_endpoint_path`。 +- `generate_max_retries` / `generate_start_timeout_sec` / `generate_idle_timeout_sec` 用于控制 provider 级生成重试、首包超时和流空闲超时;未填写或 `<= 0` 时会分别回退到 `5 / 60 / 300`。 ## 测试要求 diff --git a/docs/guides/configuration.md b/docs/guides/configuration.md index f67a98ee..ee4ea99e 100644 --- a/docs/guides/configuration.md +++ b/docs/guides/configuration.md @@ -173,8 +173,17 @@ base_url: https://llm.example.com/v1 chat_api_mode: chat_completions chat_endpoint_path: /chat/completions discovery_endpoint_path: /models +generate_max_retries: 5 +generate_start_timeout_sec: 60 +generate_idle_timeout_sec: 300 ``` +新增的生成链路控制字段含义如下: + +- `generate_max_retries`:额外重试次数,不含首次尝试;`<= 0` 时回退默认值 `5`。 +- `generate_start_timeout_sec`:从发请求到收到首个有效流 payload 的最长等待窗口;`<= 0` 时回退默认值 `60`。 +- `generate_idle_timeout_sec`:首包后连续没有任何新 payload 的最长空闲窗口;`<= 0` 时回退默认值 `300`。 + ## 不写入 `config.yaml` 的字段 以下内容不允许写入主配置文件: diff --git a/internal/config/loader_test.go b/internal/config/loader_test.go index 77619ac6..0e0632c1 100644 --- a/internal/config/loader_test.go +++ b/internal/config/loader_test.go @@ -1362,6 +1362,92 @@ func TestSaveCustomProviderManualModelsPersistOptionalFields(t *testing.T) { } } +func TestSaveAndLoadCustomProviderPersistsGenerateControls(t *testing.T) { + t.Parallel() + + baseDir := t.TempDir() + const providerName = "retry-controls-provider" + err := SaveCustomProviderWithModels(baseDir, SaveCustomProviderInput{ + Name: providerName, + Driver: provider.DriverOpenAICompat, + BaseURL: "https://llm.example.com/v1", + APIKeyEnv: "RETRY_CONTROLS_PROVIDER_API_KEY", + ModelSource: ModelSourceDiscover, + DiscoveryEndpointPath: provider.DiscoveryEndpointPathModels, + GenerateMaxRetries: 7, + GenerateStartTimeoutSec: 75, + GenerateIdleTimeoutSec: 420, + }) + if err != nil { + t.Fatalf("SaveCustomProviderWithModels() error = %v", err) + } + + cfg, err := loadCustomProvider(filepath.Join(baseDir, providersDirName, providerName)) + if err != nil { + t.Fatalf("loadCustomProvider() error = %v", err) + } + if cfg.GenerateMaxRetries != 7 { + t.Fatalf("expected GenerateMaxRetries=7, got %d", cfg.GenerateMaxRetries) + } + if cfg.GenerateStartTimeoutSec != 75 { + t.Fatalf("expected GenerateStartTimeoutSec=75, got %d", cfg.GenerateStartTimeoutSec) + } + if cfg.GenerateIdleTimeoutSec != 420 { + t.Fatalf("expected GenerateIdleTimeoutSec=420, got %d", cfg.GenerateIdleTimeoutSec) + } +} + +func TestSaveCustomProviderOmitsDefaultGenerateControlsWhenUnset(t *testing.T) { + t.Parallel() + + baseDir := t.TempDir() + const providerName = "omit-default-generate-controls" + err := SaveCustomProviderWithModels(baseDir, SaveCustomProviderInput{ + Name: providerName, + Driver: provider.DriverOpenAICompat, + BaseURL: "https://llm.example.com/v1", + APIKeyEnv: "OMIT_DEFAULT_GENERATE_CONTROLS_API_KEY", + ModelSource: ModelSourceDiscover, + DiscoveryEndpointPath: provider.DiscoveryEndpointPathModels, + }) + if err != nil { + t.Fatalf("SaveCustomProviderWithModels() error = %v", err) + } + + data, err := os.ReadFile(filepath.Join(baseDir, providersDirName, providerName, customProviderConfigName)) + if err != nil { + t.Fatalf("ReadFile() error = %v", err) + } + content := string(data) + if strings.Contains(content, "generate_max_retries") { + t.Fatalf("expected generate_max_retries to be omitted, got %q", content) + } + if strings.Contains(content, "generate_start_timeout_sec") { + t.Fatalf("expected generate_start_timeout_sec to be omitted, got %q", content) + } + if strings.Contains(content, "generate_idle_timeout_sec") { + t.Fatalf("expected generate_idle_timeout_sec to be omitted, got %q", content) + } +} + +func TestSaveCustomProviderRejectsNegativeGenerateControls(t *testing.T) { + t.Parallel() + + baseDir := t.TempDir() + err := SaveCustomProviderWithModels(baseDir, SaveCustomProviderInput{ + Name: "invalid-generate-controls", + Driver: provider.DriverOpenAICompat, + BaseURL: "https://llm.example.com/v1", + APIKeyEnv: "INVALID_GENERATE_CONTROLS_API_KEY", + ModelSource: ModelSourceDiscover, + DiscoveryEndpointPath: provider.DiscoveryEndpointPathModels, + GenerateIdleTimeoutSec: -1, + }) + if err == nil || !strings.Contains(err.Error(), "generate_idle_timeout_sec") { + t.Fatalf("expected negative generate control to be rejected, got %v", err) + } +} + func TestToCustomProviderModelFiles(t *testing.T) { t.Parallel() diff --git a/internal/config/provider.go b/internal/config/provider.go index f71e7200..736cfcbd 100644 --- a/internal/config/provider.go +++ b/internal/config/provider.go @@ -3,9 +3,11 @@ package config import ( "errors" "fmt" + "math" "net/url" "os" "strings" + "time" "neo-code/internal/provider" providertypes "neo-code/internal/provider/types" @@ -20,17 +22,20 @@ const ( ) type ProviderConfig struct { - Name string `yaml:"name"` - Driver string `yaml:"driver"` - BaseURL string `yaml:"base_url"` - Model string `yaml:"model"` - APIKeyEnv string `yaml:"api_key_env"` - ModelSource string `yaml:"-"` - ChatAPIMode string `yaml:"-"` - ChatEndpointPath string `yaml:"-"` - DiscoveryEndpointPath string `yaml:"-"` - Models []providertypes.ModelDescriptor `yaml:"-"` - Source ProviderSource `yaml:"-"` + Name string `yaml:"name"` + Driver string `yaml:"driver"` + BaseURL string `yaml:"base_url"` + Model string `yaml:"model"` + APIKeyEnv string `yaml:"api_key_env"` + GenerateMaxRetries int `yaml:"generate_max_retries,omitempty"` + GenerateStartTimeoutSec int `yaml:"generate_start_timeout_sec,omitempty"` + GenerateIdleTimeoutSec int `yaml:"generate_idle_timeout_sec,omitempty"` + ModelSource string `yaml:"-"` + ChatAPIMode string `yaml:"-"` + ChatEndpointPath string `yaml:"-"` + DiscoveryEndpointPath string `yaml:"-"` + Models []providertypes.ModelDescriptor `yaml:"-"` + Source ProviderSource `yaml:"-"` } type ResolvedProviderConfig struct { @@ -79,6 +84,15 @@ func (p ProviderConfig) Validate() error { if strings.TrimSpace(p.APIKeyEnv) == "" { return fmt.Errorf("provider %q api_key_env is empty", p.Name) } + if err := validateOptionalNonNegativeGenerateControl("generate_max_retries", p.GenerateMaxRetries); err != nil { + return fmt.Errorf("provider %q: %w", p.Name, err) + } + if err := validateOptionalGenerateDurationSeconds("generate_start_timeout_sec", p.GenerateStartTimeoutSec); err != nil { + return fmt.Errorf("provider %q: %w", p.Name, err) + } + if err := validateOptionalGenerateDurationSeconds("generate_idle_timeout_sec", p.GenerateIdleTimeoutSec); err != nil { + return fmt.Errorf("provider %q: %w", p.Name, err) + } normalizedModelSource := NormalizeModelSource(p.ModelSource) if normalizedModelSource == "" { @@ -108,6 +122,25 @@ func (p ProviderConfig) Validate() error { return nil } +// validateOptionalNonNegativeGenerateControl 校验可选的整型生成控制字段,拒绝会被运行时静默吞掉的负数输入。 +func validateOptionalNonNegativeGenerateControl(field string, value int) error { + if value < 0 { + return fmt.Errorf("%s must be greater than or equal to 0", field) + } + return nil +} + +// validateOptionalGenerateDurationSeconds 校验秒级超时字段,避免负值和 duration 溢出在运行时被悄悄回退为默认值。 +func validateOptionalGenerateDurationSeconds(field string, value int) error { + if err := validateOptionalNonNegativeGenerateControl(field, value); err != nil { + return err + } + if int64(value) > math.MaxInt64/int64(time.Second) { + return fmt.Errorf("%s exceeds supported range", field) + } + return nil +} + func (p ProviderConfig) Identity() (provider.ProviderIdentity, error) { return providerIdentityFromConfig(p) } @@ -239,6 +272,9 @@ func (p ResolvedProviderConfig) ToRuntimeConfig() (provider.RuntimeConfig, error ChatAPIMode: chatAPIMode, ChatEndpointPath: chatEndpointPath, DiscoveryEndpointPath: discoveryEndpointPath, + GenerateMaxRetries: provider.NormalizeGenerateMaxRetries(p.GenerateMaxRetries), + GenerateStartTimeout: provider.NormalizeGenerateStartTimeout(time.Duration(p.GenerateStartTimeoutSec) * time.Second), + GenerateIdleTimeout: provider.NormalizeGenerateIdleTimeout(time.Duration(p.GenerateIdleTimeoutSec) * time.Second), }, nil } diff --git a/internal/config/provider_custom_normalize.go b/internal/config/provider_custom_normalize.go index 3ec3efbb..e6fd7709 100644 --- a/internal/config/provider_custom_normalize.go +++ b/internal/config/provider_custom_normalize.go @@ -14,13 +14,16 @@ const ManualModelOptionalIntUnset = -1 // NormalizeCustomProviderInput 统一归一化 custom provider 的输入字段,并执行协议/模型来源的组合校验。 func NormalizeCustomProviderInput(input SaveCustomProviderInput) (SaveCustomProviderInput, error) { normalized := SaveCustomProviderInput{ - Name: strings.TrimSpace(input.Name), - Driver: normalizeProviderDriver(strings.TrimSpace(input.Driver)), - BaseURL: strings.TrimSpace(input.BaseURL), - ChatAPIMode: strings.TrimSpace(input.ChatAPIMode), - ChatEndpointPath: strings.TrimSpace(input.ChatEndpointPath), - APIKeyEnv: strings.TrimSpace(input.APIKeyEnv), - DiscoveryEndpointPath: strings.TrimSpace(input.DiscoveryEndpointPath), + Name: strings.TrimSpace(input.Name), + Driver: normalizeProviderDriver(strings.TrimSpace(input.Driver)), + BaseURL: strings.TrimSpace(input.BaseURL), + ChatAPIMode: strings.TrimSpace(input.ChatAPIMode), + ChatEndpointPath: strings.TrimSpace(input.ChatEndpointPath), + APIKeyEnv: strings.TrimSpace(input.APIKeyEnv), + GenerateMaxRetries: normalizeOptionalGenerateInt(input.GenerateMaxRetries), + GenerateStartTimeoutSec: normalizeOptionalGenerateInt(input.GenerateStartTimeoutSec), + GenerateIdleTimeoutSec: normalizeOptionalGenerateInt(input.GenerateIdleTimeoutSec), + DiscoveryEndpointPath: strings.TrimSpace(input.DiscoveryEndpointPath), } if err := validateCustomProviderName(normalized.Name); err != nil { @@ -96,11 +99,36 @@ func NormalizeCustomProviderInput(input SaveCustomProviderInput) (SaveCustomProv ) } normalized.DiscoveryEndpointPath = "" - return normalized, nil + return normalized, validateNormalizedCustomProviderInput(normalized) } normalized.DiscoveryEndpointPath = discoveryEndpointPath - return normalized, nil + return normalized, validateNormalizedCustomProviderInput(normalized) +} + +// normalizeOptionalGenerateInt 归一化可选的生成控制字段,仅保留调用方原始输入,避免在保存前静默吞掉非法值。 +func normalizeOptionalGenerateInt(value int) int { + return value +} + +// validateNormalizedCustomProviderInput 复用统一的 provider 配置校验,避免 custom provider 保存路径和加载路径出现两套规则。 +func validateNormalizedCustomProviderInput(input SaveCustomProviderInput) error { + cfg := ProviderConfig{ + Name: input.Name, + Driver: input.Driver, + BaseURL: input.BaseURL, + APIKeyEnv: input.APIKeyEnv, + GenerateMaxRetries: input.GenerateMaxRetries, + GenerateStartTimeoutSec: input.GenerateStartTimeoutSec, + GenerateIdleTimeoutSec: input.GenerateIdleTimeoutSec, + ModelSource: input.ModelSource, + ChatAPIMode: input.ChatAPIMode, + ChatEndpointPath: input.ChatEndpointPath, + DiscoveryEndpointPath: input.DiscoveryEndpointPath, + Models: input.Models, + Source: ProviderSourceCustom, + } + return cfg.Validate() } // NormalizeCustomProviderModels 统一归一化 custom provider 的模型描述并校验必填字段和边界条件。 diff --git a/internal/config/provider_loader.go b/internal/config/provider_loader.go index fb9881c1..4c792fd6 100644 --- a/internal/config/provider_loader.go +++ b/internal/config/provider_loader.go @@ -21,15 +21,18 @@ const ( ) type customProviderFile struct { - Name string `yaml:"name"` - Driver string `yaml:"driver"` - APIKeyEnv string `yaml:"api_key_env"` - ModelSource string `yaml:"model_source,omitempty"` - ChatAPIMode string `yaml:"chat_api_mode,omitempty"` - BaseURL string `yaml:"base_url,omitempty"` - ChatEndpointPath string `yaml:"chat_endpoint_path,omitempty"` - DiscoveryEndpointPath string `yaml:"discovery_endpoint_path,omitempty"` - Models []customProviderModelFile `yaml:"models,omitempty"` + Name string `yaml:"name"` + Driver string `yaml:"driver"` + APIKeyEnv string `yaml:"api_key_env"` + GenerateMaxRetries int `yaml:"generate_max_retries,omitempty"` + GenerateStartTimeoutSec int `yaml:"generate_start_timeout_sec,omitempty"` + GenerateIdleTimeoutSec int `yaml:"generate_idle_timeout_sec,omitempty"` + ModelSource string `yaml:"model_source,omitempty"` + ChatAPIMode string `yaml:"chat_api_mode,omitempty"` + BaseURL string `yaml:"base_url,omitempty"` + ChatEndpointPath string `yaml:"chat_endpoint_path,omitempty"` + DiscoveryEndpointPath string `yaml:"discovery_endpoint_path,omitempty"` + Models []customProviderModelFile `yaml:"models,omitempty"` } type customProviderModelFile struct { @@ -109,31 +112,37 @@ func loadCustomProvider(providerDir string) (ProviderConfig, error) { } normalizedInput, err := NormalizeCustomProviderInput(SaveCustomProviderInput{ - Name: strings.TrimSpace(file.Name), - Driver: strings.TrimSpace(file.Driver), - BaseURL: strings.TrimSpace(file.BaseURL), - APIKeyEnv: strings.TrimSpace(file.APIKeyEnv), - ModelSource: strings.TrimSpace(file.ModelSource), - ChatAPIMode: strings.TrimSpace(file.ChatAPIMode), - ChatEndpointPath: strings.TrimSpace(file.ChatEndpointPath), - DiscoveryEndpointPath: strings.TrimSpace(file.DiscoveryEndpointPath), - Models: models, + Name: strings.TrimSpace(file.Name), + Driver: strings.TrimSpace(file.Driver), + BaseURL: strings.TrimSpace(file.BaseURL), + APIKeyEnv: strings.TrimSpace(file.APIKeyEnv), + GenerateMaxRetries: file.GenerateMaxRetries, + GenerateStartTimeoutSec: file.GenerateStartTimeoutSec, + GenerateIdleTimeoutSec: file.GenerateIdleTimeoutSec, + ModelSource: strings.TrimSpace(file.ModelSource), + ChatAPIMode: strings.TrimSpace(file.ChatAPIMode), + ChatEndpointPath: strings.TrimSpace(file.ChatEndpointPath), + DiscoveryEndpointPath: strings.TrimSpace(file.DiscoveryEndpointPath), + Models: models, }) if err != nil { return ProviderConfig{}, fmt.Errorf("config: custom provider %q: %w", filepath.Base(providerDir), err) } cfg := ProviderConfig{ - Name: normalizedInput.Name, - Driver: normalizedInput.Driver, - BaseURL: normalizedInput.BaseURL, - APIKeyEnv: normalizedInput.APIKeyEnv, - ModelSource: normalizedInput.ModelSource, - ChatAPIMode: normalizedInput.ChatAPIMode, - ChatEndpointPath: normalizedInput.ChatEndpointPath, - DiscoveryEndpointPath: normalizedInput.DiscoveryEndpointPath, - Models: normalizedInput.Models, - Source: ProviderSourceCustom, + Name: normalizedInput.Name, + Driver: normalizedInput.Driver, + BaseURL: normalizedInput.BaseURL, + APIKeyEnv: normalizedInput.APIKeyEnv, + GenerateMaxRetries: normalizedInput.GenerateMaxRetries, + GenerateStartTimeoutSec: normalizedInput.GenerateStartTimeoutSec, + GenerateIdleTimeoutSec: normalizedInput.GenerateIdleTimeoutSec, + ModelSource: normalizedInput.ModelSource, + ChatAPIMode: normalizedInput.ChatAPIMode, + ChatEndpointPath: normalizedInput.ChatEndpointPath, + DiscoveryEndpointPath: normalizedInput.DiscoveryEndpointPath, + Models: normalizedInput.Models, + Source: ProviderSourceCustom, } if err := cfg.Validate(); err != nil { @@ -191,15 +200,18 @@ func customProviderModels(models []customProviderModelFile) ([]providertypes.Mod // SaveCustomProviderInput 定义自定义 Provider 的持久化字段。 type SaveCustomProviderInput struct { - Name string - Driver string - BaseURL string - ChatAPIMode string - ChatEndpointPath string - APIKeyEnv string - DiscoveryEndpointPath string - ModelSource string - Models []providertypes.ModelDescriptor + Name string + Driver string + BaseURL string + ChatAPIMode string + ChatEndpointPath string + APIKeyEnv string + GenerateMaxRetries int + GenerateStartTimeoutSec int + GenerateIdleTimeoutSec int + DiscoveryEndpointPath string + ModelSource string + Models []providertypes.ModelDescriptor } // SaveCustomProviderWithModels 保存自定义 provider,并可在 manual 模式下写入手工模型列表。 @@ -215,11 +227,14 @@ func SaveCustomProviderWithModels(baseDir string, input SaveCustomProviderInput) } cfg := customProviderFile{ - Name: normalizedInput.Name, - Driver: normalizedInput.Driver, - APIKeyEnv: normalizedInput.APIKeyEnv, - ModelSource: normalizedInput.ModelSource, - ChatAPIMode: normalizedInput.ChatAPIMode, + Name: normalizedInput.Name, + Driver: normalizedInput.Driver, + APIKeyEnv: normalizedInput.APIKeyEnv, + GenerateMaxRetries: normalizedInput.GenerateMaxRetries, + GenerateStartTimeoutSec: normalizedInput.GenerateStartTimeoutSec, + GenerateIdleTimeoutSec: normalizedInput.GenerateIdleTimeoutSec, + ModelSource: normalizedInput.ModelSource, + ChatAPIMode: normalizedInput.ChatAPIMode, } cfg.BaseURL = normalizedInput.BaseURL diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go index 18a788b4..63962aa7 100644 --- a/internal/config/provider_test.go +++ b/internal/config/provider_test.go @@ -4,6 +4,7 @@ import ( "os" "strings" "testing" + "time" providerpkg "neo-code/internal/provider" providertypes "neo-code/internal/provider/types" @@ -364,6 +365,59 @@ func TestProviderConfigValidateRejectsBaseURLWithUserinfo(t *testing.T) { } } +func TestProviderConfigValidateRejectsNegativeGenerateControls(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + mutate func(*ProviderConfig) + errContain string + }{ + { + name: "negative retries", + mutate: func(cfg *ProviderConfig) { + cfg.GenerateMaxRetries = -1 + }, + errContain: "generate_max_retries", + }, + { + name: "negative start timeout", + mutate: func(cfg *ProviderConfig) { + cfg.GenerateStartTimeoutSec = -1 + }, + errContain: "generate_start_timeout_sec", + }, + { + name: "negative idle timeout", + mutate: func(cfg *ProviderConfig) { + cfg.GenerateIdleTimeoutSec = -1 + }, + errContain: "generate_idle_timeout_sec", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + cfg := ProviderConfig{ + Name: "company-gateway", + Driver: providerpkg.DriverOpenAICompat, + BaseURL: "https://llm.example.com/v1", + APIKeyEnv: "TEST_KEY", + ModelSource: ModelSourceDiscover, + DiscoveryEndpointPath: providerpkg.DiscoveryEndpointPathModels, + Source: ProviderSourceCustom, + } + tt.mutate(&cfg) + if err := cfg.Validate(); err == nil || !strings.Contains(err.Error(), tt.errContain) { + t.Fatalf("expected %q validation error, got %v", tt.errContain, err) + } + }) + } +} + func TestCloneProviderConfigModelDescriptorsIndependence(t *testing.T) { t.Parallel() @@ -721,6 +775,9 @@ func TestResolvedProviderConfigToRuntimeConfig(t *testing.T) { ChatAPIMode: "", ChatEndpointPath: "", DiscoveryEndpointPath: providerpkg.DiscoveryEndpointPathModels, + GenerateMaxRetries: providerpkg.DefaultGenerateMaxRetries, + GenerateStartTimeout: providerpkg.DefaultGenerateStartTimeout, + GenerateIdleTimeout: providerpkg.DefaultGenerateIdleTimeout, } if got.APIKeyResolver == nil { @@ -736,11 +793,45 @@ func TestResolvedProviderConfigToRuntimeConfig(t *testing.T) { got.RequestAssetBudget != want.RequestAssetBudget || got.ChatAPIMode != want.ChatAPIMode || got.ChatEndpointPath != want.ChatEndpointPath || - got.DiscoveryEndpointPath != want.DiscoveryEndpointPath { + got.DiscoveryEndpointPath != want.DiscoveryEndpointPath || + got.GenerateMaxRetries != want.GenerateMaxRetries || + got.GenerateStartTimeout != want.GenerateStartTimeout || + got.GenerateIdleTimeout != want.GenerateIdleTimeout { t.Fatalf("ToRuntimeConfig() = %+v, want %+v", got, want) } } +func TestResolvedProviderConfigToRuntimeConfigMapsGenerateControls(t *testing.T) { + t.Parallel() + + resolved := ResolvedProviderConfig{ + ProviderConfig: ProviderConfig{ + Name: "company-gateway", + Driver: "openaicompat", + BaseURL: "https://llm.example.com/v1", + Model: "server-default", + APIKeyEnv: "COMPANY_GATEWAY_KEY", + GenerateMaxRetries: 7, + GenerateStartTimeoutSec: 75, + GenerateIdleTimeoutSec: 420, + }, + } + + got, err := resolved.ToRuntimeConfig() + if err != nil { + t.Fatalf("ToRuntimeConfig() error = %v", err) + } + if got.GenerateMaxRetries != 7 { + t.Fatalf("expected GenerateMaxRetries=7, got %d", got.GenerateMaxRetries) + } + if got.GenerateStartTimeout != 75*time.Second { + t.Fatalf("expected GenerateStartTimeout=75s, got %s", got.GenerateStartTimeout) + } + if got.GenerateIdleTimeout != 420*time.Second { + t.Fatalf("expected GenerateIdleTimeout=420s, got %s", got.GenerateIdleTimeout) + } +} + func TestResolvedProviderConfigToRuntimeConfigStripsBaseURLUserinfo(t *testing.T) { t.Parallel() diff --git a/internal/gateway/launcher/launcher.go b/internal/gateway/launcher/launcher.go index b0b846d9..51412771 100644 --- a/internal/gateway/launcher/launcher.go +++ b/internal/gateway/launcher/launcher.go @@ -138,7 +138,7 @@ func resolveExecutablePath(lookPathFn func(string) (string, error), binary strin // validateExplicitGatewayBinary 校验显式配置的网关二进制路径,禁止使用相对路径降低 PATH 劫持风险。 func validateExplicitGatewayBinary(explicitBinary string) error { - if !filepath.IsAbs(explicitBinary) { + if !isAbsolutePath(explicitBinary) { return fmt.Errorf("explicit gateway binary must be an absolute path: %q", explicitBinary) } return nil @@ -146,8 +146,13 @@ func validateExplicitGatewayBinary(explicitBinary string) error { // validateResolvedExecutablePath 校验解析后的可执行路径必须为绝对路径,避免执行不受控相对路径目标。 func validateResolvedExecutablePath(resolvedPath string, source string) error { - if !filepath.IsAbs(resolvedPath) { + if !isAbsolutePath(resolvedPath) { return fmt.Errorf("resolved executable from %s is not an absolute path: %q", source, resolvedPath) } return nil } + +// isAbsolutePath 按当前平台原生语义校验绝对路径,避免放宽到依赖环境状态的跨平台路径解释。 +func isAbsolutePath(p string) bool { + return filepath.IsAbs(p) +} diff --git a/internal/gateway/launcher/launcher_test.go b/internal/gateway/launcher/launcher_test.go index afa22c31..96243dd9 100644 --- a/internal/gateway/launcher/launcher_test.go +++ b/internal/gateway/launcher/launcher_test.go @@ -26,12 +26,27 @@ func assertLaunchSpecEqual(t *testing.T, spec LaunchSpec, want LaunchSpec) { } } +func testAbsolutePath(name string) string { + if runtime.GOOS == "windows" { + return `C:\tools\` + name + } + return "/opt/tools/" + name +} + +func testPathBinary(name string) string { + if runtime.GOOS == "windows" { + return `C:\usr\local\bin\` + name + } + return "/usr/local/bin/" + name +} + func TestResolveGatewayLaunchSpecWithDeps(t *testing.T) { t.Run("explicit binary has highest priority", func(t *testing.T) { + executable := testAbsolutePath("neocode-gateway") spec, err := resolveGatewayLaunchSpecWithDeps( - ResolveOptions{ExplicitBinary: "/opt/tools/neocode-gateway"}, + ResolveOptions{ExplicitBinary: executable}, func(binary string) (string, error) { - if binary == "/opt/tools/neocode-gateway" { + if binary == executable { return binary, nil } return "", errors.New("unexpected lookup") @@ -42,16 +57,17 @@ func TestResolveGatewayLaunchSpecWithDeps(t *testing.T) { } assertLaunchSpecEqual(t, spec, LaunchSpec{ LaunchMode: LaunchModeExplicitPath, - Executable: "/opt/tools/neocode-gateway", + Executable: executable, }) }) t.Run("path binary preferred over fallback", func(t *testing.T) { + executable := testPathBinary("neocode-gateway") spec, err := resolveGatewayLaunchSpecWithDeps( ResolveOptions{}, func(binary string) (string, error) { if binary == "neocode-gateway" { - return "/usr/local/bin/neocode-gateway", nil + return executable, nil } return "", errors.New("unexpected lookup") }, @@ -61,11 +77,12 @@ func TestResolveGatewayLaunchSpecWithDeps(t *testing.T) { } assertLaunchSpecEqual(t, spec, LaunchSpec{ LaunchMode: LaunchModePathBinary, - Executable: "/usr/local/bin/neocode-gateway", + Executable: executable, }) }) t.Run("fallback to neocode subcommand", func(t *testing.T) { + executable := testPathBinary("neocode") spec, err := resolveGatewayLaunchSpecWithDeps( ResolveOptions{}, func(binary string) (string, error) { @@ -73,7 +90,7 @@ func TestResolveGatewayLaunchSpecWithDeps(t *testing.T) { case "neocode-gateway": return "", errors.New("not found") case "neocode": - return "/usr/local/bin/neocode", nil + return executable, nil default: return "", errors.New("unexpected lookup") } @@ -84,7 +101,7 @@ func TestResolveGatewayLaunchSpecWithDeps(t *testing.T) { } assertLaunchSpecEqual(t, spec, LaunchSpec{ LaunchMode: LaunchModeFallbackSubcommand, - Executable: "/usr/local/bin/neocode", + Executable: executable, Args: []string{"gateway"}, }) }) @@ -118,6 +135,27 @@ func TestResolveGatewayLaunchSpecWithDeps(t *testing.T) { } }) + t.Run("unix style path is not treated as absolute on windows", func(t *testing.T) { + if runtime.GOOS != "windows" { + t.Skip("windows-only path semantics") + } + + lookupCalled := false + _, err := resolveGatewayLaunchSpecWithDeps( + ResolveOptions{ExplicitBinary: "/tools/neocode-gateway.exe"}, + func(string) (string, error) { + lookupCalled = true + return "", nil + }, + ) + if err == nil { + t.Fatal("expected windows unix-style path validation error") + } + if lookupCalled { + t.Fatal("lookPath should not be called for invalid explicit path") + } + }) + t.Run("path binary resolution rejects non-absolute path", func(t *testing.T) { _, err := resolveGatewayLaunchSpecWithDeps( ResolveOptions{}, diff --git a/internal/provider/anthropic/driver.go b/internal/provider/anthropic/driver.go index 95e551c5..d39ac82c 100644 --- a/internal/provider/anthropic/driver.go +++ b/internal/provider/anthropic/driver.go @@ -24,7 +24,7 @@ func Driver() provider.DriverDefinition { return New(cfg) }, Discover: func(ctx context.Context, cfg provider.RuntimeConfig) ([]providertypes.ModelDescriptor, error) { - client, err := newSDKClient(cfg) + client, err := newDiscoverySDKClient(cfg) if err != nil { return nil, err } @@ -57,20 +57,35 @@ func Driver() provider.DriverDefinition { } } -// newSDKClient 构造 Anthropic SDK 客户端,供生成与模型发现链路共享连接配置。 -func newSDKClient(cfg provider.RuntimeConfig) (anthropic.Client, error) { +// newDiscoverySDKClient 构造模型发现使用的 Anthropic SDK 客户端。 +func newDiscoverySDKClient(cfg provider.RuntimeConfig) (anthropic.Client, error) { + return newSDKClient(cfg, true) +} + +// newGenerateSDKClient 构造生成链路使用的 Anthropic SDK 客户端,并关闭 SDK 内建重试。 +func newGenerateSDKClient(cfg provider.RuntimeConfig) (anthropic.Client, error) { + return newSDKClient(cfg, false) +} + +// newSDKClient 根据调用场景构造 Anthropic SDK 客户端,避免生成链路被底层超时与重试抢占控制权。 +func newSDKClient(cfg provider.RuntimeConfig, discovery bool) (anthropic.Client, error) { apiKey, err := cfg.ResolveAPIKeyValue() if err != nil { return anthropic.Client{}, err } - httpClient := &http.Client{ - Timeout: provider.DefaultSDKRequestTimeout, + httpClient := &http.Client{} + if discovery { + httpClient.Timeout = provider.DefaultSDKRequestTimeout } + options := []anthroption.RequestOption{ anthroption.WithHTTPClient(httpClient), anthroption.WithAPIKey(apiKey), } + if !discovery { + options = append(options, anthroption.WithMaxRetries(0)) + } if strings.TrimSpace(cfg.BaseURL) != "" { options = append(options, anthroption.WithBaseURL(strings.TrimSpace(cfg.BaseURL))) } diff --git a/internal/provider/anthropic/provider.go b/internal/provider/anthropic/provider.go index 777fcba6..df749b86 100644 --- a/internal/provider/anthropic/provider.go +++ b/internal/provider/anthropic/provider.go @@ -57,7 +57,7 @@ func (p *Provider) EstimateInputTokens( }, nil } -// New 创建 Anthropic provider 实例,并初始化官方 SDK 客户端。 +// New 创建 Anthropic provider 实例。 func New(cfg provider.RuntimeConfig) (*Provider, error) { if strings.TrimSpace(cfg.APIKeyEnv) == "" { return nil, errors.New(errorPrefix + "api_key_env is empty") @@ -65,7 +65,7 @@ func New(cfg provider.RuntimeConfig) (*Provider, error) { return &Provider{cfg: cfg}, nil } -// Generate 发起 Anthropic 流式请求,并将 typed stream 转为统一事件。 +// Generate 发起 Anthropic 流式请求,并将重试与超时语义收敛到 provider 公共 runner。 func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { params, ok := p.takePreparedRequest(provider.BuildGenerateRequestSignature(req)) if !ok { @@ -76,7 +76,21 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque } } - client, err := newSDKClient(p.cfg) + return provider.RunGenerateWithRetry(ctx, p.cfg, events, func( + attemptCtx context.Context, + attemptEvents chan<- providertypes.StreamEvent, + ) error { + return p.generateOnce(attemptCtx, params, attemptEvents) + }) +} + +// generateOnce 执行一次 Anthropic 流式尝试,并将 typed stream 转为统一事件。 +func (p *Provider) generateOnce( + ctx context.Context, + params anthropic.MessageNewParams, + events chan<- providertypes.StreamEvent, +) error { + client, err := newGenerateSDKClient(p.cfg) if err != nil { return err } @@ -86,13 +100,13 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque var ( finishReason string usage providertypes.Usage - hasPayload bool + hasChunk bool toolCallSeq int ) toolCalls := make(map[int]toolCallState) for streamReader.Next() { - hasPayload = true + hasChunk = true event := streamReader.Current() switch variant := event.AsAny().(type) { case anthropic.MessageStartEvent: @@ -186,7 +200,7 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque } return fmt.Errorf("%sstream receive: %w", errorPrefix, streamErr) } - if !hasPayload { + if !hasChunk { return fmt.Errorf("%w: empty anthropic stream payload", provider.ErrStreamInterrupted) } for index, state := range toolCalls { diff --git a/internal/provider/constants.go b/internal/provider/constants.go index 8bc65cae..b5b2dafb 100644 --- a/internal/provider/constants.go +++ b/internal/provider/constants.go @@ -2,7 +2,7 @@ package provider import "time" -// Driver 与 OpenAI-compatible 协议常量用于在 config/provider 间共享稳定枚举值,避免字面量漂移。 +// Driver 是 config/provider 间共享的稳定枚举,避免字面量分支漂移。 const ( DriverOpenAICompat = "openaicompat" DriverGemini = "gemini" @@ -11,5 +11,41 @@ const ( DiscoveryEndpointPathModels = "/models" ) -// DefaultSDKRequestTimeout 定义 provider 层对外部模型 SDK 请求的统一超时,避免流式请求无限悬挂。 -const DefaultSDKRequestTimeout = 10 * time.Minute +const ( + // DefaultGenerateMaxRetries 定义生成链路默认额外重试次数,不含首次尝试。 + DefaultGenerateMaxRetries = 5 + // DefaultGenerateStartTimeout 定义生成链路等待首个有效 payload 的默认窗口。 + DefaultGenerateStartTimeout = 60 * time.Second + // DefaultGenerateIdleTimeout 定义首包后默认的流空闲超时窗口。 + DefaultGenerateIdleTimeout = 5 * time.Minute + // DefaultGenerateRetryBaseWait 定义生成链路重试退避的基础等待时长。 + DefaultGenerateRetryBaseWait = 1 * time.Second + // DefaultGenerateRetryMaxWait 定义生成链路重试退避的最大等待时长。 + DefaultGenerateRetryMaxWait = 5 * time.Second + // DefaultSDKRequestTimeout 定义非生成链路访问外部模型 SDK 的统一保底超时。 + DefaultSDKRequestTimeout = 10 * time.Minute +) + +// NormalizeGenerateMaxRetries 归一化生成链路额外重试次数,非正值回退到默认值。 +func NormalizeGenerateMaxRetries(value int) int { + if value <= 0 { + return DefaultGenerateMaxRetries + } + return value +} + +// NormalizeGenerateStartTimeout 归一化生成链路首包超时,非正值回退到默认值。 +func NormalizeGenerateStartTimeout(value time.Duration) time.Duration { + if value <= 0 { + return DefaultGenerateStartTimeout + } + return value +} + +// NormalizeGenerateIdleTimeout 归一化生成链路空闲超时,非正值回退到默认值。 +func NormalizeGenerateIdleTimeout(value time.Duration) time.Duration { + if value <= 0 { + return DefaultGenerateIdleTimeout + } + return value +} diff --git a/internal/provider/contracts.go b/internal/provider/contracts.go index 790bc2b3..32b8cc37 100644 --- a/internal/provider/contracts.go +++ b/internal/provider/contracts.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "strings" + "time" providertypes "neo-code/internal/provider/types" "neo-code/internal/session" @@ -27,6 +28,9 @@ type RuntimeConfig struct { ChatAPIMode string ChatEndpointPath string DiscoveryEndpointPath string + GenerateMaxRetries int + GenerateStartTimeout time.Duration + GenerateIdleTimeout time.Duration } // ResolveAPIKeyValue 在 provider 即将发起请求前解析当前配置引用的 API Key。 @@ -58,6 +62,21 @@ func (c RuntimeConfig) ResolveAPIKeyValue() (string, error) { return value, nil } +// ResolvedGenerateMaxRetries 返回运行时已归一化的生成重试次数。 +func (c RuntimeConfig) ResolvedGenerateMaxRetries() int { + return NormalizeGenerateMaxRetries(c.GenerateMaxRetries) +} + +// ResolvedGenerateStartTimeout 返回运行时已归一化的生成首包超时。 +func (c RuntimeConfig) ResolvedGenerateStartTimeout() time.Duration { + return NormalizeGenerateStartTimeout(c.GenerateStartTimeout) +} + +// ResolvedGenerateIdleTimeout 返回运行时已归一化的生成空闲超时。 +func (c RuntimeConfig) ResolvedGenerateIdleTimeout() time.Duration { + return NormalizeGenerateIdleTimeout(c.GenerateIdleTimeout) +} + // StaticAPIKeyResolver 返回一个仅供测试和受控注入场景使用的固定密钥解析器。 func StaticAPIKeyResolver(apiKey string) APIKeyResolver { trimmed := strings.TrimSpace(apiKey) diff --git a/internal/provider/gemini/driver.go b/internal/provider/gemini/driver.go index 10a90b93..92e36610 100644 --- a/internal/provider/gemini/driver.go +++ b/internal/provider/gemini/driver.go @@ -23,7 +23,7 @@ func Driver() provider.DriverDefinition { return New(cfg) }, Discover: func(ctx context.Context, cfg provider.RuntimeConfig) ([]providertypes.ModelDescriptor, error) { - client, err := newSDKClient(ctx, cfg) + client, err := newDiscoverySDKClient(ctx, cfg) if err != nil { return nil, err } @@ -63,14 +63,25 @@ func Driver() provider.DriverDefinition { } } -// newSDKClient 构造 Gemini SDK 客户端,供生成与模型发现链路共享连接配置。 -func newSDKClient(ctx context.Context, cfg provider.RuntimeConfig) (*genai.Client, error) { +// newDiscoverySDKClient 构造模型发现使用的 Gemini SDK 客户端。 +func newDiscoverySDKClient(ctx context.Context, cfg provider.RuntimeConfig) (*genai.Client, error) { + return newSDKClient(ctx, cfg, true) +} + +// newGenerateSDKClient 构造生成链路使用的 Gemini SDK 客户端。 +func newGenerateSDKClient(ctx context.Context, cfg provider.RuntimeConfig) (*genai.Client, error) { + return newSDKClient(ctx, cfg, false) +} + +// newSDKClient 根据调用场景构造 Gemini SDK 客户端,避免生成链路被底层总超时抢占控制权。 +func newSDKClient(ctx context.Context, cfg provider.RuntimeConfig, discovery bool) (*genai.Client, error) { apiKey, err := cfg.ResolveAPIKeyValue() if err != nil { return nil, err } - httpClient := &http.Client{ - Timeout: provider.DefaultSDKRequestTimeout, + httpClient := &http.Client{} + if discovery { + httpClient.Timeout = provider.DefaultSDKRequestTimeout } clientConfig := &genai.ClientConfig{ APIKey: apiKey, diff --git a/internal/provider/gemini/provider.go b/internal/provider/gemini/provider.go index 3a6568a5..2739e2cf 100644 --- a/internal/provider/gemini/provider.go +++ b/internal/provider/gemini/provider.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "math/rand/v2" "net" "net/http" "strings" @@ -20,11 +19,7 @@ import ( const errorPrefix = "gemini provider: " -const ( - defaultGenerateRetryMax = 2 - generateRetryBaseWait = 1 * time.Second - generateRetryMaxWait = 5 * time.Second -) +const defaultGenerateRetryMax = provider.DefaultGenerateMaxRetries // Provider 封装 Gemini native 协议的请求发送与流式响应解析。 type Provider struct { @@ -74,19 +69,17 @@ func (p *Provider) EstimateInputTokens( }, nil } -// New 创建 Gemini native provider 实例,并初始化官方 SDK 客户端。 +// New 创建 Gemini native provider 实例。 func New(cfg provider.RuntimeConfig) (*Provider, error) { if strings.TrimSpace(cfg.APIKeyEnv) == "" { return nil, errors.New(errorPrefix + "api_key_env is empty") } return &Provider{ - cfg: cfg, - retryBackoff: generateRetryBackoff, - retryWait: waitForRetry, + cfg: cfg, }, nil } -// Generate 发起 Gemini 流式请求,并将 SDK chunk 转为统一流式事件。 +// Generate 发起 Gemini 流式请求,并将重试与超时语义收敛到 provider 公共 runner。 func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { model, contents, config, ok := p.takePreparedRequest(provider.BuildGenerateRequestSignature(req)) if !ok { @@ -100,66 +93,46 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque if normalizedModel == "" { return errors.New(errorPrefix + "model is empty") } - var lastErr error - for attempt := 0; attempt <= defaultGenerateRetryMax; attempt++ { - if attempt > 0 { - wait := generateRetryBaseWait - if p.retryBackoff != nil { - wait = p.retryBackoff(attempt) - } - if p.retryWait != nil { - if err := p.retryWait(ctx, wait); err != nil { - return err - } - } - } - started, err := p.generateOnce(ctx, normalizedModel, contents, config, events) - if err == nil { - return nil - } - lastErr = err - if started || !isRetryableGenerateError(err) { - return err - } - if ctx.Err() != nil { - return ctx.Err() - } - } - return lastErr + return provider.RunGenerateWithRetryUsing(ctx, p.cfg, events, p.retryBackoff, p.retryWait, func( + attemptCtx context.Context, + attemptEvents chan<- providertypes.StreamEvent, + ) error { + return p.generateOnce(attemptCtx, normalizedModel, contents, config, attemptEvents) + }) } -// generateOnce 执行一次 Gemini 流式请求,并在未收到任何输出时返回可重试错误。 +// generateOnce 执行一次 Gemini 流式尝试,并将 SDK chunk 转为统一流式事件。 func (p *Provider) generateOnce( ctx context.Context, model string, contents []*genai.Content, config *genai.GenerateContentConfig, events chan<- providertypes.StreamEvent, -) (bool, error) { - client, err := newSDKClient(ctx, p.cfg) +) error { + client, err := newGenerateSDKClient(ctx, p.cfg) if err != nil { - return false, err + return err } var ( finishReason string usage providertypes.Usage - hasPayload bool + hasChunk bool callSeq int ) for chunk, streamErr := range client.Models.GenerateContentStream(ctx, model, contents, config) { if streamErr != nil { if ctxErr := ctx.Err(); ctxErr != nil { - return hasPayload, ctxErr + return ctxErr } - return hasPayload, normalizeGenerateError(streamErr) + return normalizeGenerateError(streamErr) } if chunk == nil { continue } - hasPayload = true + hasChunk = true extractUsage(&usage, chunk.UsageMetadata) for _, candidate := range chunk.Candidates { @@ -175,7 +148,7 @@ func (p *Provider) generateOnce( } if strings.TrimSpace(part.Text) != "" { if err := provider.EmitTextDelta(ctx, events, part.Text); err != nil { - return hasPayload, err + return err } } if part.FunctionCall == nil { @@ -192,25 +165,25 @@ func (p *Provider) generateOnce( continue } if err := provider.EmitToolCallStart(ctx, events, callSeq-1, callID, name); err != nil { - return hasPayload, err + return err } argsJSON, err := encodeArguments(part.FunctionCall.Args) if err != nil { - return hasPayload, err + return err } if err := provider.EmitToolCallDelta(ctx, events, callSeq-1, callID, argsJSON); err != nil { - return hasPayload, err + return err } } } } - if !hasPayload { - return false, fmt.Errorf("%w: empty gemini stream payload", provider.ErrStreamInterrupted) + if !hasChunk { + return fmt.Errorf("%w: empty gemini stream payload", provider.ErrStreamInterrupted) } if !usage.InputObserved && !usage.OutputObserved { - return true, provider.EmitMessageDone(ctx, events, finishReason, nil) + return provider.EmitMessageDone(ctx, events, finishReason, nil) } - return true, provider.EmitMessageDone(ctx, events, finishReason, &usage) + return provider.EmitMessageDone(ctx, events, finishReason, &usage) } // storePreparedRequest 缓存估算阶段的 Gemini 构建结果,供同轮发送直接复用。 @@ -304,15 +277,6 @@ func normalizeGenerateError(err error) error { return fmt.Errorf("%sstream generate: %w", errorPrefix, err) } -// isRetryableGenerateError 判断 Gemini provider 是否应在请求级重试当前错误。 -func isRetryableGenerateError(err error) bool { - if err == nil { - return false - } - var providerErr *provider.ProviderError - return errors.As(err, &providerErr) && providerErr.Retryable -} - // isTimeoutGenerateError 判断 Gemini 流式生成错误是否由超时触发。 func isTimeoutGenerateError(err error) bool { if err == nil { @@ -325,35 +289,6 @@ func isTimeoutGenerateError(err error) bool { return errors.As(err, &netErr) && netErr.Timeout() } -// generateRetryBackoff 计算 Gemini provider 请求级重试的指数退避时长。 -func generateRetryBackoff(attempt int) time.Duration { - if attempt <= 0 { - return 0 - } - wait := generateRetryBaseWait << (attempt - 1) - jitter := float64(wait) * (0.5 + rand.Float64()) - wait = time.Duration(jitter) - if wait > generateRetryMaxWait { - wait = generateRetryMaxWait - } - return wait -} - -// waitForRetry 在重试窗口内等待,同时尊重上层上下文取消。 -func waitForRetry(ctx context.Context, wait time.Duration) error { - if wait <= 0 { - return nil - } - timer := time.NewTimer(wait) - defer timer.Stop() - select { - case <-ctx.Done(): - return ctx.Err() - case <-timer.C: - return nil - } -} - // mapGeminiSDKError 将 Gemini SDK 错误映射为 provider 领域错误,仅保留状态码级别兜底。 func mapGeminiSDKError(err error) error { var apiErr genai.APIError diff --git a/internal/provider/generate_attempt.go b/internal/provider/generate_attempt.go new file mode 100644 index 00000000..f8bec712 --- /dev/null +++ b/internal/provider/generate_attempt.go @@ -0,0 +1,370 @@ +package provider + +import ( + "context" + "errors" + "fmt" + "math/rand/v2" + "sync/atomic" + "time" + + providertypes "neo-code/internal/provider/types" +) + +var ( + // ErrGenerateStartTimeout 标记生成请求在首包前超时。 + ErrGenerateStartTimeout = errors.New("provider: generate start timeout") + // ErrGenerateIdleTimeout 标记生成请求在首包后空闲超时。 + ErrGenerateIdleTimeout = errors.New("provider: generate idle timeout") +) + +type generateAttemptPhase uint32 + +const ( + generateAttemptPhaseWaitingFirstPayload generateAttemptPhase = iota + generateAttemptPhaseStreaming + generateAttemptPhaseCompleted +) + +type generateAttemptRunner struct { + cfg RuntimeConfig + retryBackoff func(attempt int) time.Duration + retryWait func(ctx context.Context, wait time.Duration) error +} + +type generateAttemptRunFunc func(ctx context.Context, events chan<- providertypes.StreamEvent) error + +type generateAttemptResult struct { + payloadStarted bool + err error + retryable bool +} + +// RunGenerateWithRetry 以统一的首包/空闲超时语义执行 provider 生成请求。 +func RunGenerateWithRetry( + ctx context.Context, + cfg RuntimeConfig, + events chan<- providertypes.StreamEvent, + run generateAttemptRunFunc, +) error { + runner := generateAttemptRunner{cfg: cfg} + return runner.run(ctx, events, run) +} + +// RunGenerateWithRetryUsing 使用可注入的等待策略执行统一生成 runner,供测试场景稳定控制重试节奏。 +func RunGenerateWithRetryUsing( + ctx context.Context, + cfg RuntimeConfig, + events chan<- providertypes.StreamEvent, + retryBackoff func(attempt int) time.Duration, + retryWait func(ctx context.Context, wait time.Duration) error, + run generateAttemptRunFunc, +) error { + runner := generateAttemptRunner{ + cfg: cfg, + retryBackoff: retryBackoff, + retryWait: retryWait, + } + return runner.run(ctx, events, run) +} + +// run 执行统一生成 runner 的外层重试循环。 +func (r generateAttemptRunner) run( + ctx context.Context, + events chan<- providertypes.StreamEvent, + run generateAttemptRunFunc, +) error { + maxRetries := r.cfg.ResolvedGenerateMaxRetries() + var lastErr error + for attempt := 0; attempt <= maxRetries; attempt++ { + if attempt > 0 { + wait := generateRetryBackoff(attempt) + if r.retryBackoff != nil { + wait = r.retryBackoff(attempt) + } + waitFn := waitForRetry + if r.retryWait != nil { + waitFn = r.retryWait + } + if err := waitFn(ctx, wait); err != nil { + return err + } + } + + result := r.runOnce(ctx, events, run) + if result.err == nil { + return nil + } + lastErr = result.err + if result.payloadStarted || !result.retryable { + return result.err + } + if ctx.Err() != nil { + return ctx.Err() + } + } + return lastErr +} + +// runOnce 执行一次生成尝试,并基于统一观察器产出重试决策。 +func (r generateAttemptRunner) runOnce( + ctx context.Context, + events chan<- providertypes.StreamEvent, + run generateAttemptRunFunc, +) generateAttemptResult { + attemptCtx, cancelAttempt := context.WithCancelCause(ctx) + defer cancelAttempt(nil) + + proxyEvents := make(chan providertypes.StreamEvent, 32) + phase := &atomic.Uint32{} + forwardDone := make(chan error, 1) + go func() { + forwardDone <- forwardAttemptEvents( + attemptCtx, + proxyEvents, + events, + phase, + r.cfg.ResolvedGenerateStartTimeout(), + r.cfg.ResolvedGenerateIdleTimeout(), + cancelAttempt, + ) + }() + + runErr := run(attemptCtx, proxyEvents) + close(proxyEvents) + forwardErr := <-forwardDone + + phaseValue := generateAttemptPhase(phase.Load()) + if phaseValue == generateAttemptPhaseCompleted { + return generateAttemptResult{} + } + + if ctx.Err() != nil { + return generateAttemptResult{ + payloadStarted: phaseValue == generateAttemptPhaseStreaming, + err: ctx.Err(), + retryable: false, + } + } + if forwardErr != nil && runErr == nil { + runErr = forwardErr + } + + if cause := context.Cause(attemptCtx); cause != nil && !errors.Is(cause, ctx.Err()) { + if runErr == nil || errors.Is(runErr, context.Canceled) || errors.Is(runErr, context.DeadlineExceeded) { + runErr = cause + } + } + + if runErr == nil { + return generateAttemptResult{payloadStarted: phaseValue == generateAttemptPhaseStreaming} + } + payloadStarted := phaseValue == generateAttemptPhaseStreaming + return generateAttemptResult{ + payloadStarted: payloadStarted, + err: runErr, + retryable: !payloadStarted && isRetryableAttemptError(runErr), + } +} + +// forwardAttemptEvents 在转发事件时统一维护首包与空闲超时观察器。 +func forwardAttemptEvents( + ctx context.Context, + source <-chan providertypes.StreamEvent, + target chan<- providertypes.StreamEvent, + phase *atomic.Uint32, + startTimeout time.Duration, + idleTimeout time.Duration, + cancel context.CancelCauseFunc, +) error { + startTimer := time.NewTimer(startTimeout) + defer startTimer.Stop() + ctxDone := ctx.Done() + + var idleTimer *time.Timer + var idleTimerC <-chan time.Time + defer func() { + if idleTimer != nil { + idleTimer.Stop() + } + }() + + draining := false + + for { + select { + case <-ctxDone: + draining = true + ctxDone = nil + stopTimer(startTimer) + if idleTimer != nil { + stopTimer(idleTimer) + idleTimerC = nil + } + case <-startTimer.C: + if phase.Load() == uint32(generateAttemptPhaseWaitingFirstPayload) { + cancel(newGenerateStartTimeoutError(startTimeout)) + } + case <-idleTimerC: + if phase.Load() == uint32(generateAttemptPhaseStreaming) { + cancel(newGenerateIdleTimeoutError(idleTimeout)) + } + case event, ok := <-source: + if !ok { + return nil + } + phaseValue := updateGenerateAttemptPhase(event, phase) + if phaseValue == generateAttemptPhaseStreaming { + stopTimer(startTimer) + if idleTimer == nil { + idleTimer = time.NewTimer(idleTimeout) + idleTimerC = idleTimer.C + } else { + resetTimer(idleTimer, idleTimeout) + } + } + if phaseValue == generateAttemptPhaseCompleted { + stopTimer(startTimer) + if idleTimer != nil { + stopTimer(idleTimer) + idleTimerC = nil + } + } + if draining { + continue + } + if err := emitStreamEvent(ctx, target, event); err != nil { + draining = true + ctxDone = nil + stopTimer(startTimer) + if idleTimer != nil { + stopTimer(idleTimer) + idleTimerC = nil + } + continue + } + if phaseValue == generateAttemptPhaseCompleted { + draining = true + ctxDone = nil + cancel(nil) + } + } + } +} + +// updateGenerateAttemptPhase 统一维护生成尝试的阶段流转,确保首包、完成态和重试边界只在公共层定义一次。 +func updateGenerateAttemptPhase( + event providertypes.StreamEvent, + phase *atomic.Uint32, +) generateAttemptPhase { + current := generateAttemptPhase(phase.Load()) + if current == generateAttemptPhaseCompleted { + return current + } + if event.Type == providertypes.StreamEventMessageDone { + phase.Store(uint32(generateAttemptPhaseCompleted)) + return generateAttemptPhaseCompleted + } + if IsEffectiveGeneratePayloadEvent(event) { + if phase.CompareAndSwap( + uint32(generateAttemptPhaseWaitingFirstPayload), + uint32(generateAttemptPhaseStreaming), + ) { + return generateAttemptPhaseStreaming + } + return generateAttemptPhase(phase.Load()) + } + return current +} + +// IsEffectiveGeneratePayloadEvent 判断事件是否属于“流已开始”的有效 payload。 +func IsEffectiveGeneratePayloadEvent(event providertypes.StreamEvent) bool { + switch event.Type { + case providertypes.StreamEventTextDelta, providertypes.StreamEventToolCallStart, providertypes.StreamEventToolCallDelta: + return true + default: + return false + } +} + +// isRetryableAttemptError 统一判断一次尝试失败后是否仍可在首包前继续重试。 +func isRetryableAttemptError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, ErrStreamInterrupted) || errors.Is(err, ErrGenerateStartTimeout) { + return true + } + var providerErr *ProviderError + return errors.As(err, &providerErr) && providerErr.Retryable +} + +// generateRetryBackoff 计算生成链路统一退避时间,避免三家 provider 各自维护一套重试等待规则。 +func generateRetryBackoff(attempt int) time.Duration { + if attempt <= 0 { + return 0 + } + wait := DefaultGenerateRetryBaseWait << (attempt - 1) + jitter := float64(wait) * (0.5 + rand.Float64()) + wait = time.Duration(jitter) + if wait > DefaultGenerateRetryMaxWait { + wait = DefaultGenerateRetryMaxWait + } + return wait +} + +// waitForRetry 在统一重试窗口内等待,同时尊重上层上下文取消。 +func waitForRetry(ctx context.Context, wait time.Duration) error { + if wait <= 0 { + return nil + } + timer := time.NewTimer(wait) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} + +// newGenerateStartTimeoutError 构造统一的首包超时错误,供 attempt runner 与上层错误判断复用。 +func newGenerateStartTimeoutError(timeout time.Duration) error { + return fmt.Errorf( + "%w: %w", + ErrGenerateStartTimeout, + NewTimeoutProviderError(fmt.Sprintf("generate start timeout after %s", timeout)), + ) +} + +// newGenerateIdleTimeoutError 构造统一的流空闲超时错误,供 attempt runner 与上层错误判断复用。 +func newGenerateIdleTimeoutError(timeout time.Duration) error { + return fmt.Errorf( + "%w: %w", + ErrGenerateIdleTimeout, + NewTimeoutProviderError(fmt.Sprintf("generate idle timeout after %s", timeout)), + ) +} + +// stopTimer 安全停止计时器,并清理可能残留的触发信号。 +func stopTimer(timer *time.Timer) { + if timer == nil { + return + } + if timer.Stop() { + return + } + select { + case <-timer.C: + default: + } +} + +// resetTimer 在统一观察器中安全重置计时器,避免复用时遗留旧信号。 +func resetTimer(timer *time.Timer, wait time.Duration) { + if timer == nil { + return + } + stopTimer(timer) + timer.Reset(wait) +} diff --git a/internal/provider/generate_attempt_test.go b/internal/provider/generate_attempt_test.go new file mode 100644 index 00000000..462abe85 --- /dev/null +++ b/internal/provider/generate_attempt_test.go @@ -0,0 +1,270 @@ +package provider + +import ( + "context" + "errors" + "testing" + "time" + + providertypes "neo-code/internal/provider/types" +) + +func TestRunGenerateWithRetryUsingRetriesBeforePayloadStarts(t *testing.T) { + t.Parallel() + + cfg := RuntimeConfig{ + GenerateMaxRetries: 2, + GenerateStartTimeout: time.Second, + GenerateIdleTimeout: time.Second, + } + events := make(chan providertypes.StreamEvent, 8) + attempts := 0 + + err := RunGenerateWithRetryUsing( + context.Background(), + cfg, + events, + func(int) time.Duration { return 0 }, + func(context.Context, time.Duration) error { return nil }, + func(ctx context.Context, attemptEvents chan<- providertypes.StreamEvent) error { + attempts++ + if attempts < 3 { + return NewProviderErrorFromStatus(500, "temporary") + } + if emitErr := EmitTextDelta(ctx, attemptEvents, "ok"); emitErr != nil { + return emitErr + } + return EmitMessageDone(ctx, attemptEvents, "stop", nil) + }, + ) + if err != nil { + t.Fatalf("RunGenerateWithRetryUsing() error = %v", err) + } + if attempts != 3 { + t.Fatalf("expected 3 attempts, got %d", attempts) + } + + drained := drainAttemptEvents(events) + if len(drained) != 2 { + t.Fatalf("expected success events to be forwarded once, got %+v", drained) + } +} + +func TestRunGenerateWithRetryUsingDoesNotRetryAfterPayloadStarts(t *testing.T) { + t.Parallel() + + cfg := RuntimeConfig{ + GenerateMaxRetries: 3, + GenerateStartTimeout: time.Second, + GenerateIdleTimeout: time.Second, + } + events := make(chan providertypes.StreamEvent, 8) + attempts := 0 + + err := RunGenerateWithRetryUsing( + context.Background(), + cfg, + events, + func(int) time.Duration { return 0 }, + func(context.Context, time.Duration) error { return nil }, + func(ctx context.Context, attemptEvents chan<- providertypes.StreamEvent) error { + attempts++ + if emitErr := EmitTextDelta(ctx, attemptEvents, "partial"); emitErr != nil { + return emitErr + } + return NewProviderErrorFromStatus(500, "temporary") + }, + ) + if err == nil { + t.Fatal("expected error after payload-started failure") + } + if attempts != 1 { + t.Fatalf("expected exactly 1 attempt, got %d", attempts) + } +} + +func TestRunGenerateWithRetryUsingRetriesOnStartTimeout(t *testing.T) { + t.Parallel() + + cfg := RuntimeConfig{ + GenerateMaxRetries: 1, + GenerateStartTimeout: 20 * time.Millisecond, + GenerateIdleTimeout: time.Second, + } + events := make(chan providertypes.StreamEvent, 8) + attempts := 0 + + err := RunGenerateWithRetryUsing( + context.Background(), + cfg, + events, + func(int) time.Duration { return 0 }, + func(context.Context, time.Duration) error { return nil }, + func(ctx context.Context, attemptEvents chan<- providertypes.StreamEvent) error { + attempts++ + if attempts == 1 { + <-ctx.Done() + return ctx.Err() + } + if emitErr := EmitTextDelta(ctx, attemptEvents, "ok"); emitErr != nil { + return emitErr + } + return EmitMessageDone(ctx, attemptEvents, "stop", nil) + }, + ) + if err != nil { + t.Fatalf("RunGenerateWithRetryUsing() error = %v", err) + } + if attempts != 2 { + t.Fatalf("expected retry after start timeout, got %d attempts", attempts) + } +} + +func TestRunGenerateWithRetryUsingStopsOnIdleTimeout(t *testing.T) { + t.Parallel() + + cfg := RuntimeConfig{ + GenerateMaxRetries: 3, + GenerateStartTimeout: time.Second, + GenerateIdleTimeout: 20 * time.Millisecond, + } + events := make(chan providertypes.StreamEvent, 8) + attempts := 0 + + err := RunGenerateWithRetryUsing( + context.Background(), + cfg, + events, + func(int) time.Duration { return 0 }, + func(context.Context, time.Duration) error { return nil }, + func(ctx context.Context, attemptEvents chan<- providertypes.StreamEvent) error { + attempts++ + if emitErr := EmitTextDelta(ctx, attemptEvents, "partial"); emitErr != nil { + return emitErr + } + <-ctx.Done() + return ctx.Err() + }, + ) + if !errors.Is(err, ErrGenerateIdleTimeout) { + t.Fatalf("expected idle timeout error, got %v", err) + } + if attempts != 1 { + t.Fatalf("expected idle timeout to stop retries, got %d attempts", attempts) + } +} + +func TestRunGenerateWithRetryUsingDoesNotStartTimeoutAfterMessageDone(t *testing.T) { + t.Parallel() + + cfg := RuntimeConfig{ + GenerateMaxRetries: 1, + GenerateStartTimeout: 20 * time.Millisecond, + GenerateIdleTimeout: time.Second, + } + events := make(chan providertypes.StreamEvent, 8) + + err := RunGenerateWithRetryUsing( + context.Background(), + cfg, + events, + func(int) time.Duration { return 0 }, + func(context.Context, time.Duration) error { return nil }, + func(ctx context.Context, attemptEvents chan<- providertypes.StreamEvent) error { + if emitErr := EmitMessageDone(ctx, attemptEvents, "stop", nil); emitErr != nil { + return emitErr + } + time.Sleep(40 * time.Millisecond) + return nil + }, + ) + if err != nil { + t.Fatalf("RunGenerateWithRetryUsing() error = %v", err) + } +} + +func TestRunGenerateWithRetryUsingDrainsEventsAfterTimeoutCancel(t *testing.T) { + t.Parallel() + + cfg := RuntimeConfig{ + GenerateMaxRetries: 0, + GenerateStartTimeout: 20 * time.Millisecond, + GenerateIdleTimeout: time.Second, + } + events := make(chan providertypes.StreamEvent, 8) + + done := make(chan error, 1) + go func() { + done <- RunGenerateWithRetryUsing( + context.Background(), + cfg, + events, + func(int) time.Duration { return 0 }, + func(context.Context, time.Duration) error { return nil }, + func(ctx context.Context, attemptEvents chan<- providertypes.StreamEvent) error { + <-ctx.Done() + for i := 0; i < 128; i++ { + if err := EmitTextDelta(ctx, attemptEvents, "ignored"); err != nil { + return err + } + } + return ctx.Err() + }, + ) + }() + + select { + case err := <-done: + if !errors.Is(err, ErrGenerateStartTimeout) { + t.Fatalf("expected start timeout, got %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("expected generate attempt to return without deadlock") + } +} + +func TestRunGenerateWithRetryUsingTreatsMessageDoneAsCompletedState(t *testing.T) { + t.Parallel() + + cfg := RuntimeConfig{ + GenerateMaxRetries: 1, + GenerateStartTimeout: time.Second, + GenerateIdleTimeout: time.Second, + } + events := make(chan providertypes.StreamEvent, 8) + + err := RunGenerateWithRetryUsing( + context.Background(), + cfg, + events, + func(int) time.Duration { return 0 }, + func(context.Context, time.Duration) error { return nil }, + func(ctx context.Context, attemptEvents chan<- providertypes.StreamEvent) error { + if emitErr := EmitMessageDone(ctx, attemptEvents, "stop", nil); emitErr != nil { + return emitErr + } + <-ctx.Done() + return ctx.Err() + }, + ) + if err != nil { + t.Fatalf("expected completed attempt to ignore trailing cancellation, got %v", err) + } + + drained := drainAttemptEvents(events) + if len(drained) != 1 || drained[0].Type != providertypes.StreamEventMessageDone { + t.Fatalf("expected only message_done to be forwarded, got %+v", drained) + } +} + +func drainAttemptEvents(events <-chan providertypes.StreamEvent) []providertypes.StreamEvent { + out := make([]providertypes.StreamEvent, 0, len(events)) + for { + select { + case event := <-events: + out = append(out, event) + default: + return out + } + } +} diff --git a/internal/provider/openaicompat/driver_internal_test.go b/internal/provider/openaicompat/driver_internal_test.go index 122ba862..c9695f79 100644 --- a/internal/provider/openaicompat/driver_internal_test.go +++ b/internal/provider/openaicompat/driver_internal_test.go @@ -42,7 +42,7 @@ func TestDriverClosuresAndSupportedProtocol(t *testing.T) { t.Fatalf("Build() error = %v", err) } typed, ok := built.(*Provider) - if !ok || typed.client == nil || typed.client.Transport == nil { + if !ok || typed.generateClient == nil || typed.discoveryClient == nil || typed.generateClient.Transport == nil { t.Fatalf("unexpected built provider: %T %+v", built, typed) } @@ -110,35 +110,31 @@ func TestDriverClosuresAndSupportedProtocol(t *testing.T) { func TestFetchModelsAndGenerateExtraBranches(t *testing.T) { t.Parallel() - p := &Provider{ - cfg: provider.RuntimeConfig{ - Name: DriverName, - Driver: DriverName, - BaseURL: "://bad", - APIKeyEnv: "OPENAI_TEST_KEY", - APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), - DiscoveryEndpointPath: "/models", - }, - client: &http.Client{}, - } - cfg, _ := RequestConfigFromRuntime(p.cfg) - if _, err := DiscoverRawModels(context.Background(), p.client, cfg); err == nil || !strings.Contains(err.Error(), "build models request") { + cfg := provider.RuntimeConfig{ + Name: DriverName, + Driver: DriverName, + BaseURL: "://bad", + APIKeyEnv: "OPENAI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + DiscoveryEndpointPath: "/models", + } + requestCfg, _ := RequestConfigFromRuntime(cfg) + if _, err := DiscoverRawModels(context.Background(), &http.Client{}, requestCfg); err == nil || + !strings.Contains(err.Error(), "build models request") { t.Fatalf("expected build models request error, got %v", err) } - p = &Provider{ - cfg: provider.RuntimeConfig{ - Name: DriverName, - Driver: DriverName, - BaseURL: "https://api.example.com/v1", - APIKeyEnv: "OPENAI_TEST_KEY", - APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), - DiscoveryEndpointPath: "https://api.example.com/models", - }, - client: &http.Client{}, + cfg = provider.RuntimeConfig{ + Name: DriverName, + Driver: DriverName, + BaseURL: "https://api.example.com/v1", + APIKeyEnv: "OPENAI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + DiscoveryEndpointPath: "https://api.example.com/models", } - cfg, _ = RequestConfigFromRuntime(p.cfg) - if _, err := DiscoverRawModels(context.Background(), p.client, cfg); err == nil || !provider.IsDiscoveryConfigError(err) { + requestCfg, _ = RequestConfigFromRuntime(cfg) + if _, err := DiscoverRawModels(context.Background(), &http.Client{}, requestCfg); err == nil || + !provider.IsDiscoveryConfigError(err) { t.Fatalf("expected discovery config error, got %v", err) } @@ -150,20 +146,17 @@ func TestFetchModelsAndGenerateExtraBranches(t *testing.T) { })) defer server.Close() - p = &Provider{ - cfg: provider.RuntimeConfig{ - Name: DriverName, - Driver: DriverName, - BaseURL: server.URL, - APIKeyEnv: "OPENAI_TEST_KEY", - APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), - DiscoveryEndpointPath: "/models", - }, - client: server.Client(), + cfg = provider.RuntimeConfig{ + Name: DriverName, + Driver: DriverName, + BaseURL: server.URL, + APIKeyEnv: "OPENAI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + DiscoveryEndpointPath: "/models", } - cfg, _ = RequestConfigFromRuntime(p.cfg) - if _, err := DiscoverRawModels(context.Background(), p.client, cfg); err != nil { + requestCfg, _ = RequestConfigFromRuntime(cfg) + if _, err := DiscoverRawModels(context.Background(), server.Client(), requestCfg); err != nil { t.Fatalf("DiscoverRawModels() error = %v", err) } if auth != "Bearer test-key" { diff --git a/internal/provider/openaicompat/generate_sdk.go b/internal/provider/openaicompat/generate_sdk.go index 6db1dee3..6b72214e 100644 --- a/internal/provider/openaicompat/generate_sdk.go +++ b/internal/provider/openaicompat/generate_sdk.go @@ -6,8 +6,10 @@ import ( "errors" "fmt" "io" + "net" "net/http" "strings" + "sync/atomic" openai "github.com/openai/openai-go/v3" "github.com/openai/openai-go/v3/option" @@ -36,11 +38,12 @@ func (p *Provider) generateSDKChatCompletions( stream := client.Chat.Completions.NewStreaming(ctx, params) defer func() { _ = stream.Close() }() - if err := chatcompletions.EmitFromSDKStream(ctx, stream, events); err != nil { + payloadStarted, err := emitSDKChatCompletionStream(ctx, stream, events) + if err != nil { if mapped, ok := mapOpenAIError(err); ok { return mapped } - if !shouldFallbackToCompatibleChatStream(err) { + if !shouldFallbackToCompatibleChatStream(err, payloadStarted) { return err } return p.generateChatCompletionsWithCompatibleStream(ctx, payload, events) @@ -48,6 +51,49 @@ func (p *Provider) generateSDKChatCompletions( return nil } +// emitSDKChatCompletionStream 转发 SDK typed stream 事件,并额外记录是否已真正发出有效 payload。 +func emitSDKChatCompletionStream( + ctx context.Context, + stream any, + events chan<- providertypes.StreamEvent, +) (bool, error) { + proxyEvents := make(chan providertypes.StreamEvent, 16) + forwardErrCh := make(chan error, 1) + var payloadStarted atomic.Bool + + go func() { + forwardErrCh <- forwardSDKStreamEvents(ctx, proxyEvents, events, &payloadStarted) + }() + + err := chatcompletions.EmitFromSDKStream(ctx, stream, proxyEvents) + close(proxyEvents) + forwardErr := <-forwardErrCh + if err == nil { + err = forwardErr + } + return payloadStarted.Load(), err +} + +// forwardSDKStreamEvents 负责把 SDK 事件转发给上层,同时只按有效 payload 规则标记流已开始。 +func forwardSDKStreamEvents( + ctx context.Context, + source <-chan providertypes.StreamEvent, + target chan<- providertypes.StreamEvent, + payloadStarted *atomic.Bool, +) error { + for event := range source { + if provider.IsEffectiveGeneratePayloadEvent(event) { + payloadStarted.Store(true) + } + select { + case <-ctx.Done(): + return ctx.Err() + case target <- event: + } + } + return nil +} + // convertToChatCompletionParams 将内部 chat/completions 请求映射为 OpenAI SDK 参数对象。 func convertToChatCompletionParams(req chatcompletions.Request) openai.ChatCompletionNewParams { params := openai.ChatCompletionNewParams{ @@ -183,10 +229,13 @@ func toSDKAssistantToolCalls(calls []chatcompletions.ToolCall) []openai.ChatComp } // shouldFallbackToCompatibleChatStream 判断是否需要从 SDK typed stream 降级到兼容流解析。 -func shouldFallbackToCompatibleChatStream(err error) bool { +func shouldFallbackToCompatibleChatStream(err error, payloadStarted bool) bool { if err == nil { return false } + if payloadStarted { + return false + } var syntaxErr *json.SyntaxError if errors.As(err, &syntaxErr) { @@ -332,20 +381,59 @@ func (p *Provider) generateSDKResponses( // wrapSDKRequestError 将 SDK 请求错误映射为统一 ProviderError;无法映射时保留原始错误链。 func wrapSDKRequestError(err error, action string) error { + if err == nil { + return nil + } if mapped, ok := mapOpenAIError(err); ok { return mapped } + if errors.Is(err, context.Canceled) { + return err + } + message := strings.TrimSpace(safeErrorMessage(err)) + if message == "" { + message = "unknown transport error" + } + if isSDKTransportTimeout(err) { + return provider.NewTimeoutProviderError( + fmt.Sprintf("%s%s timeout: %s", errorPrefix, strings.TrimSpace(action), message), + ) + } + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + return provider.NewNetworkProviderError( + fmt.Sprintf("%s%s: %s", errorPrefix, strings.TrimSpace(action), message), + ) + } + var netErr net.Error + if errors.As(err, &netErr) { + return provider.NewNetworkProviderError( + fmt.Sprintf("%s%s: %s", errorPrefix, strings.TrimSpace(action), message), + ) + } return fmt.Errorf("%s%s: %w", errorPrefix, strings.TrimSpace(action), err) } +// isSDKTransportTimeout 统一识别 OpenAI-compatible 发送阶段的超时错误。 +func isSDKTransportTimeout(err error) bool { + if err == nil { + return false + } + if errors.Is(err, context.DeadlineExceeded) { + return true + } + var netErr net.Error + return errors.As(err, &netErr) && netErr.Timeout() +} + func (p *Provider) newSDKClient() (openai.Client, error) { apiKey, err := p.cfg.ResolveAPIKeyValue() if err != nil { return openai.Client{}, err } return openai.NewClient( - option.WithHTTPClient(p.client), + option.WithHTTPClient(p.generateClient), option.WithAPIKey(apiKey), + option.WithMaxRetries(0), option.WithBaseURL(strings.TrimRight(strings.TrimSpace(p.cfg.BaseURL), "/")), ), nil } diff --git a/internal/provider/openaicompat/generate_sdk_test.go b/internal/provider/openaicompat/generate_sdk_test.go index e4425063..c677150e 100644 --- a/internal/provider/openaicompat/generate_sdk_test.go +++ b/internal/provider/openaicompat/generate_sdk_test.go @@ -1,6 +1,7 @@ package openaicompat import ( + "context" "encoding/json" "errors" "fmt" @@ -12,6 +13,7 @@ import ( "neo-code/internal/provider" "neo-code/internal/provider/openaicompat/chatcompletions" + providertypes "neo-code/internal/provider/types" ) type closeTrackingReadCloser struct { @@ -259,21 +261,24 @@ func TestConvertToChatCompletionParamsEnablesUsageInStream(t *testing.T) { func TestShouldFallbackToCompatibleChatStream(t *testing.T) { t.Parallel() - if shouldFallbackToCompatibleChatStream(io.EOF) { + if shouldFallbackToCompatibleChatStream(io.EOF, false) { t.Fatal("did not expect fallback for EOF") } - if !shouldFallbackToCompatibleChatStream(errors.New("SDK stream error: invalid character '[' after top-level value")) { + if !shouldFallbackToCompatibleChatStream(errors.New("SDK stream error: invalid character '[' after top-level value"), false) { t.Fatal("expected fallback for weak SSE decode error") } - if !shouldFallbackToCompatibleChatStream(fmt.Errorf("SDK stream error: %w", &json.SyntaxError{Offset: 1})) { + if !shouldFallbackToCompatibleChatStream(fmt.Errorf("SDK stream error: %w", &json.SyntaxError{Offset: 1}), false) { t.Fatal("expected fallback for json syntax error") } - if !shouldFallbackToCompatibleChatStream(fmt.Errorf("SDK stream error: %w", io.ErrUnexpectedEOF)) { + if !shouldFallbackToCompatibleChatStream(fmt.Errorf("SDK stream error: %w", io.ErrUnexpectedEOF), false) { t.Fatal("expected fallback for unexpected EOF") } - if shouldFallbackToCompatibleChatStream(errors.New("context deadline exceeded")) { + if shouldFallbackToCompatibleChatStream(errors.New("context deadline exceeded"), false) { t.Fatal("did not expect fallback for non-decode error") } + if shouldFallbackToCompatibleChatStream(errors.New("SDK stream error: invalid character '[' after top-level value"), true) { + t.Fatal("did not expect fallback after payload has started") + } } func TestMapOpenAIError(t *testing.T) { @@ -296,12 +301,102 @@ func TestWrapSDKRequestError(t *testing.T) { t.Parallel() wrapped := wrapSDKRequestError(io.EOF, "send request") - if !strings.Contains(wrapped.Error(), "send request") { - t.Fatalf("expected wrapped action in error, got %v", wrapped) + if !strings.Contains(wrapped.Error(), "network_error") { + t.Fatalf("expected network provider error, got %v", wrapped) } mapped := wrapSDKRequestError(&openai.Error{Message: "invalid key", StatusCode: 401}, "send request") if !strings.Contains(mapped.Error(), "auth_failed") { t.Fatalf("expected mapped provider error, got %v", mapped) } + + timeoutErr := wrapSDKRequestError(timeoutNetError{}, "send request") + if !strings.Contains(timeoutErr.Error(), "timeout") { + t.Fatalf("expected timeout provider error, got %v", timeoutErr) + } +} + +func TestEmitSDKChatCompletionStreamTracksPayloadStart(t *testing.T) { + t.Parallel() + + stream := &fakeOpenAICompatSDKStream{ + chunks: []openai.ChatCompletionChunk{ + { + Choices: []openai.ChatCompletionChunkChoice{ + { + Delta: openai.ChatCompletionChunkChoiceDelta{ + Content: "hello", + }, + }, + }, + }, + }, + err: errors.New("decode failed"), + } + + events := make(chan providertypes.StreamEvent, 4) + started, err := emitSDKChatCompletionStream(context.Background(), stream, events) + if err == nil { + t.Fatal("expected stream error") + } + if !started { + t.Fatal("expected payloadStarted=true after text delta") + } + if len(drainOpenAICompatEvents(events)) == 0 { + t.Fatal("expected forwarded events") + } +} + +func TestEmitSDKChatCompletionStreamKeepsPayloadStartFalseBeforeAnyEvent(t *testing.T) { + t.Parallel() + + stream := &fakeOpenAICompatSDKStream{ + err: errors.New("decode failed"), + } + + events := make(chan providertypes.StreamEvent, 4) + started, err := emitSDKChatCompletionStream(context.Background(), stream, events) + if err == nil { + t.Fatal("expected stream error") + } + if started { + t.Fatal("expected payloadStarted=false when no effective event was emitted") + } +} + +type fakeOpenAICompatSDKStream struct { + chunks []openai.ChatCompletionChunk + index int + err error +} + +func (s *fakeOpenAICompatSDKStream) Next() bool { + if s.index >= len(s.chunks) { + return false + } + s.index++ + return true +} + +func (s *fakeOpenAICompatSDKStream) Current() openai.ChatCompletionChunk { + if s.index == 0 { + return openai.ChatCompletionChunk{} + } + return s.chunks[s.index-1] +} + +func (s *fakeOpenAICompatSDKStream) Err() error { + return s.err +} + +func drainOpenAICompatEvents(events <-chan providertypes.StreamEvent) []providertypes.StreamEvent { + out := make([]providertypes.StreamEvent, 0, len(events)) + for { + select { + case evt := <-events: + out = append(out, evt) + default: + return out + } + } } diff --git a/internal/provider/openaicompat/openaicompat_test.go b/internal/provider/openaicompat/openaicompat_test.go index a45196c9..338b8f2e 100644 --- a/internal/provider/openaicompat/openaicompat_test.go +++ b/internal/provider/openaicompat/openaicompat_test.go @@ -43,8 +43,11 @@ func TestWithTransport(t *testing.T) { t.Fatalf("New() error = %v", err) } - if p.client.Transport != customTransport { - t.Fatal("expected custom transport to be set") + if p.generateClient.Transport != customTransport { + t.Fatal("expected generate client transport to be set") + } + if p.discoveryClient.Transport != customTransport { + t.Fatal("expected discovery client transport to be set") } } @@ -104,7 +107,7 @@ func TestNewDefaultTransportWhenNoOption(t *testing.T) { if err != nil { t.Fatalf("New() error = %v", err) } - if p.client.Transport == nil { + if p.generateClient.Transport == nil { t.Fatal("expected default transport to be set") } } @@ -139,7 +142,7 @@ func TestDiscoverModels(t *testing.T) { if err != nil { t.Fatalf("New() error = %v", err) } - p.client = server.Client() + p.discoveryClient = server.Client() models, err := p.DiscoverModels(context.Background()) if err != nil { @@ -175,7 +178,7 @@ func TestDiscoverModelsUsesConfiguredDiscoveryEndpointPath(t *testing.T) { if err != nil { t.Fatalf("New() error = %v", err) } - p.client = server.Client() + p.discoveryClient = server.Client() models, err := p.DiscoverModels(context.Background()) if err != nil { @@ -204,7 +207,7 @@ func TestDiscoverModelsParsesGeminiProfileModelList(t *testing.T) { if err != nil { t.Fatalf("New() error = %v", err) } - p.client = server.Client() + p.discoveryClient = server.Client() models, err := p.DiscoverModels(context.Background()) if err != nil { @@ -235,7 +238,7 @@ func TestDiscoverModelsParsesNestedContainerAndAliasFields(t *testing.T) { if err != nil { t.Fatalf("New() error = %v", err) } - p.client = server.Client() + p.discoveryClient = server.Client() models, err := p.DiscoverModels(context.Background()) if err != nil { @@ -289,7 +292,7 @@ data: [DONE] if err != nil { t.Fatalf("New() error = %v", err) } - p.client = server.Client() + p.discoveryClient = server.Client() reader := &singleUseSessionAssetReader{ maxOpen: 1, @@ -320,6 +323,84 @@ data: [DONE] } } +func TestGenerateRetriesReuseFrozenRequestPayload(t *testing.T) { + t.Setenv(config.OpenAIDefaultAPIKeyEnv, "test-key") + + attempts := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/gateway/chat/completions" { + t.Fatalf("unexpected path: %s", r.URL.Path) + } + attempts++ + if attempts == 1 { + w.WriteHeader(http.StatusTooManyRequests) + _, _ = w.Write([]byte(`{"error":{"message":"retry later"}}`)) + return + } + w.Header().Set("Content-Type", "text/event-stream") + _, _ = w.Write([]byte(`data: {"choices":[{"delta":{"content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}} +data: [DONE] + +`)) + })) + defer server.Close() + + cfg := resolvedConfig(server.URL, "gpt-4.1") + cfg.ChatEndpointPath = "/gateway/chat/completions" + p, err := New(cfg) + if err != nil { + t.Fatalf("New() error = %v", err) + } + p.generateClient = server.Client() + + reader := &singleUseSessionAssetReader{ + maxOpen: 1, + assets: map[string]sessionAsset{ + "asset-1": {data: []byte("image-bytes"), mime: "image/png"}, + }, + } + request := providertypes.GenerateRequest{ + Model: "gpt-4.1", + Messages: []providertypes.Message{ + { + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewSessionAssetImagePart("asset-1", "image/png")}, + }, + }, + SessionAssetReader: reader, + } + + events := make(chan providertypes.StreamEvent, 8) + if err := p.Generate(context.Background(), request, events); err != nil { + t.Fatalf("Generate() error = %v", err) + } + if attempts != 2 { + t.Fatalf("expected 2 attempts, got %d", attempts) + } + if reader.openCount != 1 { + t.Fatalf("expected session asset to be opened once across retries, got %d", reader.openCount) + } +} + +func TestNewCreatesDedicatedDiscoveryClient(t *testing.T) { + t.Parallel() + + customTransport := &http.Transport{} + p, err := New(resolvedConfig("https://api.example.com/v1", "gpt-4.1"), withTransport(customTransport)) + if err != nil { + t.Fatalf("New() error = %v", err) + } + if p.discoveryClient.Timeout != provider.DefaultSDKRequestTimeout { + t.Fatalf("expected discovery timeout %s, got %s", provider.DefaultSDKRequestTimeout, p.discoveryClient.Timeout) + } + if p.discoveryClient.Transport != customTransport { + t.Fatalf("expected discovery client to preserve custom transport") + } + if p.generateClient == p.discoveryClient { + t.Fatal("expected generate and discovery clients to stay separated") + } +} + func TestDiscoverModelsOpenAIProfileFallsBackToGenericListKeys(t *testing.T) { t.Parallel() @@ -338,7 +419,7 @@ func TestDiscoverModelsOpenAIProfileFallsBackToGenericListKeys(t *testing.T) { if err != nil { t.Fatalf("New() error = %v", err) } - p.client = server.Client() + p.discoveryClient = server.Client() models, err := p.DiscoverModels(context.Background()) if err != nil { @@ -365,7 +446,7 @@ func TestDiscoverModelsParsesStringModelIDs(t *testing.T) { if err != nil { t.Fatalf("New() error = %v", err) } - p.client = server.Client() + p.discoveryClient = server.Client() models, err := p.DiscoverModels(context.Background()) if err != nil { @@ -476,7 +557,7 @@ func TestBuildRequest_EmptyModelReturnsError(t *testing.T) { APIKeyEnv: "OPENAI_TEST_KEY", APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), }, - client: &http.Client{}, + generateClient: &http.Client{}, } _, buildErr := chatcompletions.BuildRequest(context.Background(), p.cfg, providertypes.GenerateRequest{}) @@ -641,7 +722,7 @@ data: [DONE] if err != nil { t.Fatalf("New() error = %v", err) } - p.client = server.Client() + p.generateClient = server.Client() events := make(chan providertypes.StreamEvent, 8) err = p.Generate(context.Background(), providertypes.GenerateRequest{ @@ -674,7 +755,7 @@ func TestDiscoverModelsSkipsInvalidEntriesAndDedupes(t *testing.T) { if err != nil { t.Fatalf("New() error = %v", err) } - p.client = server.Client() + p.discoveryClient = server.Client() models, err := p.DiscoverModels(context.Background()) if err != nil { diff --git a/internal/provider/openaicompat/provider.go b/internal/provider/openaicompat/provider.go index affa532f..e43f2e18 100644 --- a/internal/provider/openaicompat/provider.go +++ b/internal/provider/openaicompat/provider.go @@ -34,10 +34,11 @@ func validateRuntimeConfig(cfg provider.RuntimeConfig) error { return nil } -// Provider 封装 OpenAI-compatible 协议的运行时配置与 HTTP 客户端。 +// Provider 封装 OpenAI-compatible 协议的运行时配置和 HTTP 客户端。 type Provider struct { - cfg provider.RuntimeConfig - client *http.Client + cfg provider.RuntimeConfig + generateClient *http.Client + discoveryClient *http.Client mu sync.Mutex prepared *preparedRequest @@ -127,9 +128,13 @@ func New(cfg provider.RuntimeConfig, opts ...buildOption) (*Provider, error) { apply(o) } + streamClient := &http.Client{ + Transport: o.transport, + } return &Provider{ - cfg: cfg, - client: &http.Client{ + cfg: cfg, + generateClient: streamClient, + discoveryClient: &http.Client{ Timeout: provider.DefaultSDKRequestTimeout, Transport: o.transport, }, @@ -142,42 +147,62 @@ func (p *Provider) DiscoverModels(ctx context.Context) ([]providertypes.ModelDes if err != nil { return nil, err } - return DiscoverModelDescriptors(ctx, p.client, requestCfg) + return DiscoverModelDescriptors(ctx, p.discoveryClient, requestCfg) } -// Generate 发起流式生成请求。 +// Generate 发起流式生成请求,并将重试与超时语义收敛到 provider 公共 runner。 func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { mode, err := resolveExecutionMode(p.cfg) if err != nil { return err } + signature := provider.BuildGenerateRequestSignature(req) + + var completionsPayload chatcompletions.Request + var responsesPayload responses.Request switch mode { case executionModeCompletions: - signature := provider.BuildGenerateRequestSignature(req) if payload, ok := p.takePreparedChatCompletionsRequest(mode, signature); ok { - return p.generateSDKChatCompletions(ctx, payload, events) - } - payload, buildErr := chatcompletions.BuildRequest(ctx, p.cfg, req) - if buildErr != nil { - return buildErr + completionsPayload = payload + } else { + payload, buildErr := chatcompletions.BuildRequest(ctx, p.cfg, req) + if buildErr != nil { + return buildErr + } + completionsPayload = payload } - return p.generateSDKChatCompletions(ctx, payload, events) case executionModeResponses: - signature := provider.BuildGenerateRequestSignature(req) if payload, ok := p.takePreparedResponsesRequest(mode, signature); ok { - return p.generateSDKResponses(ctx, payload, events) + responsesPayload = payload + } else { + payload, buildErr := responses.BuildRequest(ctx, p.cfg, req) + if buildErr != nil { + return buildErr + } + responsesPayload = payload } - payload, buildErr := responses.BuildRequest(ctx, p.cfg, req) - if buildErr != nil { - return buildErr - } - return p.generateSDKResponses(ctx, payload, events) default: return provider.NewDiscoveryConfigError( fmt.Sprintf("openaicompat provider: driver %q resolved unsupported execution mode %q", p.cfg.Driver, mode), ) } + + return provider.RunGenerateWithRetry(ctx, p.cfg, events, func( + attemptCtx context.Context, + attemptEvents chan<- providertypes.StreamEvent, + ) error { + switch mode { + case executionModeCompletions: + return p.generateSDKChatCompletions(attemptCtx, completionsPayload, attemptEvents) + case executionModeResponses: + return p.generateSDKResponses(attemptCtx, responsesPayload, attemptEvents) + default: + return provider.NewDiscoveryConfigError( + fmt.Sprintf("openaicompat provider: driver %q resolved unsupported execution mode %q", p.cfg.Driver, mode), + ) + } + }) } // storePreparedRequest 缓存估算阶段已构建请求,供同轮发送复用以避免重复构建。 From 526ed669a810c3ab1239b423d567153b92365e78 Mon Sep 17 00:00:00 2001 From: xgopilot Date: Sat, 25 Apr 2026 13:06:15 +0000 Subject: [PATCH 5/8] fix(provider): repair gemini tests and cap generate retries Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: phantom5099 <245659304+phantom5099@users.noreply.github.com> --- docs/config-management-detail-design.md | 2 +- docs/guides/adding-providers.md | 2 +- docs/guides/configuration.md | 2 +- internal/config/provider.go | 13 ++++++- internal/config/provider_test.go | 7 ++++ internal/provider/constants.go | 5 +++ internal/provider/constants_test.go | 28 +++++++++++++++ internal/provider/gemini/provider_test.go | 43 ----------------------- 8 files changed, 55 insertions(+), 47 deletions(-) create mode 100644 internal/provider/constants_test.go diff --git a/docs/config-management-detail-design.md b/docs/config-management-detail-design.md index c6b47fe5..5241bcd2 100644 --- a/docs/config-management-detail-design.md +++ b/docs/config-management-detail-design.md @@ -69,7 +69,7 @@ custom provider 来自: ``` 当前只接受明确受支持的字段;未知字段会直接报错,不做“旧格式自动迁移”。 -`provider.yaml` 只支持平铺字段:`name/driver/base_url/api_key_env/model_source/chat_endpoint_path/discovery_endpoint_path/generate_max_retries/generate_start_timeout_sec/generate_idle_timeout_sec/models`。 +`provider.yaml` 只支持平铺字段:`name/driver/base_url/api_key_env/model_source/chat_api_mode/chat_endpoint_path/discovery_endpoint_path/generate_max_retries/generate_start_timeout_sec/generate_idle_timeout_sec/models`。 ## 加载流程 diff --git a/docs/guides/adding-providers.md b/docs/guides/adding-providers.md index 007534dc..f8e89d65 100644 --- a/docs/guides/adding-providers.md +++ b/docs/guides/adding-providers.md @@ -110,7 +110,7 @@ generate_idle_timeout_sec: 300 - `chat_endpoint_path` 为 `/` 表示直连 `base_url`;为空时会按 `chat_api_mode` 自动回填默认子路径(`/chat/completions` 或 `/responses`)。 - 当 `chat_api_mode` 已显式指定时,`chat_endpoint_path` 可使用任意以 `/` 开头的相对路径;未显式指定时,仅支持标准端点推断(`/chat/completions`、`/responses`、`/`)。 - `model_source: manual` 时必须提供 `models`,且会忽略 `discovery_endpoint_path`。 -- `generate_max_retries` / `generate_start_timeout_sec` / `generate_idle_timeout_sec` 用于控制 provider 级生成重试、首包超时和流空闲超时;未填写或 `<= 0` 时会分别回退到 `5 / 60 / 300`。 +- `generate_max_retries` / `generate_start_timeout_sec` / `generate_idle_timeout_sec` 用于控制 provider 级生成重试、首包超时和流空闲超时;未填写或 `<= 0` 时会分别回退到 `5 / 60 / 300`。其中 `generate_max_retries` 必须 `<= 20`。 ## 测试要求 diff --git a/docs/guides/configuration.md b/docs/guides/configuration.md index ee4ea99e..3f3eff3e 100644 --- a/docs/guides/configuration.md +++ b/docs/guides/configuration.md @@ -180,7 +180,7 @@ generate_idle_timeout_sec: 300 新增的生成链路控制字段含义如下: -- `generate_max_retries`:额外重试次数,不含首次尝试;`<= 0` 时回退默认值 `5`。 +- `generate_max_retries`:额外重试次数,不含首次尝试;`<= 0` 时回退默认值 `5`,且必须 `<= 20`。 - `generate_start_timeout_sec`:从发请求到收到首个有效流 payload 的最长等待窗口;`<= 0` 时回退默认值 `60`。 - `generate_idle_timeout_sec`:首包后连续没有任何新 payload 的最长空闲窗口;`<= 0` 时回退默认值 `300`。 diff --git a/internal/config/provider.go b/internal/config/provider.go index 736cfcbd..3ce61d23 100644 --- a/internal/config/provider.go +++ b/internal/config/provider.go @@ -84,7 +84,7 @@ func (p ProviderConfig) Validate() error { if strings.TrimSpace(p.APIKeyEnv) == "" { return fmt.Errorf("provider %q api_key_env is empty", p.Name) } - if err := validateOptionalNonNegativeGenerateControl("generate_max_retries", p.GenerateMaxRetries); err != nil { + if err := validateOptionalGenerateMaxRetries(p.GenerateMaxRetries); err != nil { return fmt.Errorf("provider %q: %w", p.Name, err) } if err := validateOptionalGenerateDurationSeconds("generate_start_timeout_sec", p.GenerateStartTimeoutSec); err != nil { @@ -130,6 +130,17 @@ func validateOptionalNonNegativeGenerateControl(field string, value int) error { return nil } +// validateOptionalGenerateMaxRetries 校验额外重试次数,防止超大值导致生成阶段重试循环过长。 +func validateOptionalGenerateMaxRetries(value int) error { + if err := validateOptionalNonNegativeGenerateControl("generate_max_retries", value); err != nil { + return err + } + if value > provider.MaxGenerateMaxRetries { + return fmt.Errorf("generate_max_retries must be less than or equal to %d", provider.MaxGenerateMaxRetries) + } + return nil +} + // validateOptionalGenerateDurationSeconds 校验秒级超时字段,避免负值和 duration 溢出在运行时被悄悄回退为默认值。 func validateOptionalGenerateDurationSeconds(field string, value int) error { if err := validateOptionalNonNegativeGenerateControl(field, value); err != nil { diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go index 63962aa7..8210ac7e 100644 --- a/internal/config/provider_test.go +++ b/internal/config/provider_test.go @@ -380,6 +380,13 @@ func TestProviderConfigValidateRejectsNegativeGenerateControls(t *testing.T) { }, errContain: "generate_max_retries", }, + { + name: "retries exceed upper bound", + mutate: func(cfg *ProviderConfig) { + cfg.GenerateMaxRetries = providerpkg.MaxGenerateMaxRetries + 1 + }, + errContain: "generate_max_retries", + }, { name: "negative start timeout", mutate: func(cfg *ProviderConfig) { diff --git a/internal/provider/constants.go b/internal/provider/constants.go index b5b2dafb..80da08a1 100644 --- a/internal/provider/constants.go +++ b/internal/provider/constants.go @@ -14,6 +14,8 @@ const ( const ( // DefaultGenerateMaxRetries 定义生成链路默认额外重试次数,不含首次尝试。 DefaultGenerateMaxRetries = 5 + // MaxGenerateMaxRetries 定义生成链路允许的额外重试次数上限,避免异常配置导致极长重试循环。 + MaxGenerateMaxRetries = 20 // DefaultGenerateStartTimeout 定义生成链路等待首个有效 payload 的默认窗口。 DefaultGenerateStartTimeout = 60 * time.Second // DefaultGenerateIdleTimeout 定义首包后默认的流空闲超时窗口。 @@ -31,6 +33,9 @@ func NormalizeGenerateMaxRetries(value int) int { if value <= 0 { return DefaultGenerateMaxRetries } + if value > MaxGenerateMaxRetries { + return MaxGenerateMaxRetries + } return value } diff --git a/internal/provider/constants_test.go b/internal/provider/constants_test.go new file mode 100644 index 00000000..a05e4010 --- /dev/null +++ b/internal/provider/constants_test.go @@ -0,0 +1,28 @@ +package provider + +import "testing" + +func TestNormalizeGenerateMaxRetries(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input int + want int + }{ + {name: "non positive fallback", input: 0, want: DefaultGenerateMaxRetries}, + {name: "negative fallback", input: -1, want: DefaultGenerateMaxRetries}, + {name: "keep in range", input: 3, want: 3}, + {name: "clamp upper bound", input: MaxGenerateMaxRetries + 10, want: MaxGenerateMaxRetries}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := NormalizeGenerateMaxRetries(tt.input); got != tt.want { + t.Fatalf("NormalizeGenerateMaxRetries(%d) = %d, want %d", tt.input, got, tt.want) + } + }) + } +} diff --git a/internal/provider/gemini/provider_test.go b/internal/provider/gemini/provider_test.go index 1e540d0f..596eeb7a 100644 --- a/internal/provider/gemini/provider_test.go +++ b/internal/provider/gemini/provider_test.go @@ -599,19 +599,6 @@ func TestGenerateHelpers(t *testing.T) { } }) - t.Run("retryable_error", func(t *testing.T) { - t.Parallel() - if isRetryableGenerateError(nil) { - t.Fatal("nil should not be retryable") - } - if isRetryableGenerateError(errors.New("plain")) { - t.Fatal("plain error should not be retryable") - } - if !isRetryableGenerateError(provider.NewNetworkProviderError("temporary")) { - t.Fatal("network provider error should be retryable") - } - }) - t.Run("timeout_error", func(t *testing.T) { t.Parallel() if !isTimeoutGenerateError(context.DeadlineExceeded) { @@ -623,36 +610,6 @@ func TestGenerateHelpers(t *testing.T) { }) } -func TestRetryBackoffAndWait(t *testing.T) { - t.Parallel() - - if wait := generateRetryBackoff(0); wait != 0 { - t.Fatalf("attempt 0 backoff = %v, want 0", wait) - } - - for attempt := 1; attempt <= 6; attempt++ { - wait := generateRetryBackoff(attempt) - if wait < 0 { - t.Fatalf("attempt %d backoff should be non-negative, got %v", attempt, wait) - } - if wait > generateRetryMaxWait { - t.Fatalf("attempt %d backoff should be <= %v, got %v", attempt, generateRetryMaxWait, wait) - } - } - - if err := waitForRetry(context.Background(), 0); err != nil { - t.Fatalf("waitForRetry(0) error = %v", err) - } - ctx, cancel := context.WithCancel(context.Background()) - cancel() - if err := waitForRetry(ctx, time.Millisecond); !errors.Is(err, context.Canceled) { - t.Fatalf("expected context canceled, got %v", err) - } - if err := waitForRetry(context.Background(), time.Millisecond); err != nil { - t.Fatalf("waitForRetry(timeout) error = %v", err) - } -} - func TestMapGeminiSDKError(t *testing.T) { t.Parallel() From 7226ec10c16ffed9dbfe1a4912f268fc5a17076f Mon Sep 17 00:00:00 2001 From: xgopilot Date: Sat, 25 Apr 2026 13:27:31 +0000 Subject: [PATCH 6/8] fix(provider,runtime): honor zero retries and remove runtime retry leftovers Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: phantom5099 <245659304+phantom5099@users.noreply.github.com> --- internal/config/provider_test.go | 2 +- internal/provider/constants.go | 4 +- internal/provider/constants_test.go | 2 +- internal/provider/gemini/provider_test.go | 77 ++++++++++--------- .../openaicompat/openaicompat_test.go | 1 + internal/runtime/run_lifecycle.go | 22 ------ internal/runtime/runtime.go | 5 +- 7 files changed, 48 insertions(+), 65 deletions(-) diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go index 8210ac7e..8422a8ca 100644 --- a/internal/config/provider_test.go +++ b/internal/config/provider_test.go @@ -782,7 +782,7 @@ func TestResolvedProviderConfigToRuntimeConfig(t *testing.T) { ChatAPIMode: "", ChatEndpointPath: "", DiscoveryEndpointPath: providerpkg.DiscoveryEndpointPathModels, - GenerateMaxRetries: providerpkg.DefaultGenerateMaxRetries, + GenerateMaxRetries: 0, GenerateStartTimeout: providerpkg.DefaultGenerateStartTimeout, GenerateIdleTimeout: providerpkg.DefaultGenerateIdleTimeout, } diff --git a/internal/provider/constants.go b/internal/provider/constants.go index 80da08a1..5b9216a2 100644 --- a/internal/provider/constants.go +++ b/internal/provider/constants.go @@ -28,9 +28,9 @@ const ( DefaultSDKRequestTimeout = 10 * time.Minute ) -// NormalizeGenerateMaxRetries 归一化生成链路额外重试次数,非正值回退到默认值。 +// NormalizeGenerateMaxRetries 归一化生成链路额外重试次数,负值回退到默认值。 func NormalizeGenerateMaxRetries(value int) int { - if value <= 0 { + if value < 0 { return DefaultGenerateMaxRetries } if value > MaxGenerateMaxRetries { diff --git a/internal/provider/constants_test.go b/internal/provider/constants_test.go index a05e4010..5d82b4df 100644 --- a/internal/provider/constants_test.go +++ b/internal/provider/constants_test.go @@ -10,7 +10,7 @@ func TestNormalizeGenerateMaxRetries(t *testing.T) { input int want int }{ - {name: "non positive fallback", input: 0, want: DefaultGenerateMaxRetries}, + {name: "zero keeps explicit value", input: 0, want: 0}, {name: "negative fallback", input: -1, want: DefaultGenerateMaxRetries}, {name: "keep in range", input: 3, want: 3}, {name: "clamp upper bound", input: MaxGenerateMaxRetries + 10, want: MaxGenerateMaxRetries}, diff --git a/internal/provider/gemini/provider_test.go b/internal/provider/gemini/provider_test.go index 596eeb7a..9c94f3d4 100644 --- a/internal/provider/gemini/provider_test.go +++ b/internal/provider/gemini/provider_test.go @@ -34,11 +34,12 @@ func TestProviderGenerate(t *testing.T) { defer server.Close() p, err := New(provider.RuntimeConfig{ - Driver: provider.DriverGemini, - BaseURL: server.URL, - DefaultModel: "gemini-2.5-flash", - APIKeyEnv: "GEMINI_TEST_KEY", - APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + Driver: provider.DriverGemini, + BaseURL: server.URL, + DefaultModel: "gemini-2.5-flash", + APIKeyEnv: "GEMINI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + GenerateMaxRetries: 1, }) if err != nil { t.Fatalf("New() error = %v", err) @@ -109,11 +110,12 @@ func TestProviderGenerateOmitsUsageWhenProviderDidNotReturnUsage(t *testing.T) { defer server.Close() p, err := New(provider.RuntimeConfig{ - Driver: provider.DriverGemini, - BaseURL: server.URL, - DefaultModel: "gemini-2.5-flash", - APIKeyEnv: "GEMINI_TEST_KEY", - APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + Driver: provider.DriverGemini, + BaseURL: server.URL, + DefaultModel: "gemini-2.5-flash", + APIKeyEnv: "GEMINI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + GenerateMaxRetries: defaultGenerateRetryMax, }) if err != nil { t.Fatalf("New() error = %v", err) @@ -291,11 +293,12 @@ func TestEstimateThenGenerateReusesPreparedRequest(t *testing.T) { defer server.Close() p, err := New(provider.RuntimeConfig{ - Driver: provider.DriverGemini, - BaseURL: server.URL, - DefaultModel: "gemini-2.5-flash", - APIKeyEnv: "GEMINI_TEST_KEY", - APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + Driver: provider.DriverGemini, + BaseURL: server.URL, + DefaultModel: "gemini-2.5-flash", + APIKeyEnv: "GEMINI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + GenerateMaxRetries: 1, }) if err != nil { t.Fatalf("New() error = %v", err) @@ -344,11 +347,12 @@ func TestProviderGenerateRetriesRetryableErrorBeforeStreamStarts(t *testing.T) { defer server.Close() p, err := New(provider.RuntimeConfig{ - Driver: provider.DriverGemini, - BaseURL: server.URL, - DefaultModel: "gemini-2.5-flash", - APIKeyEnv: "GEMINI_TEST_KEY", - APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + Driver: provider.DriverGemini, + BaseURL: server.URL, + DefaultModel: "gemini-2.5-flash", + APIKeyEnv: "GEMINI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + GenerateMaxRetries: 1, }) if err != nil { t.Fatalf("New() error = %v", err) @@ -391,11 +395,12 @@ func TestProviderGenerateReturnsRetryableErrorAfterRetryExhausted(t *testing.T) defer server.Close() p, err := New(provider.RuntimeConfig{ - Driver: provider.DriverGemini, - BaseURL: server.URL, - DefaultModel: "gemini-2.5-flash", - APIKeyEnv: "GEMINI_TEST_KEY", - APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + Driver: provider.DriverGemini, + BaseURL: server.URL, + DefaultModel: "gemini-2.5-flash", + APIKeyEnv: "GEMINI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + GenerateMaxRetries: defaultGenerateRetryMax, }) if err != nil { t.Fatalf("New() error = %v", err) @@ -447,11 +452,12 @@ func TestProviderGenerateRetryStateResetsAfterSuccess(t *testing.T) { defer server.Close() p, err := New(provider.RuntimeConfig{ - Driver: provider.DriverGemini, - BaseURL: server.URL, - DefaultModel: "gemini-2.5-flash", - APIKeyEnv: "GEMINI_TEST_KEY", - APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + Driver: provider.DriverGemini, + BaseURL: server.URL, + DefaultModel: "gemini-2.5-flash", + APIKeyEnv: "GEMINI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + GenerateMaxRetries: 1, }) if err != nil { t.Fatalf("New() error = %v", err) @@ -509,11 +515,12 @@ func TestProviderGenerateReturnsRetryWaitError(t *testing.T) { defer server.Close() p, err := New(provider.RuntimeConfig{ - Driver: provider.DriverGemini, - BaseURL: server.URL, - DefaultModel: "gemini-2.5-flash", - APIKeyEnv: "GEMINI_TEST_KEY", - APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + Driver: provider.DriverGemini, + BaseURL: server.URL, + DefaultModel: "gemini-2.5-flash", + APIKeyEnv: "GEMINI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + GenerateMaxRetries: 1, }) if err != nil { t.Fatalf("New() error = %v", err) diff --git a/internal/provider/openaicompat/openaicompat_test.go b/internal/provider/openaicompat/openaicompat_test.go index 338b8f2e..db290abf 100644 --- a/internal/provider/openaicompat/openaicompat_test.go +++ b/internal/provider/openaicompat/openaicompat_test.go @@ -347,6 +347,7 @@ data: [DONE] cfg := resolvedConfig(server.URL, "gpt-4.1") cfg.ChatEndpointPath = "/gateway/chat/completions" + cfg.GenerateMaxRetries = 1 p, err := New(cfg) if err != nil { t.Fatalf("New() error = %v", err) diff --git a/internal/runtime/run_lifecycle.go b/internal/runtime/run_lifecycle.go index bc2fff3e..821b3a5c 100644 --- a/internal/runtime/run_lifecycle.go +++ b/internal/runtime/run_lifecycle.go @@ -3,11 +3,9 @@ package runtime import ( "context" "errors" - "math/rand/v2" "strings" "time" - "neo-code/internal/provider" providertypes "neo-code/internal/provider/types" "neo-code/internal/runtime/controlplane" ) @@ -205,26 +203,6 @@ func (s *Service) handleRunError(err error) error { return err } -// isRetryableProviderError 判断 provider 错误是否允许 runtime 级重试。 -func isRetryableProviderError(err error) bool { - var providerErr *provider.ProviderError - if !errors.As(err, &providerErr) { - return false - } - return providerErr.Retryable -} - -// providerRetryBackoff 计算 runtime 级 provider 重试等待时长。 -func providerRetryBackoff(attempt int) time.Duration { - wait := providerRetryBaseWait << (attempt - 1) - jitter := float64(wait) * (0.5 + rand.Float64()) - wait = time.Duration(jitter) - if wait > providerRetryMaxWait { - wait = providerRetryMaxWait - } - return wait -} - // cloneMessages 深拷贝消息切片,避免后台调度读取到后续运行态修改。 func cloneMessages(messages []providertypes.Message) []providertypes.Message { if len(messages) == 0 { diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index fa3f438a..70bbb105 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -24,10 +24,7 @@ import ( ) const ( - defaultProviderRetryMax = 2 - providerRetryBaseWait = 1 * time.Second - providerRetryMaxWait = 5 * time.Second - defaultToolParallelism = 4 + defaultToolParallelism = 4 terminationEventEmitTimeout = 500 * time.Millisecond ) From 667057c8523405e91f9b19fac432d861529a6c18 Mon Sep 17 00:00:00 2001 From: phantom5099 <1011668688@qq.com> Date: Sun, 26 Apr 2026 11:35:30 +0800 Subject: [PATCH 7/8] =?UTF-8?q?pref(provider):=20=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E9=87=8D=E8=AF=95=E6=97=B6=E9=95=BF=EF=BC=8C=E5=8E=BB=E6=8E=89?= =?UTF-8?q?=E9=94=99=E8=AF=AF=E5=8E=8B=E7=BC=A9=E5=BE=AA=E7=8E=AF=EF=BC=8C?= =?UTF-8?q?=E4=B8=8A=E6=94=B6=E6=97=B6=E9=95=BF=E9=85=8D=E7=BD=AE=E8=87=B3?= =?UTF-8?q?config.yaml?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/config-management-detail-design.md | 6 +- docs/guides/adding-providers.md | 4 +- docs/guides/configuration.md | 6 +- internal/app/bootstrap_test.go | 39 ++++ internal/config/config.go | 45 +++-- internal/config/config_test.go | 3 + internal/config/context_budget_migration.go | 67 ++++--- .../config/context_budget_migration_test.go | 85 +++++++++ internal/config/loader.go | 57 +++--- internal/config/loader_test.go | 79 +++++++-- internal/config/provider.go | 40 ++--- internal/config/provider_custom_normalize.go | 44 +++-- internal/config/provider_loader.go | 109 ++++++------ internal/config/provider_test.go | 46 +++-- internal/config/runtime.go | 2 +- internal/runtime/run.go | 8 +- internal/runtime/runtime_test.go | 167 ++++++++++++++++++ 17 files changed, 596 insertions(+), 211 deletions(-) diff --git a/docs/config-management-detail-design.md b/docs/config-management-detail-design.md index 5241bcd2..27b98b50 100644 --- a/docs/config-management-detail-design.md +++ b/docs/config-management-detail-design.md @@ -45,6 +45,7 @@ - `workdir` - `shell` - `tool_timeout_sec` +- `generate_start_timeout_sec` - `context` - `tools` @@ -69,7 +70,9 @@ custom provider 来自: ``` 当前只接受明确受支持的字段;未知字段会直接报错,不做“旧格式自动迁移”。 -`provider.yaml` 只支持平铺字段:`name/driver/base_url/api_key_env/model_source/chat_api_mode/chat_endpoint_path/discovery_endpoint_path/generate_max_retries/generate_start_timeout_sec/generate_idle_timeout_sec/models`。 +`provider.yaml` 只支持平铺字段:`name/driver/base_url/api_key_env/model_source/chat_api_mode/chat_endpoint_path/discovery_endpoint_path/generate_max_retries/generate_idle_timeout_sec/models`。 + +`generate_start_timeout_sec` 已上移到根 `config.yaml` 顶层,启动 preflight 会自动将缺失字段补写为默认值 `90`。 ## 加载流程 @@ -101,6 +104,7 @@ custom provider 来自: - `current_model` - `shell` - `tool_timeout_sec` +- `generate_start_timeout_sec` - `context` - `tools` diff --git a/docs/guides/adding-providers.md b/docs/guides/adding-providers.md index f8e89d65..aa93c835 100644 --- a/docs/guides/adding-providers.md +++ b/docs/guides/adding-providers.md @@ -100,7 +100,6 @@ chat_api_mode: responses chat_endpoint_path: / discovery_endpoint_path: /models generate_max_retries: 5 -generate_start_timeout_sec: 60 generate_idle_timeout_sec: 300 ``` @@ -110,7 +109,8 @@ generate_idle_timeout_sec: 300 - `chat_endpoint_path` 为 `/` 表示直连 `base_url`;为空时会按 `chat_api_mode` 自动回填默认子路径(`/chat/completions` 或 `/responses`)。 - 当 `chat_api_mode` 已显式指定时,`chat_endpoint_path` 可使用任意以 `/` 开头的相对路径;未显式指定时,仅支持标准端点推断(`/chat/completions`、`/responses`、`/`)。 - `model_source: manual` 时必须提供 `models`,且会忽略 `discovery_endpoint_path`。 -- `generate_max_retries` / `generate_start_timeout_sec` / `generate_idle_timeout_sec` 用于控制 provider 级生成重试、首包超时和流空闲超时;未填写或 `<= 0` 时会分别回退到 `5 / 60 / 300`。其中 `generate_max_retries` 必须 `<= 20`。 +- `generate_max_retries` / `generate_idle_timeout_sec` 用于控制 provider 级生成重试和流空闲超时;未填写或 `<= 0` 时会分别回退到 `5 / 300`。其中 `generate_max_retries` 必须 `<= 20`。 +- `generate_start_timeout_sec` 已改为根 `config.yaml` 顶层字段,不再允许写入 `provider.yaml`;启动时缺失会自动补写默认值 `90`。 ## 测试要求 diff --git a/docs/guides/configuration.md b/docs/guides/configuration.md index 3f3eff3e..1ddca5f1 100644 --- a/docs/guides/configuration.md +++ b/docs/guides/configuration.md @@ -23,6 +23,7 @@ selected_provider: openai current_model: gpt-5.4 shell: bash tool_timeout_sec: 20 +generate_start_timeout_sec: 90 runtime: max_no_progress_streak: 5 @@ -174,16 +175,17 @@ chat_api_mode: chat_completions chat_endpoint_path: /chat/completions discovery_endpoint_path: /models generate_max_retries: 5 -generate_start_timeout_sec: 60 generate_idle_timeout_sec: 300 ``` 新增的生成链路控制字段含义如下: - `generate_max_retries`:额外重试次数,不含首次尝试;`<= 0` 时回退默认值 `5`,且必须 `<= 20`。 -- `generate_start_timeout_sec`:从发请求到收到首个有效流 payload 的最长等待窗口;`<= 0` 时回退默认值 `60`。 +- `generate_start_timeout_sec`:写在 `config.yaml` 顶层,从发请求到收到首个有效流 payload 的最长等待窗口;`<= 0` 时回退默认值 `90`。 - `generate_idle_timeout_sec`:首包后连续没有任何新 payload 的最长空闲窗口;`<= 0` 时回退默认值 `300`。 +启动时会自动把缺失的 `generate_start_timeout_sec` 规范化写回 `config.yaml`,避免磁盘配置与运行时默认值不一致。 + ## 不写入 `config.yaml` 的字段 以下内容不允许写入主配置文件: diff --git a/internal/app/bootstrap_test.go b/internal/app/bootstrap_test.go index adc9a182..ea24545d 100644 --- a/internal/app/bootstrap_test.go +++ b/internal/app/bootstrap_test.go @@ -121,6 +121,45 @@ context: } } +func TestBuildSharedConfigDepsPersistsGenerateStartTimeoutDefault(t *testing.T) { + disableBuiltinProviderAPIKeys(t) + + home := t.TempDir() + t.Setenv("HOME", home) + t.Setenv("USERPROFILE", home) + t.Setenv("OPENAI_API_KEY", "test-key") + + configDir := filepath.Join(home, ".neocode") + if err := os.MkdirAll(configDir, 0o755); err != nil { + t.Fatalf("mkdir config dir: %v", err) + } + configPath := filepath.Join(configDir, "config.yaml") + raw := "selected_provider: openai\ncurrent_model: gpt-5.4\nshell: powershell\n" + if err := os.WriteFile(configPath, []byte(raw), 0o644); err != nil { + t.Fatalf("write config: %v", err) + } + + shared, _, _, err := BuildSharedConfigDeps(context.Background(), BootstrapOptions{}) + if err != nil { + t.Fatalf("BuildSharedConfigDeps() error = %v", err) + } + if shared.Config.GenerateStartTimeoutSec != config.DefaultGenerateStartTimeoutSec { + t.Fatalf( + "expected generate_start_timeout_sec=%d, got %d", + config.DefaultGenerateStartTimeoutSec, + shared.Config.GenerateStartTimeoutSec, + ) + } + + data, err := os.ReadFile(configPath) + if err != nil { + t.Fatalf("read config: %v", err) + } + if !strings.Contains(string(data), "generate_start_timeout_sec: 90") { + t.Fatalf("expected config to persist generate_start_timeout_sec, got:\n%s", string(data)) + } +} + func TestBuildSharedConfigDepsReturnsPreflightError(t *testing.T) { disableBuiltinProviderAPIKeys(t) diff --git a/internal/config/config.go b/internal/config/config.go index 60e08572..125320f8 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -10,32 +10,35 @@ import ( ) const ( - DefaultWorkdir = "." - DefaultToolTimeoutSec = 20 + DefaultWorkdir = "." + DefaultToolTimeoutSec = 20 + DefaultGenerateStartTimeoutSec = 90 ) type Config struct { - Providers []ProviderConfig `yaml:"-"` - SelectedProvider string `yaml:"selected_provider"` - CurrentModel string `yaml:"current_model"` - Workdir string `yaml:"-"` - Shell string `yaml:"shell"` - ToolTimeoutSec int `yaml:"tool_timeout_sec,omitempty"` - Runtime RuntimeConfig `yaml:"runtime,omitempty"` - Context ContextConfig `yaml:"context,omitempty"` - Tools ToolsConfig `yaml:"tools,omitempty"` - Memo MemoConfig `yaml:"memo,omitempty"` - Gateway GatewayConfig `yaml:"gateway,omitempty"` + Providers []ProviderConfig `yaml:"-"` + SelectedProvider string `yaml:"selected_provider"` + CurrentModel string `yaml:"current_model"` + Workdir string `yaml:"-"` + Shell string `yaml:"shell"` + ToolTimeoutSec int `yaml:"tool_timeout_sec,omitempty"` + GenerateStartTimeoutSec int `yaml:"generate_start_timeout_sec,omitempty"` + Runtime RuntimeConfig `yaml:"runtime,omitempty"` + Context ContextConfig `yaml:"context,omitempty"` + Tools ToolsConfig `yaml:"tools,omitempty"` + Memo MemoConfig `yaml:"memo,omitempty"` + Gateway GatewayConfig `yaml:"gateway,omitempty"` } // StaticDefaults 返回 config 层负责的静态默认值骨架,不包含 provider 装配和选择状态修复。 func StaticDefaults() *Config { return &Config{ - Workdir: DefaultWorkdir, - Shell: defaultShell(), - ToolTimeoutSec: DefaultToolTimeoutSec, - Runtime: defaultRuntimeConfig(), - Context: defaultContextConfig(), + Workdir: DefaultWorkdir, + Shell: defaultShell(), + ToolTimeoutSec: DefaultToolTimeoutSec, + GenerateStartTimeoutSec: DefaultGenerateStartTimeoutSec, + Runtime: defaultRuntimeConfig(), + Context: defaultContextConfig(), Tools: ToolsConfig{ WebFetch: defaultWebFetchConfig(), MCP: defaultMCPConfig(), @@ -75,6 +78,9 @@ func (c *Config) applyStaticDefaults(defaults Config) { if c.ToolTimeoutSec <= 0 { c.ToolTimeoutSec = defaults.ToolTimeoutSec } + if c.GenerateStartTimeoutSec <= 0 { + c.GenerateStartTimeoutSec = defaults.GenerateStartTimeoutSec + } c.Runtime.ApplyDefaults(defaults.Runtime) c.Context.ApplyDefaults(defaults.Context) c.Tools.ApplyDefaults(defaults.Tools) @@ -92,6 +98,9 @@ func (c *Config) ValidateSnapshot() error { if len(c.Providers) == 0 { return errors.New("config: providers is empty") } + if err := validateOptionalGenerateDurationSeconds("generate_start_timeout_sec", c.GenerateStartTimeoutSec); err != nil { + return fmt.Errorf("config: %w", err) + } seen := make(map[string]struct{}, len(c.Providers)) seenEndpoints := make(map[string]string, len(c.Providers)) diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 9d421a1c..572dac7c 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -818,6 +818,9 @@ func TestLoaderLoadAndSaveRoundTrip(t *testing.T) { if reloaded.Tools.WebFetch.MaxResponseBytes != 1024 { t.Fatalf("expected max_response_bytes %d, got %d", 1024, reloaded.Tools.WebFetch.MaxResponseBytes) } + if reloaded.GenerateStartTimeoutSec != DefaultGenerateStartTimeoutSec { + t.Fatalf("expected generate_start_timeout_sec %d, got %d", DefaultGenerateStartTimeoutSec, reloaded.GenerateStartTimeoutSec) + } if len(reloaded.Tools.WebFetch.SupportedContentTypes) != 2 { t.Fatalf("expected persisted supported content types, got %+v", reloaded.Tools.WebFetch.SupportedContentTypes) } diff --git a/internal/config/context_budget_migration.go b/internal/config/context_budget_migration.go index 2e5139c5..0e1fb632 100644 --- a/internal/config/context_budget_migration.go +++ b/internal/config/context_budget_migration.go @@ -56,7 +56,7 @@ func MigrateContextBudgetConfigFile(path string, dryRun bool) (ContextBudgetMigr } result.Notes = append(result.Notes, notes...) if !changed { - result.Reason = "未检测到 context.auto_compact" + result.Reason = "未检测到需要升级的配置字段" return result, nil } @@ -81,44 +81,55 @@ func MigrateContextBudgetConfigContent(raw []byte) ([]byte, bool, []string, erro if len(bytes.TrimSpace(raw)) == 0 { return raw, false, nil, nil } - if !bytes.Contains(raw, []byte("auto_compact")) { - return raw, false, nil, nil - } var doc map[string]any if err := yaml.Unmarshal(raw, &doc); err != nil { return nil, false, nil, err } - contextValue, ok := doc["context"] - if !ok { - return raw, false, nil, nil - } - contextMap, ok := migrationStringMap(contextValue) - if !ok { - return nil, false, nil, errors.New("context must be a mapping") + if doc == nil { + doc = make(map[string]any) } - autoValue, hasAutoCompact := contextMap["auto_compact"] - if !hasAutoCompact { - return raw, false, nil, nil - } - if _, hasBudget := contextMap["budget"]; hasBudget { - return nil, false, nil, errors.New("context.auto_compact and context.budget cannot both exist") + changed := false + if _, exists := doc["generate_start_timeout_sec"]; !exists { + doc["generate_start_timeout_sec"] = DefaultGenerateStartTimeoutSec + changed = true } - autoMap, ok := migrationStringMap(autoValue) - if !ok { - return nil, false, nil, errors.New("context.auto_compact must be a mapping") + var notes []string + contextValue, hasContext := doc["context"] + if hasContext { + contextMap, ok := migrationStringMap(contextValue) + if !ok { + return nil, false, nil, errors.New("context must be a mapping") + } + + autoValue, hasAutoCompact := contextMap["auto_compact"] + if hasAutoCompact { + if _, hasBudget := contextMap["budget"]; hasBudget { + return nil, false, nil, errors.New("context.auto_compact and context.budget cannot both exist") + } + + autoMap, ok := migrationStringMap(autoValue) + if !ok { + return nil, false, nil, errors.New("context.auto_compact must be a mapping") + } + budgetMap := make(map[string]any) + migrationMoveField(autoMap, budgetMap, "input_token_threshold", "prompt_budget") + migrationMoveField(autoMap, budgetMap, "reserve_tokens", "reserve_tokens") + migrationMoveField(autoMap, budgetMap, "fallback_input_token_threshold", "fallback_prompt_budget") + notes = collectContextBudgetMigrationNotes(autoMap) + + delete(contextMap, "auto_compact") + contextMap["budget"] = budgetMap + doc["context"] = contextMap + changed = true + } } - budgetMap := make(map[string]any) - migrationMoveField(autoMap, budgetMap, "input_token_threshold", "prompt_budget") - migrationMoveField(autoMap, budgetMap, "reserve_tokens", "reserve_tokens") - migrationMoveField(autoMap, budgetMap, "fallback_input_token_threshold", "fallback_prompt_budget") - notes := collectContextBudgetMigrationNotes(autoMap) - delete(contextMap, "auto_compact") - contextMap["budget"] = budgetMap - doc["context"] = contextMap + if !changed { + return raw, false, nil, nil + } out, err := yaml.Marshal(doc) if err != nil { diff --git a/internal/config/context_budget_migration_test.go b/internal/config/context_budget_migration_test.go index 97c99357..e8a953c0 100644 --- a/internal/config/context_budget_migration_test.go +++ b/internal/config/context_budget_migration_test.go @@ -37,6 +37,7 @@ context: t.Fatalf("expected auto_compact removed, got:\n%s", text) } for _, want := range []string{ + "generate_start_timeout_sec: 90", "budget:", "prompt_budget: 120000", "reserve_tokens: 13000", @@ -48,6 +49,52 @@ context: } } +func TestMigrateContextBudgetConfigContentAddsGenerateStartTimeoutWhenMissing(t *testing.T) { + t.Parallel() + + input := []byte(strings.TrimSpace(` +selected_provider: openai +current_model: gpt-5.4 +shell: powershell +`) + "\n") + + out, changed, notes, err := MigrateContextBudgetConfigContent(input) + if err != nil { + t.Fatalf("MigrateContextBudgetConfigContent() error = %v", err) + } + if !changed { + t.Fatal("expected migration change") + } + if len(notes) != 0 { + t.Fatalf("expected no migration notes, got %v", notes) + } + if !strings.Contains(string(out), "generate_start_timeout_sec: 90") { + t.Fatalf("expected generate_start_timeout_sec to be added, got:\n%s", out) + } +} + +func TestMigrateContextBudgetConfigContentKeepsExistingGenerateStartTimeout(t *testing.T) { + t.Parallel() + + input := []byte(strings.TrimSpace(` +selected_provider: openai +current_model: gpt-5.4 +shell: powershell +generate_start_timeout_sec: 120 +`) + "\n") + + out, changed, notes, err := MigrateContextBudgetConfigContent(input) + if err != nil { + t.Fatalf("MigrateContextBudgetConfigContent() error = %v", err) + } + if changed { + t.Fatalf("expected no migration change, got:\n%s", out) + } + if len(notes) != 0 { + t.Fatalf("expected no migration notes, got %v", notes) + } +} + func TestMigrateContextBudgetConfigContentRejectsMixedBudgetBlocks(t *testing.T) { t.Parallel() @@ -154,6 +201,44 @@ context: } } +func TestMigrateContextBudgetConfigFileCreatesBackupWhenAddingGenerateStartTimeout(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + target := filepath.Join(dir, configName) + original := "selected_provider: openai\nshell: powershell\n" + if err := os.WriteFile(target, []byte(original), 0o644); err != nil { + t.Fatalf("write target: %v", err) + } + + result, err := MigrateContextBudgetConfigFile(target, false) + if err != nil { + t.Fatalf("MigrateContextBudgetConfigFile() error = %v", err) + } + if !result.Changed { + t.Fatal("expected changed result") + } + if result.Backup == "" { + t.Fatal("expected backup path") + } + + migrated, err := os.ReadFile(target) + if err != nil { + t.Fatalf("read migrated: %v", err) + } + if !strings.Contains(string(migrated), "generate_start_timeout_sec: 90") { + t.Fatalf("expected generate_start_timeout_sec to be persisted, got:\n%s", migrated) + } + + backup, err := os.ReadFile(result.Backup) + if err != nil { + t.Fatalf("read backup: %v", err) + } + if string(backup) != original { + t.Fatalf("expected backup to keep original content, got:\n%s", backup) + } +} + func TestUpgradeConfigSchemaReturnsNotes(t *testing.T) { t.Parallel() diff --git a/internal/config/loader.go b/internal/config/loader.go index 0766f734..9158e67e 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -22,15 +22,16 @@ type Loader struct { } type persistedConfig struct { - SelectedProvider string `yaml:"selected_provider,omitempty"` - CurrentModel string `yaml:"current_model,omitempty"` - Shell string `yaml:"shell"` - ToolTimeoutSec int `yaml:"tool_timeout_sec,omitempty"` - Runtime RuntimeConfig `yaml:"runtime,omitempty"` - Context persistedContextConfig `yaml:"context,omitempty"` - Tools ToolsConfig `yaml:"tools,omitempty"` - Memo persistedMemoConfig `yaml:"memo,omitempty"` - Gateway GatewayConfig `yaml:"gateway,omitempty"` + SelectedProvider string `yaml:"selected_provider,omitempty"` + CurrentModel string `yaml:"current_model,omitempty"` + Shell string `yaml:"shell"` + ToolTimeoutSec int `yaml:"tool_timeout_sec,omitempty"` + GenerateStartTimeoutSec int `yaml:"generate_start_timeout_sec,omitempty"` + Runtime RuntimeConfig `yaml:"runtime,omitempty"` + Context persistedContextConfig `yaml:"context,omitempty"` + Tools ToolsConfig `yaml:"tools,omitempty"` + Memo persistedMemoConfig `yaml:"memo,omitempty"` + Gateway GatewayConfig `yaml:"gateway,omitempty"` } type persistedContextConfig struct { @@ -217,15 +218,16 @@ func parseCurrentConfig(data []byte, contextDefaults ContextConfig, memoDefaults return nil, err } cfg := &Config{ - SelectedProvider: strings.TrimSpace(file.SelectedProvider), - CurrentModel: strings.TrimSpace(file.CurrentModel), - Shell: strings.TrimSpace(file.Shell), - ToolTimeoutSec: file.ToolTimeoutSec, - Runtime: file.Runtime, - Context: fromPersistedContextConfig(file.Context, contextDefaults), - Tools: file.Tools, - Memo: fromPersistedMemoConfig(file.Memo, memoDefaults), - Gateway: file.Gateway, + SelectedProvider: strings.TrimSpace(file.SelectedProvider), + CurrentModel: strings.TrimSpace(file.CurrentModel), + Shell: strings.TrimSpace(file.Shell), + ToolTimeoutSec: file.ToolTimeoutSec, + GenerateStartTimeoutSec: file.GenerateStartTimeoutSec, + Runtime: file.Runtime, + Context: fromPersistedContextConfig(file.Context, contextDefaults), + Tools: file.Tools, + Memo: fromPersistedMemoConfig(file.Memo, memoDefaults), + Gateway: file.Gateway, } return cfg, nil @@ -233,15 +235,16 @@ func parseCurrentConfig(data []byte, contextDefaults ContextConfig, memoDefaults func marshalPersistedConfig(snapshot Config) ([]byte, error) { file := persistedConfig{ - SelectedProvider: snapshot.SelectedProvider, - CurrentModel: snapshot.CurrentModel, - Shell: snapshot.Shell, - ToolTimeoutSec: snapshot.ToolTimeoutSec, - Runtime: snapshot.Runtime, - Context: newPersistedContextConfig(snapshot.Context), - Tools: snapshot.Tools, - Memo: newPersistedMemoConfig(snapshot.Memo), - Gateway: snapshot.Gateway, + SelectedProvider: snapshot.SelectedProvider, + CurrentModel: snapshot.CurrentModel, + Shell: snapshot.Shell, + ToolTimeoutSec: snapshot.ToolTimeoutSec, + GenerateStartTimeoutSec: snapshot.GenerateStartTimeoutSec, + Runtime: snapshot.Runtime, + Context: newPersistedContextConfig(snapshot.Context), + Tools: snapshot.Tools, + Memo: newPersistedMemoConfig(snapshot.Memo), + Gateway: snapshot.Gateway, } data, err := yaml.Marshal(&file) diff --git a/internal/config/loader_test.go b/internal/config/loader_test.go index 0e0632c1..30de24eb 100644 --- a/internal/config/loader_test.go +++ b/internal/config/loader_test.go @@ -1368,15 +1368,14 @@ func TestSaveAndLoadCustomProviderPersistsGenerateControls(t *testing.T) { baseDir := t.TempDir() const providerName = "retry-controls-provider" err := SaveCustomProviderWithModels(baseDir, SaveCustomProviderInput{ - Name: providerName, - Driver: provider.DriverOpenAICompat, - BaseURL: "https://llm.example.com/v1", - APIKeyEnv: "RETRY_CONTROLS_PROVIDER_API_KEY", - ModelSource: ModelSourceDiscover, - DiscoveryEndpointPath: provider.DiscoveryEndpointPathModels, - GenerateMaxRetries: 7, - GenerateStartTimeoutSec: 75, - GenerateIdleTimeoutSec: 420, + Name: providerName, + Driver: provider.DriverOpenAICompat, + BaseURL: "https://llm.example.com/v1", + APIKeyEnv: "RETRY_CONTROLS_PROVIDER_API_KEY", + ModelSource: ModelSourceDiscover, + DiscoveryEndpointPath: provider.DiscoveryEndpointPathModels, + GenerateMaxRetries: 7, + GenerateIdleTimeoutSec: 420, }) if err != nil { t.Fatalf("SaveCustomProviderWithModels() error = %v", err) @@ -1389,9 +1388,6 @@ func TestSaveAndLoadCustomProviderPersistsGenerateControls(t *testing.T) { if cfg.GenerateMaxRetries != 7 { t.Fatalf("expected GenerateMaxRetries=7, got %d", cfg.GenerateMaxRetries) } - if cfg.GenerateStartTimeoutSec != 75 { - t.Fatalf("expected GenerateStartTimeoutSec=75, got %d", cfg.GenerateStartTimeoutSec) - } if cfg.GenerateIdleTimeoutSec != 420 { t.Fatalf("expected GenerateIdleTimeoutSec=420, got %d", cfg.GenerateIdleTimeoutSec) } @@ -1422,14 +1418,38 @@ func TestSaveCustomProviderOmitsDefaultGenerateControlsWhenUnset(t *testing.T) { if strings.Contains(content, "generate_max_retries") { t.Fatalf("expected generate_max_retries to be omitted, got %q", content) } - if strings.Contains(content, "generate_start_timeout_sec") { - t.Fatalf("expected generate_start_timeout_sec to be omitted, got %q", content) - } if strings.Contains(content, "generate_idle_timeout_sec") { t.Fatalf("expected generate_idle_timeout_sec to be omitted, got %q", content) } } +func TestLoaderRejectsCustomProviderGenerateStartTimeoutField(t *testing.T) { + t.Parallel() + + loader := NewLoader(t.TempDir(), testDefaultConfig()) + customDir := filepath.Join(loader.BaseDir(), providersDirName, "company-gateway") + if err := os.MkdirAll(customDir, 0o755); err != nil { + t.Fatalf("mkdir custom provider dir: %v", err) + } + + providerYAML := ` +name: company-gateway +driver: openaicompat +base_url: https://llm.example.com/v1 +api_key_env: COMPANY_GATEWAY_API_KEY +generate_start_timeout_sec: 75 +discovery_endpoint_path: /models +` + if err := os.WriteFile(filepath.Join(customDir, customProviderConfigName), []byte(strings.TrimSpace(providerYAML)+"\n"), 0o644); err != nil { + t.Fatalf("write provider.yaml: %v", err) + } + + _, err := loader.Load(context.Background()) + if err == nil || !strings.Contains(err.Error(), "field generate_start_timeout_sec not found") { + t.Fatalf("expected unknown field rejection for generate_start_timeout_sec, got %v", err) + } +} + func TestSaveCustomProviderRejectsNegativeGenerateControls(t *testing.T) { t.Parallel() @@ -1984,3 +2004,32 @@ func TestLoaderSaveRoundTripsCompactExtendedFields(t *testing.T) { t.Fatalf("expected round-trip max_archived_prompt_chars=3072, got %d", loaded.Context.Compact.MaxArchivedPromptChars) } } + +func TestLoaderSaveAndLoadPersistsGenerateStartTimeoutSec(t *testing.T) { + t.Parallel() + + loader := NewLoader(t.TempDir(), testDefaultConfig()) + cfg := loader.DefaultConfig() + cfg.GenerateStartTimeoutSec = 120 + + if err := loader.Save(context.Background(), &cfg); err != nil { + t.Fatalf("Save() error = %v", err) + } + + data, err := os.ReadFile(loader.ConfigPath()) + if err != nil { + t.Fatalf("ReadFile() error = %v", err) + } + text := string(data) + if !strings.Contains(text, "generate_start_timeout_sec: 120") { + t.Fatalf("expected generate_start_timeout_sec to be persisted, got:\n%s", text) + } + + loaded, err := loader.Load(context.Background()) + if err != nil { + t.Fatalf("Load() error = %v", err) + } + if loaded.GenerateStartTimeoutSec != 120 { + t.Fatalf("expected GenerateStartTimeoutSec=120 after reload, got %d", loaded.GenerateStartTimeoutSec) + } +} diff --git a/internal/config/provider.go b/internal/config/provider.go index 3ce61d23..84afae84 100644 --- a/internal/config/provider.go +++ b/internal/config/provider.go @@ -22,26 +22,26 @@ const ( ) type ProviderConfig struct { - Name string `yaml:"name"` - Driver string `yaml:"driver"` - BaseURL string `yaml:"base_url"` - Model string `yaml:"model"` - APIKeyEnv string `yaml:"api_key_env"` - GenerateMaxRetries int `yaml:"generate_max_retries,omitempty"` - GenerateStartTimeoutSec int `yaml:"generate_start_timeout_sec,omitempty"` - GenerateIdleTimeoutSec int `yaml:"generate_idle_timeout_sec,omitempty"` - ModelSource string `yaml:"-"` - ChatAPIMode string `yaml:"-"` - ChatEndpointPath string `yaml:"-"` - DiscoveryEndpointPath string `yaml:"-"` - Models []providertypes.ModelDescriptor `yaml:"-"` - Source ProviderSource `yaml:"-"` + Name string `yaml:"name"` + Driver string `yaml:"driver"` + BaseURL string `yaml:"base_url"` + Model string `yaml:"model"` + APIKeyEnv string `yaml:"api_key_env"` + GenerateMaxRetries int `yaml:"generate_max_retries,omitempty"` + GenerateIdleTimeoutSec int `yaml:"generate_idle_timeout_sec,omitempty"` + ModelSource string `yaml:"-"` + ChatAPIMode string `yaml:"-"` + ChatEndpointPath string `yaml:"-"` + DiscoveryEndpointPath string `yaml:"-"` + Models []providertypes.ModelDescriptor `yaml:"-"` + Source ProviderSource `yaml:"-"` } type ResolvedProviderConfig struct { ProviderConfig - SessionAssetPolicy session.AssetPolicy `yaml:"-"` - RequestAssetBudget provider.RequestAssetBudget `yaml:"-"` + GenerateStartTimeoutSec int `yaml:"-"` + SessionAssetPolicy session.AssetPolicy `yaml:"-"` + RequestAssetBudget provider.RequestAssetBudget `yaml:"-"` } // ResolveSelectedProvider 解析当前配置中选中的 provider,并补全运行时所需的运行时策略。 @@ -55,7 +55,10 @@ func ResolveSelectedProvider(cfg Config) (ResolvedProviderConfig, error) { if err != nil { return ResolvedProviderConfig{}, err } - resolved := ResolvedProviderConfig{ProviderConfig: providerCfg} + resolved := ResolvedProviderConfig{ + ProviderConfig: providerCfg, + GenerateStartTimeoutSec: cfg.GenerateStartTimeoutSec, + } resolved.SessionAssetPolicy = cfg.Runtime.ResolveSessionAssetPolicy() resolved.RequestAssetBudget = cfg.Runtime.ResolveRequestAssetBudget() return resolved, nil @@ -87,9 +90,6 @@ func (p ProviderConfig) Validate() error { if err := validateOptionalGenerateMaxRetries(p.GenerateMaxRetries); err != nil { return fmt.Errorf("provider %q: %w", p.Name, err) } - if err := validateOptionalGenerateDurationSeconds("generate_start_timeout_sec", p.GenerateStartTimeoutSec); err != nil { - return fmt.Errorf("provider %q: %w", p.Name, err) - } if err := validateOptionalGenerateDurationSeconds("generate_idle_timeout_sec", p.GenerateIdleTimeoutSec); err != nil { return fmt.Errorf("provider %q: %w", p.Name, err) } diff --git a/internal/config/provider_custom_normalize.go b/internal/config/provider_custom_normalize.go index e6fd7709..ca32604f 100644 --- a/internal/config/provider_custom_normalize.go +++ b/internal/config/provider_custom_normalize.go @@ -14,16 +14,15 @@ const ManualModelOptionalIntUnset = -1 // NormalizeCustomProviderInput 统一归一化 custom provider 的输入字段,并执行协议/模型来源的组合校验。 func NormalizeCustomProviderInput(input SaveCustomProviderInput) (SaveCustomProviderInput, error) { normalized := SaveCustomProviderInput{ - Name: strings.TrimSpace(input.Name), - Driver: normalizeProviderDriver(strings.TrimSpace(input.Driver)), - BaseURL: strings.TrimSpace(input.BaseURL), - ChatAPIMode: strings.TrimSpace(input.ChatAPIMode), - ChatEndpointPath: strings.TrimSpace(input.ChatEndpointPath), - APIKeyEnv: strings.TrimSpace(input.APIKeyEnv), - GenerateMaxRetries: normalizeOptionalGenerateInt(input.GenerateMaxRetries), - GenerateStartTimeoutSec: normalizeOptionalGenerateInt(input.GenerateStartTimeoutSec), - GenerateIdleTimeoutSec: normalizeOptionalGenerateInt(input.GenerateIdleTimeoutSec), - DiscoveryEndpointPath: strings.TrimSpace(input.DiscoveryEndpointPath), + Name: strings.TrimSpace(input.Name), + Driver: normalizeProviderDriver(strings.TrimSpace(input.Driver)), + BaseURL: strings.TrimSpace(input.BaseURL), + ChatAPIMode: strings.TrimSpace(input.ChatAPIMode), + ChatEndpointPath: strings.TrimSpace(input.ChatEndpointPath), + APIKeyEnv: strings.TrimSpace(input.APIKeyEnv), + GenerateMaxRetries: normalizeOptionalGenerateInt(input.GenerateMaxRetries), + GenerateIdleTimeoutSec: normalizeOptionalGenerateInt(input.GenerateIdleTimeoutSec), + DiscoveryEndpointPath: strings.TrimSpace(input.DiscoveryEndpointPath), } if err := validateCustomProviderName(normalized.Name); err != nil { @@ -114,19 +113,18 @@ func normalizeOptionalGenerateInt(value int) int { // validateNormalizedCustomProviderInput 复用统一的 provider 配置校验,避免 custom provider 保存路径和加载路径出现两套规则。 func validateNormalizedCustomProviderInput(input SaveCustomProviderInput) error { cfg := ProviderConfig{ - Name: input.Name, - Driver: input.Driver, - BaseURL: input.BaseURL, - APIKeyEnv: input.APIKeyEnv, - GenerateMaxRetries: input.GenerateMaxRetries, - GenerateStartTimeoutSec: input.GenerateStartTimeoutSec, - GenerateIdleTimeoutSec: input.GenerateIdleTimeoutSec, - ModelSource: input.ModelSource, - ChatAPIMode: input.ChatAPIMode, - ChatEndpointPath: input.ChatEndpointPath, - DiscoveryEndpointPath: input.DiscoveryEndpointPath, - Models: input.Models, - Source: ProviderSourceCustom, + Name: input.Name, + Driver: input.Driver, + BaseURL: input.BaseURL, + APIKeyEnv: input.APIKeyEnv, + GenerateMaxRetries: input.GenerateMaxRetries, + GenerateIdleTimeoutSec: input.GenerateIdleTimeoutSec, + ModelSource: input.ModelSource, + ChatAPIMode: input.ChatAPIMode, + ChatEndpointPath: input.ChatEndpointPath, + DiscoveryEndpointPath: input.DiscoveryEndpointPath, + Models: input.Models, + Source: ProviderSourceCustom, } return cfg.Validate() } diff --git a/internal/config/provider_loader.go b/internal/config/provider_loader.go index 4c792fd6..3fab03c5 100644 --- a/internal/config/provider_loader.go +++ b/internal/config/provider_loader.go @@ -21,18 +21,17 @@ const ( ) type customProviderFile struct { - Name string `yaml:"name"` - Driver string `yaml:"driver"` - APIKeyEnv string `yaml:"api_key_env"` - GenerateMaxRetries int `yaml:"generate_max_retries,omitempty"` - GenerateStartTimeoutSec int `yaml:"generate_start_timeout_sec,omitempty"` - GenerateIdleTimeoutSec int `yaml:"generate_idle_timeout_sec,omitempty"` - ModelSource string `yaml:"model_source,omitempty"` - ChatAPIMode string `yaml:"chat_api_mode,omitempty"` - BaseURL string `yaml:"base_url,omitempty"` - ChatEndpointPath string `yaml:"chat_endpoint_path,omitempty"` - DiscoveryEndpointPath string `yaml:"discovery_endpoint_path,omitempty"` - Models []customProviderModelFile `yaml:"models,omitempty"` + Name string `yaml:"name"` + Driver string `yaml:"driver"` + APIKeyEnv string `yaml:"api_key_env"` + GenerateMaxRetries int `yaml:"generate_max_retries,omitempty"` + GenerateIdleTimeoutSec int `yaml:"generate_idle_timeout_sec,omitempty"` + ModelSource string `yaml:"model_source,omitempty"` + ChatAPIMode string `yaml:"chat_api_mode,omitempty"` + BaseURL string `yaml:"base_url,omitempty"` + ChatEndpointPath string `yaml:"chat_endpoint_path,omitempty"` + DiscoveryEndpointPath string `yaml:"discovery_endpoint_path,omitempty"` + Models []customProviderModelFile `yaml:"models,omitempty"` } type customProviderModelFile struct { @@ -112,37 +111,35 @@ func loadCustomProvider(providerDir string) (ProviderConfig, error) { } normalizedInput, err := NormalizeCustomProviderInput(SaveCustomProviderInput{ - Name: strings.TrimSpace(file.Name), - Driver: strings.TrimSpace(file.Driver), - BaseURL: strings.TrimSpace(file.BaseURL), - APIKeyEnv: strings.TrimSpace(file.APIKeyEnv), - GenerateMaxRetries: file.GenerateMaxRetries, - GenerateStartTimeoutSec: file.GenerateStartTimeoutSec, - GenerateIdleTimeoutSec: file.GenerateIdleTimeoutSec, - ModelSource: strings.TrimSpace(file.ModelSource), - ChatAPIMode: strings.TrimSpace(file.ChatAPIMode), - ChatEndpointPath: strings.TrimSpace(file.ChatEndpointPath), - DiscoveryEndpointPath: strings.TrimSpace(file.DiscoveryEndpointPath), - Models: models, + Name: strings.TrimSpace(file.Name), + Driver: strings.TrimSpace(file.Driver), + BaseURL: strings.TrimSpace(file.BaseURL), + APIKeyEnv: strings.TrimSpace(file.APIKeyEnv), + GenerateMaxRetries: file.GenerateMaxRetries, + GenerateIdleTimeoutSec: file.GenerateIdleTimeoutSec, + ModelSource: strings.TrimSpace(file.ModelSource), + ChatAPIMode: strings.TrimSpace(file.ChatAPIMode), + ChatEndpointPath: strings.TrimSpace(file.ChatEndpointPath), + DiscoveryEndpointPath: strings.TrimSpace(file.DiscoveryEndpointPath), + Models: models, }) if err != nil { return ProviderConfig{}, fmt.Errorf("config: custom provider %q: %w", filepath.Base(providerDir), err) } cfg := ProviderConfig{ - Name: normalizedInput.Name, - Driver: normalizedInput.Driver, - BaseURL: normalizedInput.BaseURL, - APIKeyEnv: normalizedInput.APIKeyEnv, - GenerateMaxRetries: normalizedInput.GenerateMaxRetries, - GenerateStartTimeoutSec: normalizedInput.GenerateStartTimeoutSec, - GenerateIdleTimeoutSec: normalizedInput.GenerateIdleTimeoutSec, - ModelSource: normalizedInput.ModelSource, - ChatAPIMode: normalizedInput.ChatAPIMode, - ChatEndpointPath: normalizedInput.ChatEndpointPath, - DiscoveryEndpointPath: normalizedInput.DiscoveryEndpointPath, - Models: normalizedInput.Models, - Source: ProviderSourceCustom, + Name: normalizedInput.Name, + Driver: normalizedInput.Driver, + BaseURL: normalizedInput.BaseURL, + APIKeyEnv: normalizedInput.APIKeyEnv, + GenerateMaxRetries: normalizedInput.GenerateMaxRetries, + GenerateIdleTimeoutSec: normalizedInput.GenerateIdleTimeoutSec, + ModelSource: normalizedInput.ModelSource, + ChatAPIMode: normalizedInput.ChatAPIMode, + ChatEndpointPath: normalizedInput.ChatEndpointPath, + DiscoveryEndpointPath: normalizedInput.DiscoveryEndpointPath, + Models: normalizedInput.Models, + Source: ProviderSourceCustom, } if err := cfg.Validate(); err != nil { @@ -200,18 +197,17 @@ func customProviderModels(models []customProviderModelFile) ([]providertypes.Mod // SaveCustomProviderInput 定义自定义 Provider 的持久化字段。 type SaveCustomProviderInput struct { - Name string - Driver string - BaseURL string - ChatAPIMode string - ChatEndpointPath string - APIKeyEnv string - GenerateMaxRetries int - GenerateStartTimeoutSec int - GenerateIdleTimeoutSec int - DiscoveryEndpointPath string - ModelSource string - Models []providertypes.ModelDescriptor + Name string + Driver string + BaseURL string + ChatAPIMode string + ChatEndpointPath string + APIKeyEnv string + GenerateMaxRetries int + GenerateIdleTimeoutSec int + DiscoveryEndpointPath string + ModelSource string + Models []providertypes.ModelDescriptor } // SaveCustomProviderWithModels 保存自定义 provider,并可在 manual 模式下写入手工模型列表。 @@ -227,14 +223,13 @@ func SaveCustomProviderWithModels(baseDir string, input SaveCustomProviderInput) } cfg := customProviderFile{ - Name: normalizedInput.Name, - Driver: normalizedInput.Driver, - APIKeyEnv: normalizedInput.APIKeyEnv, - GenerateMaxRetries: normalizedInput.GenerateMaxRetries, - GenerateStartTimeoutSec: normalizedInput.GenerateStartTimeoutSec, - GenerateIdleTimeoutSec: normalizedInput.GenerateIdleTimeoutSec, - ModelSource: normalizedInput.ModelSource, - ChatAPIMode: normalizedInput.ChatAPIMode, + Name: normalizedInput.Name, + Driver: normalizedInput.Driver, + APIKeyEnv: normalizedInput.APIKeyEnv, + GenerateMaxRetries: normalizedInput.GenerateMaxRetries, + GenerateIdleTimeoutSec: normalizedInput.GenerateIdleTimeoutSec, + ModelSource: normalizedInput.ModelSource, + ChatAPIMode: normalizedInput.ChatAPIMode, } cfg.BaseURL = normalizedInput.BaseURL diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go index 8422a8ca..a401a6fa 100644 --- a/internal/config/provider_test.go +++ b/internal/config/provider_test.go @@ -108,6 +108,29 @@ func TestResolveSelectedProviderIncludesRuntimeAssetPolicyAndBudget(t *testing.T } } +func TestResolveSelectedProviderPrefersRootGenerateStartTimeout(t *testing.T) { + t.Parallel() + + cfg := testDefaultConfig().Clone() + cfg.GenerateStartTimeoutSec = 90 + + resolved, err := ResolveSelectedProvider(cfg) + if err != nil { + t.Fatalf("ResolveSelectedProvider() error = %v", err) + } + if resolved.GenerateStartTimeoutSec != 90 { + t.Fatalf("expected resolved GenerateStartTimeoutSec=90, got %d", resolved.GenerateStartTimeoutSec) + } + + runtimeCfg, err := resolved.ToRuntimeConfig() + if err != nil { + t.Fatalf("ToRuntimeConfig() error = %v", err) + } + if runtimeCfg.GenerateStartTimeout != 90*time.Second { + t.Fatalf("expected runtime GenerateStartTimeout=90s, got %s", runtimeCfg.GenerateStartTimeout) + } +} + func TestProviderConfigIdentity(t *testing.T) { t.Parallel() @@ -387,13 +410,6 @@ func TestProviderConfigValidateRejectsNegativeGenerateControls(t *testing.T) { }, errContain: "generate_max_retries", }, - { - name: "negative start timeout", - mutate: func(cfg *ProviderConfig) { - cfg.GenerateStartTimeoutSec = -1 - }, - errContain: "generate_start_timeout_sec", - }, { name: "negative idle timeout", mutate: func(cfg *ProviderConfig) { @@ -813,15 +829,15 @@ func TestResolvedProviderConfigToRuntimeConfigMapsGenerateControls(t *testing.T) resolved := ResolvedProviderConfig{ ProviderConfig: ProviderConfig{ - Name: "company-gateway", - Driver: "openaicompat", - BaseURL: "https://llm.example.com/v1", - Model: "server-default", - APIKeyEnv: "COMPANY_GATEWAY_KEY", - GenerateMaxRetries: 7, - GenerateStartTimeoutSec: 75, - GenerateIdleTimeoutSec: 420, + Name: "company-gateway", + Driver: "openaicompat", + BaseURL: "https://llm.example.com/v1", + Model: "server-default", + APIKeyEnv: "COMPANY_GATEWAY_KEY", + GenerateMaxRetries: 7, + GenerateIdleTimeoutSec: 420, }, + GenerateStartTimeoutSec: 75, } got, err := resolved.ToRuntimeConfig() diff --git a/internal/config/runtime.go b/internal/config/runtime.go index 04867dd4..0fcc6838 100644 --- a/internal/config/runtime.go +++ b/internal/config/runtime.go @@ -10,7 +10,7 @@ import ( const ( DefaultMaxNoProgressStreak = 5 DefaultMaxRepeatCycleStreak = 3 - DefaultMaxTurns = 40 + DefaultMaxTurns = 90 ) // RuntimeConfig 定义 runtime 层的可调参数。 diff --git a/internal/runtime/run.go b/internal/runtime/run.go index ea1ed85e..f25ff2d8 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -168,15 +168,19 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { } switch decision.Action { case controlplane.TurnBudgetActionCompact: - if _, err := s.applyCompactForState( + applied, err := s.applyCompactForState( ctx, &state, snapshot.Config, contextcompact.ModeProactive, compactErrorBestEffort, - ); err != nil { + ) + if err != nil { return s.handleRunError(err) } + if !applied { + state.compactCount++ + } continue case controlplane.TurnBudgetActionStop: state.budgetExceeded = true diff --git a/internal/runtime/runtime_test.go b/internal/runtime/runtime_test.go index e2bd24a7..a764e47b 100644 --- a/internal/runtime/runtime_test.go +++ b/internal/runtime/runtime_test.go @@ -4853,6 +4853,173 @@ func TestServiceRunAllowsAfterProactiveCompactWhenEstimateAdvisory(t *testing.T) } } +func TestServiceRunStopsAfterNoOpProactiveCompactWhenEstimateGateable(t *testing.T) { + t.Parallel() + + manager := newRuntimeConfigManager(t) + if err := manager.Update(context.Background(), func(cfg *config.Config) error { + cfg.Context.Budget.PromptBudget = 10 + cfg.Context.Budget.FallbackPromptBudget = 10 + return nil + }); err != nil { + t.Fatalf("update config: %v", err) + } + + store := newMemoryStore() + registry := tools.NewRegistry() + scripted := &scriptedProvider{ + estimateFn: func(ctx context.Context, req providertypes.GenerateRequest) (providertypes.BudgetEstimate, error) { + _ = ctx + _ = req + return providertypes.BudgetEstimate{ + EstimatedInputTokens: 99, + EstimateSource: provider.EstimateSourceLocal, + GatePolicy: provider.EstimateGateGateable, + }, nil + }, + } + + service := NewWithFactory(manager, registry, store, &scriptedProviderFactory{provider: scripted}, &stubContextBuilder{}) + service.compactRunner = &stubCompactRunner{ + result: contextcompact.Result{ + Applied: false, + Metrics: contextcompact.Metrics{ + TriggerMode: string(contextcompact.ModeProactive), + }, + }, + } + + if err := service.Run(context.Background(), UserInput{ + RunID: "run-budget-gateable-stop-noop-compact", + Parts: []providertypes.ContentPart{providertypes.NewTextPart("continue")}, + }); err != nil { + t.Fatalf("Run() error = %v", err) + } + + compactRunner := service.compactRunner.(*stubCompactRunner) + if len(compactRunner.calls) != 1 { + t.Fatalf("expected one proactive compact attempt, got %d", len(compactRunner.calls)) + } + if scripted.callCount != 0 { + t.Fatalf("expected provider Generate to be skipped after no-op compact budget stop, got %d calls", scripted.callCount) + } + + events := collectRuntimeEvents(service.Events()) + var budgetActions []string + var stopPayload StopReasonDecidedPayload + for _, event := range events { + switch event.Type { + case EventBudgetChecked: + payload, ok := event.Payload.(BudgetCheckedPayload) + if !ok { + t.Fatalf("expected BudgetCheckedPayload, got %T", event.Payload) + } + budgetActions = append(budgetActions, payload.Action) + case EventStopReasonDecided: + payload, ok := event.Payload.(StopReasonDecidedPayload) + if !ok { + t.Fatalf("expected StopReasonDecidedPayload, got %T", event.Payload) + } + stopPayload = payload + } + } + + if len(budgetActions) != 2 || budgetActions[0] != "compact" || budgetActions[1] != "stop" { + t.Fatalf("expected budget actions [compact stop], got %v", budgetActions) + } + if stopPayload.Reason != controlplane.StopReasonBudgetExceeded { + t.Fatalf("expected stop reason %q, got %q", controlplane.StopReasonBudgetExceeded, stopPayload.Reason) + } +} + +func TestServiceRunAllowsAfterNoOpProactiveCompactWhenEstimateAdvisory(t *testing.T) { + t.Parallel() + + manager := newRuntimeConfigManager(t) + if err := manager.Update(context.Background(), func(cfg *config.Config) error { + cfg.Context.Budget.PromptBudget = 10 + cfg.Context.Budget.FallbackPromptBudget = 10 + return nil + }); err != nil { + t.Fatalf("update config: %v", err) + } + + store := newMemoryStore() + registry := tools.NewRegistry() + scripted := &scriptedProvider{ + estimateFn: func(ctx context.Context, req providertypes.GenerateRequest) (providertypes.BudgetEstimate, error) { + _ = ctx + _ = req + return providertypes.BudgetEstimate{ + EstimatedInputTokens: 99, + EstimateSource: provider.EstimateSourceLocal, + GatePolicy: provider.EstimateGateAdvisory, + }, nil + }, + responses: []scriptedResponse{ + { + Message: providertypes.Message{ + Role: providertypes.RoleAssistant, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("继续执行")}, + }, + FinishReason: "stop", + }, + }, + } + + service := NewWithFactory(manager, registry, store, &scriptedProviderFactory{provider: scripted}, &stubContextBuilder{}) + service.compactRunner = &stubCompactRunner{ + result: contextcompact.Result{ + Applied: false, + Metrics: contextcompact.Metrics{ + TriggerMode: string(contextcompact.ModeProactive), + }, + }, + } + + if err := service.Run(context.Background(), UserInput{ + RunID: "run-budget-advisory-allow-noop-compact", + Parts: []providertypes.ContentPart{providertypes.NewTextPart("continue")}, + }); err != nil { + t.Fatalf("Run() error = %v", err) + } + + compactRunner := service.compactRunner.(*stubCompactRunner) + if len(compactRunner.calls) != 1 { + t.Fatalf("expected one proactive compact attempt, got %d", len(compactRunner.calls)) + } + if scripted.callCount != 1 { + t.Fatalf("expected provider Generate to be called once after no-op compact, got %d calls", scripted.callCount) + } + + events := collectRuntimeEvents(service.Events()) + var budgetActions []string + var stopPayload StopReasonDecidedPayload + for _, event := range events { + switch event.Type { + case EventBudgetChecked: + payload, ok := event.Payload.(BudgetCheckedPayload) + if !ok { + t.Fatalf("expected BudgetCheckedPayload, got %T", event.Payload) + } + budgetActions = append(budgetActions, payload.Action) + case EventStopReasonDecided: + payload, ok := event.Payload.(StopReasonDecidedPayload) + if !ok { + t.Fatalf("expected StopReasonDecidedPayload, got %T", event.Payload) + } + stopPayload = payload + } + } + + if len(budgetActions) != 2 || budgetActions[0] != "compact" || budgetActions[1] != "allow" { + t.Fatalf("expected budget actions [compact allow], got %v", budgetActions) + } + if stopPayload.Reason != controlplane.StopReasonCompleted { + t.Fatalf("expected stop reason %q, got %q", controlplane.StopReasonCompleted, stopPayload.Reason) + } +} + func TestServiceRunBypassesBudgetGateWhenEstimateFails(t *testing.T) { t.Parallel() From 1e48b59e38dc25e6aa564d2e249ec31ea4470f86 Mon Sep 17 00:00:00 2001 From: xgopilot Date: Sun, 26 Apr 2026 08:33:29 +0000 Subject: [PATCH 8/8] test(provider): improve generate attempt and timeout normalization coverage Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: phantom5099 <245659304+phantom5099@users.noreply.github.com> --- internal/provider/constants_test.go | 35 +++- internal/provider/generate_attempt_test.go | 199 +++++++++++++++++++++ 2 files changed, 233 insertions(+), 1 deletion(-) diff --git a/internal/provider/constants_test.go b/internal/provider/constants_test.go index 5d82b4df..41452693 100644 --- a/internal/provider/constants_test.go +++ b/internal/provider/constants_test.go @@ -1,6 +1,9 @@ package provider -import "testing" +import ( + "testing" + "time" +) func TestNormalizeGenerateMaxRetries(t *testing.T) { t.Parallel() @@ -26,3 +29,33 @@ func TestNormalizeGenerateMaxRetries(t *testing.T) { }) } } + +func TestNormalizeGenerateStartTimeout(t *testing.T) { + t.Parallel() + + if got := NormalizeGenerateStartTimeout(0); got != DefaultGenerateStartTimeout { + t.Fatalf("NormalizeGenerateStartTimeout(0) = %s, want %s", got, DefaultGenerateStartTimeout) + } + if got := NormalizeGenerateStartTimeout(-time.Second); got != DefaultGenerateStartTimeout { + t.Fatalf("NormalizeGenerateStartTimeout(-1s) = %s, want %s", got, DefaultGenerateStartTimeout) + } + want := 3 * time.Second + if got := NormalizeGenerateStartTimeout(want); got != want { + t.Fatalf("NormalizeGenerateStartTimeout(3s) = %s, want %s", got, want) + } +} + +func TestNormalizeGenerateIdleTimeout(t *testing.T) { + t.Parallel() + + if got := NormalizeGenerateIdleTimeout(0); got != DefaultGenerateIdleTimeout { + t.Fatalf("NormalizeGenerateIdleTimeout(0) = %s, want %s", got, DefaultGenerateIdleTimeout) + } + if got := NormalizeGenerateIdleTimeout(-time.Second); got != DefaultGenerateIdleTimeout { + t.Fatalf("NormalizeGenerateIdleTimeout(-1s) = %s, want %s", got, DefaultGenerateIdleTimeout) + } + want := 4 * time.Second + if got := NormalizeGenerateIdleTimeout(want); got != want { + t.Fatalf("NormalizeGenerateIdleTimeout(4s) = %s, want %s", got, want) + } +} diff --git a/internal/provider/generate_attempt_test.go b/internal/provider/generate_attempt_test.go index 462abe85..0f386c46 100644 --- a/internal/provider/generate_attempt_test.go +++ b/internal/provider/generate_attempt_test.go @@ -3,6 +3,7 @@ package provider import ( "context" "errors" + "sync/atomic" "testing" "time" @@ -257,6 +258,204 @@ func TestRunGenerateWithRetryUsingTreatsMessageDoneAsCompletedState(t *testing.T } } +func TestRunGenerateWithRetryUsesDefaultRunner(t *testing.T) { + t.Parallel() + + cfg := RuntimeConfig{ + GenerateMaxRetries: 0, + GenerateStartTimeout: time.Second, + GenerateIdleTimeout: time.Second, + } + events := make(chan providertypes.StreamEvent, 8) + + err := RunGenerateWithRetry( + context.Background(), + cfg, + events, + func(ctx context.Context, attemptEvents chan<- providertypes.StreamEvent) error { + if emitErr := EmitTextDelta(ctx, attemptEvents, "ok"); emitErr != nil { + return emitErr + } + return EmitMessageDone(ctx, attemptEvents, "stop", nil) + }, + ) + if err != nil { + t.Fatalf("RunGenerateWithRetry() error = %v", err) + } + + drained := drainAttemptEvents(events) + if len(drained) != 2 { + t.Fatalf("expected two forwarded events, got %d", len(drained)) + } +} + +func TestRunGenerateWithRetryUsingReturnsRetryWaitError(t *testing.T) { + t.Parallel() + + cfg := RuntimeConfig{ + GenerateMaxRetries: 1, + GenerateStartTimeout: time.Second, + GenerateIdleTimeout: time.Second, + } + events := make(chan providertypes.StreamEvent, 4) + waitErr := errors.New("wait failed") + attempts := 0 + + err := RunGenerateWithRetryUsing( + context.Background(), + cfg, + events, + func(int) time.Duration { return time.Millisecond }, + func(context.Context, time.Duration) error { return waitErr }, + func(context.Context, chan<- providertypes.StreamEvent) error { + attempts++ + return ErrStreamInterrupted + }, + ) + if !errors.Is(err, waitErr) { + t.Fatalf("expected retry wait error, got %v", err) + } + if attempts != 1 { + t.Fatalf("expected only first attempt before wait failure, got %d", attempts) + } +} + +func TestRunGenerateWithRetryUsingReturnsLastErrorAfterExhaustedRetries(t *testing.T) { + t.Parallel() + + cfg := RuntimeConfig{ + GenerateMaxRetries: 1, + GenerateStartTimeout: time.Second, + GenerateIdleTimeout: time.Second, + } + events := make(chan providertypes.StreamEvent, 4) + firstErr := NewProviderErrorFromStatus(500, "first") + lastErr := NewProviderErrorFromStatus(500, "last") + attempts := 0 + + err := RunGenerateWithRetryUsing( + context.Background(), + cfg, + events, + func(int) time.Duration { return 0 }, + func(context.Context, time.Duration) error { return nil }, + func(context.Context, chan<- providertypes.StreamEvent) error { + attempts++ + if attempts == 1 { + return firstErr + } + return lastErr + }, + ) + if !errors.Is(err, lastErr) { + t.Fatalf("expected last retryable error, got %v", err) + } + if attempts != 2 { + t.Fatalf("expected two attempts after exhausting retries, got %d", attempts) + } +} + +func TestWaitForRetryHonorsContextCancel(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := waitForRetry(ctx, time.Second) + if !errors.Is(err, context.Canceled) { + t.Fatalf("waitForRetry() error = %v, want context canceled", err) + } +} + +func TestWaitForRetryReturnsNilOnNonPositiveWait(t *testing.T) { + t.Parallel() + + if err := waitForRetry(context.Background(), 0); err != nil { + t.Fatalf("waitForRetry() error = %v", err) + } +} + +func TestStopAndResetTimerHelpers(t *testing.T) { + t.Parallel() + + stopTimer(nil) + resetTimer(nil, time.Millisecond) + + timer := time.NewTimer(time.Millisecond) + time.Sleep(5 * time.Millisecond) + stopTimer(timer) + select { + case <-timer.C: + t.Fatal("expected stopTimer to drain timer channel") + default: + } + + resetTimer(timer, 30*time.Millisecond) + select { + case <-timer.C: + t.Fatal("expected resetTimer to apply new wait") + case <-time.After(10 * time.Millisecond): + } + select { + case <-timer.C: + case <-time.After(100 * time.Millisecond): + t.Fatal("expected reset timer to fire") + } +} + +func TestUpdateGenerateAttemptPhaseTransitions(t *testing.T) { + t.Parallel() + + var phase atomic.Uint32 + if got := updateGenerateAttemptPhase( + providertypes.StreamEvent{Type: providertypes.StreamEventType("unknown")}, + &phase, + ); got != generateAttemptPhaseWaitingFirstPayload { + t.Fatalf("unexpected initial phase = %v", got) + } + if got := updateGenerateAttemptPhase( + providertypes.StreamEvent{Type: providertypes.StreamEventToolCallStart}, + &phase, + ); got != generateAttemptPhaseStreaming { + t.Fatalf("expected streaming phase, got %v", got) + } + if got := updateGenerateAttemptPhase( + providertypes.StreamEvent{Type: providertypes.StreamEventMessageDone}, + &phase, + ); got != generateAttemptPhaseCompleted { + t.Fatalf("expected completed phase, got %v", got) + } + if got := updateGenerateAttemptPhase( + providertypes.StreamEvent{Type: providertypes.StreamEventTextDelta}, + &phase, + ); got != generateAttemptPhaseCompleted { + t.Fatalf("expected completed phase to remain terminal, got %v", got) + } +} + +func TestIsEffectiveGeneratePayloadEvent(t *testing.T) { + t.Parallel() + + cases := []struct { + eventType providertypes.StreamEventType + want bool + }{ + {eventType: providertypes.StreamEventTextDelta, want: true}, + {eventType: providertypes.StreamEventToolCallStart, want: true}, + {eventType: providertypes.StreamEventToolCallDelta, want: true}, + {eventType: providertypes.StreamEventMessageDone, want: false}, + } + for _, tc := range cases { + tc := tc + t.Run(string(tc.eventType), func(t *testing.T) { + t.Parallel() + if got := IsEffectiveGeneratePayloadEvent(providertypes.StreamEvent{Type: tc.eventType}); got != tc.want { + t.Fatalf("IsEffectiveGeneratePayloadEvent(%s) = %v, want %v", tc.eventType, got, tc.want) + } + }) + } +} + func drainAttemptEvents(events <-chan providertypes.StreamEvent) []providertypes.StreamEvent { out := make([]providertypes.StreamEvent, 0, len(events)) for {