From 9649d2039cb8807ba3b3a5b81c0bd86660277cf7 Mon Sep 17 00:00:00 2001 From: Jaaneek Date: Thu, 4 Jun 2026 17:01:40 +0100 Subject: [PATCH] Handle provider model-change restrictions. --- .../OrchestrationEngineHarness.integration.ts | 3 + .../Layers/ProviderCommandReactor.test.ts | 76 +++++++++++++++++++ .../Layers/ProviderCommandReactor.ts | 48 ++++++++++++ .../src/provider/Layers/GrokProvider.test.ts | 1 + .../src/provider/Layers/GrokProvider.ts | 1 + apps/server/src/provider/providerSnapshot.ts | 4 + .../testUtils/providerRegistryMock.ts | 21 +++++ .../web/src/components/ChatView.logic.test.ts | 68 +++++++++++++++++ apps/web/src/components/ChatView.logic.ts | 39 ++++++++++ apps/web/src/components/ChatView.tsx | 17 +++++ packages/contracts/src/server.ts | 1 + 11 files changed, 279 insertions(+) create mode 100644 apps/server/src/provider/testUtils/providerRegistryMock.ts diff --git a/apps/server/integration/OrchestrationEngineHarness.integration.ts b/apps/server/integration/OrchestrationEngineHarness.integration.ts index 837c32fc4fd..10390056950 100644 --- a/apps/server/integration/OrchestrationEngineHarness.integration.ts +++ b/apps/server/integration/OrchestrationEngineHarness.integration.ts @@ -35,6 +35,7 @@ import { ProjectionCheckpointRepository } from "../src/persistence/Services/Proj import { ProjectionPendingApprovalRepository } from "../src/persistence/Services/ProjectionPendingApprovals.ts"; import { makeAdapterRegistryMock } from "../src/provider/testUtils/providerAdapterRegistryMock.ts"; import { ProviderAdapterRegistry } from "../src/provider/Services/ProviderAdapterRegistry.ts"; +import { makeProviderRegistryLayer } from "../src/provider/testUtils/providerRegistryMock.ts"; import { ProviderSessionDirectoryLive } from "../src/provider/Layers/ProviderSessionDirectory.ts"; import { ServerSettingsService } from "../src/serverSettings.ts"; import { makeProviderServiceLive } from "../src/provider/Layers/ProviderService.ts"; @@ -292,6 +293,7 @@ export const makeOrchestrationIntegrationHarness = ( Layer.provide(AnalyticsService.layerTest), Layer.provide(providerEventLoggersLayer), ); + const providerRegistryLayer = makeProviderRegistryLayer(); const checkpointStoreLayer = CheckpointStoreLive.pipe(Layer.provide(VcsDriverRegistry.layer)); const projectionSnapshotQueryLayer = OrchestrationProjectionSnapshotQueryLive; @@ -368,6 +370,7 @@ export const makeOrchestrationIntegrationHarness = ( const layer = Layer.empty.pipe( Layer.provideMerge(runtimeServicesLayer), Layer.provideMerge(orchestrationReactorLayer), + Layer.provideMerge(providerRegistryLayer), Layer.provide(persistenceLayer), Layer.provideMerge(RepositoryIdentityResolverLive), Layer.provideMerge(ServerSettingsService.layerTest()), diff --git a/apps/server/src/orchestration/Layers/ProviderCommandReactor.test.ts b/apps/server/src/orchestration/Layers/ProviderCommandReactor.test.ts index 571164fad93..e6ac5436e92 100644 --- a/apps/server/src/orchestration/Layers/ProviderCommandReactor.test.ts +++ b/apps/server/src/orchestration/Layers/ProviderCommandReactor.test.ts @@ -40,6 +40,7 @@ import { ProviderService, type ProviderServiceShape, } from "../../provider/Services/ProviderService.ts"; +import { makeProviderRegistryLayer } from "../../provider/testUtils/providerRegistryMock.ts"; import { TextGeneration, type TextGenerationShape } from "../../textGeneration/TextGeneration.ts"; import { RepositoryIdentityResolverLive } from "../../project/Layers/RepositoryIdentityResolver.ts"; import { OrchestrationEngineLive } from "./OrchestrationEngine.ts"; @@ -142,6 +143,7 @@ describe("ProviderCommandReactor", () => { readonly baseDir?: string; readonly threadModelSelection?: ModelSelection; readonly sessionModelSwitch?: "unsupported" | "in-session"; + readonly requiresNewThreadForModelChange?: boolean; }) { const now = "2026-01-01T00:00:00.000Z"; const baseDir = input?.baseDir ?? fs.mkdtempSync(path.join(os.tmpdir(), "t3code-reactor-")); @@ -280,6 +282,14 @@ describe("ProviderCommandReactor", () => { }), ), ); + const providerSnapshots = [ + { + instanceId: modelSelection.instanceId, + ...(input?.requiresNewThreadForModelChange === true + ? { requiresNewThreadForModelChange: true } + : {}), + }, + ]; const unsupported = () => Effect.die(new Error("Unsupported provider call in test")) as never; const service: ProviderServiceShape = { @@ -335,6 +345,7 @@ describe("ProviderCommandReactor", () => { Layer.provideMerge(orchestrationLayer), Layer.provideMerge(projectionSnapshotLayer), Layer.provideMerge(Layer.succeed(ProviderService, service)), + Layer.provideMerge(makeProviderRegistryLayer(providerSnapshots as never)), Layer.provideMerge( Layer.mock(GitWorkflowService)({ renameBranch, @@ -879,6 +890,71 @@ describe("ProviderCommandReactor", () => { }); }); + it("rejects changing models after start when the provider requires a new thread", async () => { + const harness = await createHarness({ requiresNewThreadForModelChange: true }); + const now = "2026-01-01T00:00:00.000Z"; + + await Effect.runPromise( + harness.engine.dispatch({ + type: "thread.turn.start", + commandId: CommandId.make("cmd-turn-start-restricted-1"), + threadId: ThreadId.make("thread-1"), + message: { + messageId: asMessageId("user-message-restricted-1"), + role: "user", + text: "first", + attachments: [], + }, + interactionMode: DEFAULT_PROVIDER_INTERACTION_MODE, + runtimeMode: "approval-required", + createdAt: now, + }), + ); + + await waitFor(() => harness.sendTurn.mock.calls.length === 1); + + await Effect.runPromise( + harness.engine.dispatch({ + type: "thread.turn.start", + commandId: CommandId.make("cmd-turn-start-restricted-2"), + threadId: ThreadId.make("thread-1"), + message: { + messageId: asMessageId("user-message-restricted-2"), + role: "user", + text: "second", + attachments: [], + }, + modelSelection: { + instanceId: ProviderInstanceId.make("codex"), + model: "gpt-5.1-codex", + }, + interactionMode: DEFAULT_PROVIDER_INTERACTION_MODE, + runtimeMode: "approval-required", + createdAt: now, + }), + ); + + await waitFor(async () => { + const readModel = await harness.readModel(); + const thread = readModel.threads.find((entry) => entry.id === ThreadId.make("thread-1")); + return ( + thread?.activities.some((activity) => activity.kind === "provider.turn.start.failed") ?? + false + ); + }); + + expect(harness.sendTurn).toHaveBeenCalledTimes(1); + const readModel = await harness.readModel(); + const thread = readModel.threads.find((entry) => entry.id === ThreadId.make("thread-1")); + expect( + thread?.activities.find((activity) => activity.kind === "provider.turn.start.failed"), + ).toMatchObject({ + payload: { + detail: expect.stringContaining("cannot switch models after the conversation has started"), + }, + }); + }); + it("starts a first turn on the requested provider instance even when it differs from the thread model", async () => { const harness = await createHarness({ threadModelSelection: { instanceId: ProviderInstanceId.make("codex"), model: "gpt-5-codex" }, diff --git a/apps/server/src/orchestration/Layers/ProviderCommandReactor.ts b/apps/server/src/orchestration/Layers/ProviderCommandReactor.ts index f63b873bc3d..e0db0fc320c 100644 --- a/apps/server/src/orchestration/Layers/ProviderCommandReactor.ts +++ b/apps/server/src/orchestration/Layers/ProviderCommandReactor.ts @@ -31,6 +31,7 @@ import { ProviderAdapterRequestError } from "../../provider/Errors.ts"; import type { ProviderServiceError } from "../../provider/Errors.ts"; import { TextGeneration } from "../../textGeneration/TextGeneration.ts"; import { ProviderService } from "../../provider/Services/ProviderService.ts"; +import { ProviderRegistry } from "../../provider/Services/ProviderRegistry.ts"; import { OrchestrationEngineService } from "../Services/OrchestrationEngine.ts"; import { ProjectionSnapshotQuery } from "../Services/ProjectionSnapshotQuery.ts"; import { @@ -180,6 +181,7 @@ const make = Effect.gen(function* () { const orchestrationEngine = yield* OrchestrationEngineService; const projectionSnapshotQuery = yield* ProjectionSnapshotQuery; const providerService = yield* ProviderService; + const providerRegistry = yield* ProviderRegistry; const gitWorkflow = yield* GitWorkflowService; const vcsStatusBroadcaster = yield* VcsStatusBroadcaster; const textGeneration = yield* TextGeneration; @@ -305,6 +307,38 @@ const make = Effect.gen(function* () { .pipe(Effect.map(Option.getOrUndefined)); }); + const rejectStartedThreadModelChangeIfRequired = Effect.fnUntraced(function* (input: { + readonly threadId: ThreadId; + readonly currentModelSelection: ModelSelection; + readonly requestedModelSelection: ModelSelection | undefined; + }) { + const requestedModelSelection = input.requestedModelSelection; + if ( + requestedModelSelection === undefined || + (input.currentModelSelection.instanceId === requestedModelSelection.instanceId && + input.currentModelSelection.model === requestedModelSelection.model) + ) { + return; + } + const providers = yield* providerRegistry.getProviders; + const requiresNewThread = + providers.find((snapshot) => snapshot.instanceId === input.currentModelSelection.instanceId) + ?.requiresNewThreadForModelChange === true || + providers.find((snapshot) => snapshot.instanceId === requestedModelSelection.instanceId) + ?.requiresNewThreadForModelChange === true; + if (!requiresNewThread) { + return; + } + return yield* new ProviderAdapterRequestError({ + provider: providerErrorLabelFromInstanceHint({ + instanceId: String(requestedModelSelection.instanceId), + modelSelectionInstanceId: String(input.currentModelSelection.instanceId), + }), + method: "thread.turn.start", + detail: `Thread '${input.threadId}' cannot switch models after the conversation has started. Start a new thread to use '${requestedModelSelection.model}'.`, + }); + }); + const ensureSessionForThread = Effect.fn("ensureSessionForThread")(function* ( threadId: ThreadId, createdAt: string, @@ -384,6 +418,20 @@ const make = Effect.gen(function* () { }); } const preferredProvider: ProviderDriverKind = desiredDriverKind; + if (thread.session !== null) { + yield* rejectStartedThreadModelChangeIfRequired({ + threadId, + currentModelSelection: + activeSession?.model !== undefined + ? { + ...thread.modelSelection, + instanceId: currentInstanceId, + model: activeSession.model, + } + : thread.modelSelection, + requestedModelSelection, + }); + } if ( thread.session !== null && requestedModelSelection !== undefined && diff --git a/apps/server/src/provider/Layers/GrokProvider.test.ts b/apps/server/src/provider/Layers/GrokProvider.test.ts index 8de684cdd00..4fa1aa4b77d 100644 --- a/apps/server/src/provider/Layers/GrokProvider.test.ts +++ b/apps/server/src/provider/Layers/GrokProvider.test.ts @@ -39,6 +39,7 @@ describe("buildInitialGrokProviderSnapshot", () => { expect(snapshot.status).toBe("warning"); expect(snapshot.version).toBeNull(); expect(snapshot.message).toContain("Checking Grok"); + expect(snapshot.requiresNewThreadForModelChange).toBe(true); }); }); diff --git a/apps/server/src/provider/Layers/GrokProvider.ts b/apps/server/src/provider/Layers/GrokProvider.ts index 7348f9e7445..bead8b1a407 100644 --- a/apps/server/src/provider/Layers/GrokProvider.ts +++ b/apps/server/src/provider/Layers/GrokProvider.ts @@ -35,6 +35,7 @@ const GROK_PRESENTATION = { displayName: "Grok", badgeLabel: "Early Access", showInteractionModeToggle: false, + requiresNewThreadForModelChange: true, } as const; const PROVIDER = ProviderDriverKind.make("grok"); const EMPTY_CAPABILITIES: ModelCapabilities = createModelCapabilities({ diff --git a/apps/server/src/provider/providerSnapshot.ts b/apps/server/src/provider/providerSnapshot.ts index c40903e1b45..ce43c5e6eab 100644 --- a/apps/server/src/provider/providerSnapshot.ts +++ b/apps/server/src/provider/providerSnapshot.ts @@ -45,6 +45,7 @@ export interface ServerProviderPresentation { readonly displayName: string; readonly badgeLabel?: string; readonly showInteractionModeToggle?: boolean; + readonly requiresNewThreadForModelChange?: boolean; } export type ServerProviderDraft = Omit; @@ -214,6 +215,9 @@ export function buildServerProvider(input: { ...(typeof input.presentation.showInteractionModeToggle === "boolean" ? { showInteractionModeToggle: input.presentation.showInteractionModeToggle } : {}), + ...(typeof input.presentation.requiresNewThreadForModelChange === "boolean" + ? { requiresNewThreadForModelChange: input.presentation.requiresNewThreadForModelChange } + : {}), enabled: input.enabled, installed: input.probe.installed, version: input.probe.version, diff --git a/apps/server/src/provider/testUtils/providerRegistryMock.ts b/apps/server/src/provider/testUtils/providerRegistryMock.ts new file mode 100644 index 00000000000..36598b05900 --- /dev/null +++ b/apps/server/src/provider/testUtils/providerRegistryMock.ts @@ -0,0 +1,21 @@ +import { ProviderRegistry, type ProviderRegistryShape } from "../Services/ProviderRegistry.ts"; +import type { ServerProvider } from "@t3tools/contracts"; +import * as Effect from "effect/Effect"; +import * as Layer from "effect/Layer"; +import * as Stream from "effect/Stream"; +import { makeManualOnlyProviderMaintenanceCapabilities } from "../providerMaintenance.ts"; + +export const makeProviderRegistryMock = ( + providers: ReadonlyArray = [], +): ProviderRegistryShape => ({ + getProviders: Effect.succeed(providers), + refresh: () => Effect.succeed(providers), + refreshInstance: () => Effect.succeed(providers), + getProviderMaintenanceCapabilitiesForInstance: (_instanceId, provider) => + Effect.succeed(makeManualOnlyProviderMaintenanceCapabilities({ provider, packageName: null })), + setProviderMaintenanceActionState: () => Effect.succeed(providers), + streamChanges: Stream.empty, +}); + +export const makeProviderRegistryLayer = (providers: ReadonlyArray = []) => + Layer.succeed(ProviderRegistry, makeProviderRegistryMock(providers)); diff --git a/apps/web/src/components/ChatView.logic.test.ts b/apps/web/src/components/ChatView.logic.test.ts index 83c90edaddc..0806953e108 100644 --- a/apps/web/src/components/ChatView.logic.test.ts +++ b/apps/web/src/components/ChatView.logic.test.ts @@ -16,6 +16,7 @@ import { buildExpiredTerminalContextToastCopy, createLocalDispatchSnapshot, deriveComposerSendState, + getStartedThreadModelChangeBlockReason, hasServerAcknowledgedLocalDispatch, reconcileMountedTerminalThreadIds, resolveSendEnvMode, @@ -90,6 +91,73 @@ describe("buildExpiredTerminalContextToastCopy", () => { }); }); +describe("getStartedThreadModelChangeBlockReason", () => { + const providers = [ + { + instanceId: ProviderInstanceId.make("codex"), + }, + { + instanceId: ProviderInstanceId.make("grok"), + requiresNewThreadForModelChange: true, + }, + ]; + + it("allows model changes before a provider session has started", () => { + expect( + getStartedThreadModelChangeBlockReason({ + providers, + hasStartedSession: false, + currentModelSelection: { + instanceId: ProviderInstanceId.make("grok"), + model: "grok-build", + }, + nextModelSelection: { + instanceId: ProviderInstanceId.make("grok"), + model: "grok-other", + }, + }), + ).toBeNull(); + }); + + it("allows unchanged model selections for restricted providers", () => { + expect( + getStartedThreadModelChangeBlockReason({ + providers, + hasStartedSession: true, + currentModelSelection: { + instanceId: ProviderInstanceId.make("grok"), + model: "grok-build", + }, + nextModelSelection: { + instanceId: ProviderInstanceId.make("grok"), + model: "grok-build", + }, + }), + ).toBeNull(); + }); + + it("blocks started-session model changes when either provider requires a new thread", () => { + expect( + getStartedThreadModelChangeBlockReason({ + providers, + hasStartedSession: true, + currentModelSelection: { + instanceId: ProviderInstanceId.make("codex"), + model: "gpt-5.4", + }, + nextModelSelection: { + instanceId: ProviderInstanceId.make("grok"), + model: "grok-build", + }, + }), + ).toEqual({ + title: "Start a new chat to change models", + description: + "This provider does not allow switching models after a conversation has started.", + }); + }); +}); + describe("resolveSendEnvMode", () => { it("keeps worktree mode for git repositories", () => { expect(resolveSendEnvMode({ requestedEnvMode: "worktree", isGitRepo: true })).toBe("worktree"); diff --git a/apps/web/src/components/ChatView.logic.ts b/apps/web/src/components/ChatView.logic.ts index bf87add28d9..de69c573046 100644 --- a/apps/web/src/components/ChatView.logic.ts +++ b/apps/web/src/components/ChatView.logic.ts @@ -4,6 +4,7 @@ import { ProjectId, type ModelSelection, type ProviderDriverKind, + type ServerProvider, type ScopedThreadRef, type ThreadId, type TurnId, @@ -262,6 +263,44 @@ export function deriveLockedProvider(input: { return narrowedThreadProvider ?? narrowedSelectedProvider ?? null; } +export function getStartedThreadModelChangeBlockReason(input: { + providers: ReadonlyArray>; + hasStartedSession: boolean; + currentModelSelection: ModelSelection; + currentProviderInstanceId?: ModelSelection["instanceId"] | null | undefined; + nextModelSelection: ModelSelection; +}): { title: string; description: string } | null { + if (!input.hasStartedSession) { + return null; + } + const currentModelSelection = { + ...input.currentModelSelection, + instanceId: input.currentProviderInstanceId ?? input.currentModelSelection.instanceId, + }; + if ( + currentModelSelection.instanceId === input.nextModelSelection.instanceId && + currentModelSelection.model === input.nextModelSelection.model + ) { + return null; + } + const currentProvider = input.providers.find( + (snapshot) => snapshot.instanceId === currentModelSelection.instanceId, + ); + const nextProvider = input.providers.find( + (snapshot) => snapshot.instanceId === input.nextModelSelection.instanceId, + ); + if ( + currentProvider?.requiresNewThreadForModelChange !== true && + nextProvider?.requiresNewThreadForModelChange !== true + ) { + return null; + } + return { + title: "Start a new chat to change models", + description: "This provider does not allow switching models after a conversation has started.", + }; +} + export async function waitForStartedServerThread( threadRef: ScopedThreadRef, timeoutMs = 1_000, diff --git a/apps/web/src/components/ChatView.tsx b/apps/web/src/components/ChatView.tsx index 5f5a96647d3..df5d5e4a2ff 100644 --- a/apps/web/src/components/ChatView.tsx +++ b/apps/web/src/components/ChatView.tsx @@ -163,6 +163,7 @@ import { createLocalDispatchSnapshot, deriveComposerSendState, hasServerAcknowledgedLocalDispatch, + getStartedThreadModelChangeBlockReason, LAST_INVOKED_SCRIPT_BY_PROJECT_KEY, LastInvokedScriptByProjectSchema, type LocalDispatchSnapshot, @@ -3639,6 +3640,22 @@ export default function ChatView(props: ChatViewProps) { instanceId, model: resolvedModel, }; + const modelChangeBlockReason = getStartedThreadModelChangeBlockReason({ + providers: providerStatuses, + hasStartedSession: activeThread.session !== null, + currentModelSelection: activeThread.modelSelection, + currentProviderInstanceId: activeThread.session?.providerInstanceId ?? null, + nextModelSelection, + }); + if (modelChangeBlockReason) { + toastManager.add({ + type: "warning", + title: modelChangeBlockReason.title, + description: modelChangeBlockReason.description, + }); + scheduleComposerFocus(); + return; + } setComposerDraftModelSelection( scopeThreadRef(activeThread.environmentId, activeThread.id), nextModelSelection, diff --git a/packages/contracts/src/server.ts b/packages/contracts/src/server.ts index 85ff4a4b2cb..dbc5d3ca549 100644 --- a/packages/contracts/src/server.ts +++ b/packages/contracts/src/server.ts @@ -165,6 +165,7 @@ export const ServerProvider = Schema.Struct({ badgeLabel: Schema.optional(TrimmedNonEmptyString), continuation: Schema.optional(ServerProviderContinuation), showInteractionModeToggle: Schema.optional(Schema.Boolean), + requiresNewThreadForModelChange: Schema.optional(Schema.Boolean), enabled: Schema.Boolean, installed: Schema.Boolean, version: Schema.NullOr(TrimmedNonEmptyString),