diff --git a/packages/app/src/context/local.tsx b/packages/app/src/context/local.tsx index 84a613c0d285..f05da70b6c2d 100644 --- a/packages/app/src/context/local.tsx +++ b/packages/app/src/context/local.tsx @@ -7,7 +7,12 @@ import { useModels } from "@/context/models" import { useProviders } from "@/hooks/use-providers" import { modelEnabled, modelProbe } from "@/testing/model-selection" import { Persist, persisted } from "@/utils/persist" -import { cycleModelVariant, getConfiguredAgentVariant, resolveModelVariant } from "./model-variant" +import { + cycleModelVariant, + getConfiguredAgentVariant, + getConfiguredModelVariant, + resolveModelVariant, +} from "./model-variant" import { useSDK } from "./sdk" import { useSync } from "./sync" @@ -235,11 +240,22 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({ const configured = () => { const item = agent.current() const model = current() - if (!item || !model) return undefined - return getConfiguredAgentVariant({ - agent: { model: item.model, variant: item.variant }, - model: { providerID: model.provider.id, modelID: model.id, variants: model.variants }, - }) + if (!model) return undefined + return ( + (item && + getConfiguredAgentVariant({ + agent: { model: item.model, variant: item.variant }, + model: { providerID: model.provider.id, modelID: model.id, variants: model.variants }, + })) ?? + getConfiguredModelVariant({ + model: { + providerID: model.provider.id, + modelID: model.id, + options: model.options, + variants: model.variants, + }, + }) + ) } const selected = () => scope()?.variant diff --git a/packages/app/src/context/model-variant.test.ts b/packages/app/src/context/model-variant.test.ts index 583bc5c3dc71..2b9ddbe87d72 100644 --- a/packages/app/src/context/model-variant.test.ts +++ b/packages/app/src/context/model-variant.test.ts @@ -1,5 +1,10 @@ import { describe, expect, test } from "bun:test" -import { cycleModelVariant, getConfiguredAgentVariant, resolveModelVariant } from "./model-variant" +import { + cycleModelVariant, + getConfiguredAgentVariant, + getConfiguredModelVariant, + resolveModelVariant, +} from "./model-variant" describe("model variant", () => { test("resolves configured agent variant when model matches", () => { @@ -34,6 +39,67 @@ describe("model variant", () => { expect(value).toBeUndefined() }) + test("infers configured model variant from matching options", () => { + const value = getConfiguredModelVariant({ + model: { + providerID: "openai", + modelID: "gpt-5.4", + options: { reasoningEffort: "high" }, + variants: { low: { reasoningEffort: "low" }, high: { reasoningEffort: "high" } }, + }, + }) + + expect(value).toBe("high") + }) + + test("infers configured model variant when built-in variant adds extra defaults", () => { + const value = getConfiguredModelVariant({ + model: { + providerID: "openai", + modelID: "gpt-5.4", + options: { reasoningEffort: "high" }, + variants: { + low: { reasoningEffort: "low", reasoningSummary: "auto", include: ["reasoning.encrypted_content"] }, + high: { reasoningEffort: "high", reasoningSummary: "auto", include: ["reasoning.encrypted_content"] }, + }, + }, + }) + + expect(value).toBe("high") + }) + + test("infers configured model variant from nested options", () => { + const value = getConfiguredModelVariant({ + model: { + providerID: "google", + modelID: "gemini-3", + options: { thinkingConfig: { thinkingLevel: "high", includeThoughts: true } }, + variants: { + low: { thinkingConfig: { thinkingLevel: "low", includeThoughts: true } }, + high: { thinkingConfig: { thinkingLevel: "high", includeThoughts: true } }, + }, + }, + }) + + expect(value).toBe("high") + }) + + test("does not infer a variant from auxiliary defaults alone", () => { + const value = getConfiguredModelVariant({ + model: { + providerID: "openai", + modelID: "gpt-5.4", + options: { reasoningSummary: "auto" }, + variants: { + low: { reasoningEffort: "low", reasoningSummary: "auto" }, + high: { reasoningEffort: "high", reasoningSummary: "auto" }, + }, + }, + }) + + expect(value).toBeUndefined() + }) + test("prefers selected variant over configured variant", () => { const value = resolveModelVariant({ variants: ["low", "high", "xhigh"], @@ -54,6 +120,16 @@ describe("model variant", () => { expect(value).toBeUndefined() }) + test("lets an explicit default override the inferred model variant", () => { + const value = resolveModelVariant({ + variants: ["low", "high", "xhigh"], + selected: null, + configured: "high", + }) + + expect(value).toBeUndefined() + }) + test("cycles from configured variant to next", () => { const value = cycleModelVariant({ variants: ["low", "high", "xhigh"], diff --git a/packages/app/src/context/model-variant.ts b/packages/app/src/context/model-variant.ts index 525acbba3219..2117d163e77e 100644 --- a/packages/app/src/context/model-variant.ts +++ b/packages/app/src/context/model-variant.ts @@ -9,6 +9,7 @@ type Agent = { } type Model = AgentModel & { + options?: Record variants?: Record } @@ -28,6 +29,46 @@ export function getConfiguredAgentVariant(input: { agent: Agent | undefined; mod return input.agent.variant } +function record(value: unknown): value is Record { + return typeof value === "object" && value !== null && !Array.isArray(value) +} + +function pick(value: unknown, path: string[]) { + return path.reduce((acc, key) => (record(acc) ? acc[key] : undefined), value) +} + +function signal(value: unknown) { + const keys = [ + ["reasoningEffort"], + ["reasoning", "effort"], + ["effort"], + ["thinkingLevel"], + ["thinkingBudget"], + ["thinking_budget"], + ["thinkingConfig", "thinkingLevel"], + ["thinkingConfig", "thinkingBudget"], + ["thinking", "budgetTokens"], + ["reasoningConfig", "budgetTokens"], + ["reasoningConfig", "maxReasoningEffort"], + ] + + return keys.flatMap((path) => { + const item = pick(value, path) + return item === undefined ? [] : [[path.join("."), item] as const] + }) +} + +export function getConfiguredModelVariant(input: { model: Model | undefined }) { + if (!input.model?.variants) return undefined + if (!input.model.options) return undefined + const cfg = signal(input.model.options) + if (cfg.length === 0) return undefined + return Object.entries(input.model.variants).find(([, value]) => { + const variant = new Map(signal(value)) + return cfg.every(([key, item]) => variant.get(key) === item) + })?.[0] +} + export function resolveModelVariant(input: VariantInput) { if (input.selected === null) return undefined if (input.selected && input.variants.includes(input.selected)) return input.selected