diff --git a/docs/guides/adding-providers.md b/docs/guides/adding-providers.md index 538a65a4..af0be185 100644 --- a/docs/guides/adding-providers.md +++ b/docs/guides/adding-providers.md @@ -108,7 +108,7 @@ generate_idle_timeout_sec: 300 - `chat_endpoint_path` 为 `/` 表示直连 `base_url`;为空时会按 `chat_api_mode` 自动回填默认子路径(`/chat/completions` 或 `/responses`)。 - 当 `chat_api_mode` 已显式指定时,`chat_endpoint_path` 可使用任意以 `/` 开头的相对路径;未显式指定时,仅支持标准端点推断(`/chat/completions`、`/responses`、`/`)。 - `model_source: manual` 时必须提供 `models`,且会忽略 `discovery_endpoint_path`。 -- `generate_max_retries` / `generate_idle_timeout_sec` 用于控制 provider 级生成重试和流空闲超时;未填写或 `<= 0` 时会分别回退到 `5 / 300`。其中 `generate_max_retries` 必须 `<= 20`。 +- `generate_max_retries` / `generate_idle_timeout_sec` 用于控制 provider 级生成重试和流空闲超时;`generate_max_retries` 未填写时默认使用 `5`,显式填写 `0` 表示关闭生成重试,`generate_idle_timeout_sec` 未填写或 `<= 0` 时回退到 `300`。其中 `generate_max_retries` 必须 `<= 20`。 - `generate_start_timeout_sec` 已改为根 `config.yaml` 顶层字段,不再允许写入 `provider.yaml`;启动时缺失会自动补写默认值 `90`。 ## 测试要求 diff --git a/docs/guides/configuration.md b/docs/guides/configuration.md index f259eea3..c6d883f9 100644 --- a/docs/guides/configuration.md +++ b/docs/guides/configuration.md @@ -179,7 +179,7 @@ generate_idle_timeout_sec: 300 新增的生成链路控制字段含义如下: -- `generate_max_retries`:额外重试次数,不含首次尝试;`<= 0` 时回退默认值 `5`,且必须 `<= 20`。 +- `generate_max_retries`:额外重试次数,不含首次尝试;未填写时默认使用 `5`,显式填写 `0` 表示关闭生成重试,且必须 `<= 20`。 - `generate_start_timeout_sec`:写在 `config.yaml` 顶层,从发请求到收到首个有效流 payload 的最长等待窗口;`<= 0` 时回退默认值 `90`。 - `generate_idle_timeout_sec`:首包后连续没有任何新 payload 的最长空闲窗口;`<= 0` 时回退默认值 `300`。 diff --git a/internal/config/loader_test.go b/internal/config/loader_test.go index 3c5e54c9..c655b145 100644 --- a/internal/config/loader_test.go +++ b/internal/config/loader_test.go @@ -1393,6 +1393,55 @@ func TestSaveAndLoadCustomProviderPersistsGenerateControls(t *testing.T) { } } +func TestSaveAndLoadCustomProviderPreservesExplicitZeroGenerateRetries(t *testing.T) { + t.Parallel() + + baseDir := t.TempDir() + const providerName = "zero-retry-provider" + err := SaveCustomProviderWithModels(baseDir, SaveCustomProviderInput{ + Name: providerName, + Driver: provider.DriverOpenAICompat, + BaseURL: "https://llm.example.com/v1", + APIKeyEnv: "ZERO_RETRY_PROVIDER_API_KEY", + ModelSource: ModelSourceDiscover, + DiscoveryEndpointPath: provider.DiscoveryEndpointPathModels, + GenerateMaxRetries: 0, + GenerateMaxRetriesSet: true, + GenerateIdleTimeoutSec: 420, + }) + if err != nil { + t.Fatalf("SaveCustomProviderWithModels() error = %v", err) + } + + data, err := os.ReadFile(filepath.Join(baseDir, providersDirName, providerName, customProviderConfigName)) + if err != nil { + t.Fatalf("ReadFile() error = %v", err) + } + content := string(data) + if !strings.Contains(content, "generate_max_retries: 0") { + t.Fatalf("expected generate_max_retries: 0 to be persisted, got %q", content) + } + + cfg, err := loadCustomProvider(filepath.Join(baseDir, providersDirName, providerName)) + if err != nil { + t.Fatalf("loadCustomProvider() error = %v", err) + } + if !cfg.GenerateMaxRetriesSet { + t.Fatal("expected explicit zero retry setting to remain marked as configured") + } + runtimeCfg, err := cfg.Resolve() + if err != nil { + t.Fatalf("Resolve() error = %v", err) + } + providerRuntimeCfg, err := runtimeCfg.ToRuntimeConfig() + if err != nil { + t.Fatalf("ToRuntimeConfig() error = %v", err) + } + if providerRuntimeCfg.GenerateMaxRetries != 0 { + t.Fatalf("expected explicit zero retry setting to disable retries, got %d", providerRuntimeCfg.GenerateMaxRetries) + } +} + func TestSaveCustomProviderOmitsDefaultGenerateControlsWhenUnset(t *testing.T) { t.Parallel() @@ -1423,6 +1472,43 @@ func TestSaveCustomProviderOmitsDefaultGenerateControlsWhenUnset(t *testing.T) { } } +func TestLoadCustomProviderUsesDefaultGenerateRetriesWhenUnset(t *testing.T) { + t.Parallel() + + baseDir := t.TempDir() + const providerName = "default-retry-provider" + err := SaveCustomProviderWithModels(baseDir, SaveCustomProviderInput{ + Name: providerName, + Driver: provider.DriverOpenAICompat, + BaseURL: "https://llm.example.com/v1", + APIKeyEnv: "DEFAULT_RETRY_PROVIDER_API_KEY", + ModelSource: ModelSourceDiscover, + DiscoveryEndpointPath: provider.DiscoveryEndpointPathModels, + }) + if err != nil { + t.Fatalf("SaveCustomProviderWithModels() error = %v", err) + } + + cfg, err := loadCustomProvider(filepath.Join(baseDir, providersDirName, providerName)) + if err != nil { + t.Fatalf("loadCustomProvider() error = %v", err) + } + if cfg.GenerateMaxRetriesSet { + t.Fatal("expected omitted generate_max_retries to remain unset") + } + resolved, err := cfg.Resolve() + if err != nil { + t.Fatalf("Resolve() error = %v", err) + } + runtimeCfg, err := resolved.ToRuntimeConfig() + if err != nil { + t.Fatalf("ToRuntimeConfig() error = %v", err) + } + if runtimeCfg.GenerateMaxRetries != provider.DefaultGenerateMaxRetries { + t.Fatalf("expected omitted generate_max_retries to use default %d, got %d", provider.DefaultGenerateMaxRetries, runtimeCfg.GenerateMaxRetries) + } +} + func TestLoaderRejectsCustomProviderGenerateStartTimeoutField(t *testing.T) { t.Parallel() diff --git a/internal/config/provider.go b/internal/config/provider.go index 437ba380..f05bcd55 100644 --- a/internal/config/provider.go +++ b/internal/config/provider.go @@ -28,6 +28,7 @@ type ProviderConfig struct { Model string `yaml:"model"` APIKeyEnv string `yaml:"api_key_env"` GenerateMaxRetries int `yaml:"generate_max_retries,omitempty"` + GenerateMaxRetriesSet bool `yaml:"-"` GenerateIdleTimeoutSec int `yaml:"generate_idle_timeout_sec,omitempty"` ModelSource string `yaml:"-"` ChatAPIMode string `yaml:"-"` @@ -169,6 +170,14 @@ func (p ProviderConfig) Resolve() (ResolvedProviderConfig, error) { }, nil } +// resolveGenerateMaxRetries 统一解析 provider 级生成重试次数,兼容“未配置使用默认值”和“显式 0 关闭重试”两种语义。 +func (p ProviderConfig) resolveGenerateMaxRetries() int { + if p.GenerateMaxRetriesSet || p.GenerateMaxRetries > 0 { + return provider.NormalizeGenerateMaxRetries(p.GenerateMaxRetries) + } + return provider.DefaultGenerateMaxRetries +} + func cloneProviders(providers []ProviderConfig) []ProviderConfig { if len(providers) == 0 { return nil @@ -283,7 +292,7 @@ func (p ResolvedProviderConfig) ToRuntimeConfig() (provider.RuntimeConfig, error ChatAPIMode: chatAPIMode, ChatEndpointPath: chatEndpointPath, DiscoveryEndpointPath: discoveryEndpointPath, - GenerateMaxRetries: provider.NormalizeGenerateMaxRetries(p.GenerateMaxRetries), + GenerateMaxRetries: p.resolveGenerateMaxRetries(), GenerateStartTimeout: provider.NormalizeGenerateStartTimeout(time.Duration(p.GenerateStartTimeoutSec) * time.Second), GenerateIdleTimeout: provider.NormalizeGenerateIdleTimeout(time.Duration(p.GenerateIdleTimeoutSec) * time.Second), }, nil diff --git a/internal/config/provider_custom_normalize.go b/internal/config/provider_custom_normalize.go index ca32604f..bffb7d03 100644 --- a/internal/config/provider_custom_normalize.go +++ b/internal/config/provider_custom_normalize.go @@ -21,6 +21,7 @@ func NormalizeCustomProviderInput(input SaveCustomProviderInput) (SaveCustomProv ChatEndpointPath: strings.TrimSpace(input.ChatEndpointPath), APIKeyEnv: strings.TrimSpace(input.APIKeyEnv), GenerateMaxRetries: normalizeOptionalGenerateInt(input.GenerateMaxRetries), + GenerateMaxRetriesSet: input.GenerateMaxRetriesSet || input.GenerateMaxRetries > 0, GenerateIdleTimeoutSec: normalizeOptionalGenerateInt(input.GenerateIdleTimeoutSec), DiscoveryEndpointPath: strings.TrimSpace(input.DiscoveryEndpointPath), } @@ -118,6 +119,7 @@ func validateNormalizedCustomProviderInput(input SaveCustomProviderInput) error BaseURL: input.BaseURL, APIKeyEnv: input.APIKeyEnv, GenerateMaxRetries: input.GenerateMaxRetries, + GenerateMaxRetriesSet: input.GenerateMaxRetriesSet, GenerateIdleTimeoutSec: input.GenerateIdleTimeoutSec, ModelSource: input.ModelSource, ChatAPIMode: input.ChatAPIMode, diff --git a/internal/config/provider_loader.go b/internal/config/provider_loader.go index 3fab03c5..f460a76a 100644 --- a/internal/config/provider_loader.go +++ b/internal/config/provider_loader.go @@ -24,7 +24,7 @@ type customProviderFile struct { Name string `yaml:"name"` Driver string `yaml:"driver"` APIKeyEnv string `yaml:"api_key_env"` - GenerateMaxRetries int `yaml:"generate_max_retries,omitempty"` + GenerateMaxRetries *int `yaml:"generate_max_retries,omitempty"` GenerateIdleTimeoutSec int `yaml:"generate_idle_timeout_sec,omitempty"` ModelSource string `yaml:"model_source,omitempty"` ChatAPIMode string `yaml:"chat_api_mode,omitempty"` @@ -115,7 +115,8 @@ func loadCustomProvider(providerDir string) (ProviderConfig, error) { Driver: strings.TrimSpace(file.Driver), BaseURL: strings.TrimSpace(file.BaseURL), APIKeyEnv: strings.TrimSpace(file.APIKeyEnv), - GenerateMaxRetries: file.GenerateMaxRetries, + GenerateMaxRetries: optionalIntValue(file.GenerateMaxRetries), + GenerateMaxRetriesSet: file.GenerateMaxRetries != nil, GenerateIdleTimeoutSec: file.GenerateIdleTimeoutSec, ModelSource: strings.TrimSpace(file.ModelSource), ChatAPIMode: strings.TrimSpace(file.ChatAPIMode), @@ -133,6 +134,7 @@ func loadCustomProvider(providerDir string) (ProviderConfig, error) { BaseURL: normalizedInput.BaseURL, APIKeyEnv: normalizedInput.APIKeyEnv, GenerateMaxRetries: normalizedInput.GenerateMaxRetries, + GenerateMaxRetriesSet: normalizedInput.GenerateMaxRetriesSet, GenerateIdleTimeoutSec: normalizedInput.GenerateIdleTimeoutSec, ModelSource: normalizedInput.ModelSource, ChatAPIMode: normalizedInput.ChatAPIMode, @@ -204,6 +206,7 @@ type SaveCustomProviderInput struct { ChatEndpointPath string APIKeyEnv string GenerateMaxRetries int + GenerateMaxRetriesSet bool GenerateIdleTimeoutSec int DiscoveryEndpointPath string ModelSource string @@ -226,7 +229,7 @@ func SaveCustomProviderWithModels(baseDir string, input SaveCustomProviderInput) Name: normalizedInput.Name, Driver: normalizedInput.Driver, APIKeyEnv: normalizedInput.APIKeyEnv, - GenerateMaxRetries: normalizedInput.GenerateMaxRetries, + GenerateMaxRetries: optionalIntPointer(normalizedInput.GenerateMaxRetries, normalizedInput.GenerateMaxRetriesSet || normalizedInput.GenerateMaxRetries > 0), GenerateIdleTimeoutSec: normalizedInput.GenerateIdleTimeoutSec, ModelSource: normalizedInput.ModelSource, ChatAPIMode: normalizedInput.ChatAPIMode, @@ -313,3 +316,20 @@ func validateCustomProviderName(name string) error { } return nil } + +// optionalIntValue 统一读取可选整数字段,避免缺省值和显式零值在解析阶段丢失原始语义。 +func optionalIntValue(value *int) int { + if value == nil { + return 0 + } + return *value +} + +// optionalIntPointer 根据是否显式配置决定是否输出 YAML 字段,保留“未配置”和“显式 0”两种语义差异。 +func optionalIntPointer(value int, configured bool) *int { + if !configured { + return nil + } + out := value + return &out +} diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go index 07a056af..0668e21f 100644 --- a/internal/config/provider_test.go +++ b/internal/config/provider_test.go @@ -776,7 +776,7 @@ func TestResolvedProviderConfigToRuntimeConfig(t *testing.T) { ChatAPIMode: "", ChatEndpointPath: "", DiscoveryEndpointPath: providerpkg.DiscoveryEndpointPathModels, - GenerateMaxRetries: 0, + GenerateMaxRetries: providerpkg.DefaultGenerateMaxRetries, GenerateStartTimeout: providerpkg.DefaultGenerateStartTimeout, GenerateIdleTimeout: providerpkg.DefaultGenerateIdleTimeout, } @@ -802,6 +802,30 @@ func TestResolvedProviderConfigToRuntimeConfig(t *testing.T) { } } +func TestResolvedProviderConfigToRuntimeConfigPreservesExplicitZeroGenerateRetries(t *testing.T) { + t.Parallel() + + resolved := ResolvedProviderConfig{ + ProviderConfig: ProviderConfig{ + Name: "company-gateway", + Driver: "openaicompat", + BaseURL: "https://llm.example.com/v1", + Model: "server-default", + APIKeyEnv: "COMPANY_GATEWAY_KEY", + GenerateMaxRetries: 0, + GenerateMaxRetriesSet: true, + }, + } + + got, err := resolved.ToRuntimeConfig() + if err != nil { + t.Fatalf("ToRuntimeConfig() error = %v", err) + } + if got.GenerateMaxRetries != 0 { + t.Fatalf("expected explicit GenerateMaxRetries=0 to disable retries, got %d", got.GenerateMaxRetries) + } +} + func TestResolvedProviderConfigToRuntimeConfigMapsGenerateControls(t *testing.T) { t.Parallel() diff --git a/internal/provider/constants.go b/internal/provider/constants.go index 5b9216a2..ea28a24d 100644 --- a/internal/provider/constants.go +++ b/internal/provider/constants.go @@ -23,7 +23,7 @@ const ( // DefaultGenerateRetryBaseWait 定义生成链路重试退避的基础等待时长。 DefaultGenerateRetryBaseWait = 1 * time.Second // DefaultGenerateRetryMaxWait 定义生成链路重试退避的最大等待时长。 - DefaultGenerateRetryMaxWait = 5 * time.Second + DefaultGenerateRetryMaxWait = 7 * time.Second // DefaultSDKRequestTimeout 定义非生成链路访问外部模型 SDK 的统一保底超时。 DefaultSDKRequestTimeout = 10 * time.Minute ) diff --git a/internal/provider/constants_test.go b/internal/provider/constants_test.go index 41452693..889133cc 100644 --- a/internal/provider/constants_test.go +++ b/internal/provider/constants_test.go @@ -59,3 +59,11 @@ func TestNormalizeGenerateIdleTimeout(t *testing.T) { t.Fatalf("NormalizeGenerateIdleTimeout(4s) = %s, want %s", got, want) } } + +func TestDefaultGenerateRetryMaxWait(t *testing.T) { + t.Parallel() + + if DefaultGenerateRetryMaxWait != 7*time.Second { + t.Fatalf("DefaultGenerateRetryMaxWait = %s, want 7s", DefaultGenerateRetryMaxWait) + } +}