Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import { ProjectionCheckpointRepository } from "../src/persistence/Services/Proj
import { ProjectionPendingApprovalRepository } from "../src/persistence/Services/ProjectionPendingApprovals.ts";
import { makeAdapterRegistryMock } from "../src/provider/testUtils/providerAdapterRegistryMock.ts";
import { ProviderAdapterRegistry } from "../src/provider/Services/ProviderAdapterRegistry.ts";
import { makeProviderRegistryLayer } from "../src/provider/testUtils/providerRegistryMock.ts";
import { ProviderSessionDirectoryLive } from "../src/provider/Layers/ProviderSessionDirectory.ts";
import { ServerSettingsService } from "../src/serverSettings.ts";
import { makeProviderServiceLive } from "../src/provider/Layers/ProviderService.ts";
Expand Down Expand Up @@ -292,6 +293,7 @@ export const makeOrchestrationIntegrationHarness = (
Layer.provide(AnalyticsService.layerTest),
Layer.provide(providerEventLoggersLayer),
);
const providerRegistryLayer = makeProviderRegistryLayer();

const checkpointStoreLayer = CheckpointStoreLive.pipe(Layer.provide(VcsDriverRegistry.layer));
const projectionSnapshotQueryLayer = OrchestrationProjectionSnapshotQueryLive;
Expand Down Expand Up @@ -368,6 +370,7 @@ export const makeOrchestrationIntegrationHarness = (
const layer = Layer.empty.pipe(
Layer.provideMerge(runtimeServicesLayer),
Layer.provideMerge(orchestrationReactorLayer),
Layer.provideMerge(providerRegistryLayer),
Layer.provide(persistenceLayer),
Layer.provideMerge(RepositoryIdentityResolverLive),
Layer.provideMerge(ServerSettingsService.layerTest()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import {
ProviderService,
type ProviderServiceShape,
} from "../../provider/Services/ProviderService.ts";
import { makeProviderRegistryLayer } from "../../provider/testUtils/providerRegistryMock.ts";
import { TextGeneration, type TextGenerationShape } from "../../textGeneration/TextGeneration.ts";
import { RepositoryIdentityResolverLive } from "../../project/Layers/RepositoryIdentityResolver.ts";
import { OrchestrationEngineLive } from "./OrchestrationEngine.ts";
Expand Down Expand Up @@ -142,6 +143,7 @@ describe("ProviderCommandReactor", () => {
readonly baseDir?: string;
readonly threadModelSelection?: ModelSelection;
readonly sessionModelSwitch?: "unsupported" | "in-session";
readonly requiresNewThreadForModelChange?: boolean;
}) {
const now = "2026-01-01T00:00:00.000Z";
const baseDir = input?.baseDir ?? fs.mkdtempSync(path.join(os.tmpdir(), "t3code-reactor-"));
Expand Down Expand Up @@ -280,6 +282,14 @@ describe("ProviderCommandReactor", () => {
}),
),
);
const providerSnapshots = [
{
instanceId: modelSelection.instanceId,
...(input?.requiresNewThreadForModelChange === true
? { requiresNewThreadForModelChange: true }
: {}),
},
];

const unsupported = () => Effect.die(new Error("Unsupported provider call in test")) as never;
const service: ProviderServiceShape = {
Expand Down Expand Up @@ -335,6 +345,7 @@ describe("ProviderCommandReactor", () => {
Layer.provideMerge(orchestrationLayer),
Layer.provideMerge(projectionSnapshotLayer),
Layer.provideMerge(Layer.succeed(ProviderService, service)),
Layer.provideMerge(makeProviderRegistryLayer(providerSnapshots as never)),
Layer.provideMerge(
Layer.mock(GitWorkflowService)({
renameBranch,
Expand Down Expand Up @@ -879,6 +890,71 @@ describe("ProviderCommandReactor", () => {
});
});

it("rejects changing models after start when the provider requires a new thread", async () => {
const harness = await createHarness({ requiresNewThreadForModelChange: true });
const now = "2026-01-01T00:00:00.000Z";

await Effect.runPromise(
harness.engine.dispatch({
type: "thread.turn.start",
commandId: CommandId.make("cmd-turn-start-restricted-1"),
threadId: ThreadId.make("thread-1"),
message: {
messageId: asMessageId("user-message-restricted-1"),
role: "user",
text: "first",
attachments: [],
},
interactionMode: DEFAULT_PROVIDER_INTERACTION_MODE,
runtimeMode: "approval-required",
createdAt: now,
}),
);

await waitFor(() => harness.sendTurn.mock.calls.length === 1);

await Effect.runPromise(
harness.engine.dispatch({
type: "thread.turn.start",
commandId: CommandId.make("cmd-turn-start-restricted-2"),
threadId: ThreadId.make("thread-1"),
message: {
messageId: asMessageId("user-message-restricted-2"),
role: "user",
text: "second",
attachments: [],
},
modelSelection: {
instanceId: ProviderInstanceId.make("codex"),
model: "gpt-5.1-codex",
},
interactionMode: DEFAULT_PROVIDER_INTERACTION_MODE,
runtimeMode: "approval-required",
createdAt: now,
}),
);

await waitFor(async () => {
const readModel = await harness.readModel();
const thread = readModel.threads.find((entry) => entry.id === ThreadId.make("thread-1"));
return (
thread?.activities.some((activity) => activity.kind === "provider.turn.start.failed") ??
false
);
});

expect(harness.sendTurn).toHaveBeenCalledTimes(1);
const readModel = await harness.readModel();
const thread = readModel.threads.find((entry) => entry.id === ThreadId.make("thread-1"));
expect(
thread?.activities.find((activity) => activity.kind === "provider.turn.start.failed"),
).toMatchObject({
payload: {
detail: expect.stringContaining("cannot switch models after the conversation has started"),
},
});
});

it("starts a first turn on the requested provider instance even when it differs from the thread model", async () => {
const harness = await createHarness({
threadModelSelection: { instanceId: ProviderInstanceId.make("codex"), model: "gpt-5-codex" },
Expand Down
48 changes: 48 additions & 0 deletions apps/server/src/orchestration/Layers/ProviderCommandReactor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import { ProviderAdapterRequestError } from "../../provider/Errors.ts";
import type { ProviderServiceError } from "../../provider/Errors.ts";
import { TextGeneration } from "../../textGeneration/TextGeneration.ts";
import { ProviderService } from "../../provider/Services/ProviderService.ts";
import { ProviderRegistry } from "../../provider/Services/ProviderRegistry.ts";
import { OrchestrationEngineService } from "../Services/OrchestrationEngine.ts";
import { ProjectionSnapshotQuery } from "../Services/ProjectionSnapshotQuery.ts";
import {
Expand Down Expand Up @@ -180,6 +181,7 @@ const make = Effect.gen(function* () {
const orchestrationEngine = yield* OrchestrationEngineService;
const projectionSnapshotQuery = yield* ProjectionSnapshotQuery;
const providerService = yield* ProviderService;
const providerRegistry = yield* ProviderRegistry;
const gitWorkflow = yield* GitWorkflowService;
const vcsStatusBroadcaster = yield* VcsStatusBroadcaster;
const textGeneration = yield* TextGeneration;
Expand Down Expand Up @@ -305,6 +307,38 @@ const make = Effect.gen(function* () {
.pipe(Effect.map(Option.getOrUndefined));
});

const rejectStartedThreadModelChangeIfRequired = Effect.fnUntraced(function* (input: {
readonly threadId: ThreadId;
readonly currentModelSelection: ModelSelection;
readonly requestedModelSelection: ModelSelection | undefined;
}) {
const requestedModelSelection = input.requestedModelSelection;
if (
requestedModelSelection === undefined ||
(input.currentModelSelection.instanceId === requestedModelSelection.instanceId &&
input.currentModelSelection.model === requestedModelSelection.model)
) {
return;
}
const providers = yield* providerRegistry.getProviders;
const requiresNewThread =
providers.find((snapshot) => snapshot.instanceId === input.currentModelSelection.instanceId)
?.requiresNewThreadForModelChange === true ||
providers.find((snapshot) => snapshot.instanceId === requestedModelSelection.instanceId)
?.requiresNewThreadForModelChange === true;
if (!requiresNewThread) {
return;
}
return yield* new ProviderAdapterRequestError({
provider: providerErrorLabelFromInstanceHint({
instanceId: String(requestedModelSelection.instanceId),
modelSelectionInstanceId: String(input.currentModelSelection.instanceId),
}),
method: "thread.turn.start",
detail: `Thread '${input.threadId}' cannot switch models after the conversation has started. Start a new thread to use '${requestedModelSelection.model}'.`,
});
});

const ensureSessionForThread = Effect.fn("ensureSessionForThread")(function* (
threadId: ThreadId,
createdAt: string,
Expand Down Expand Up @@ -384,6 +418,20 @@ const make = Effect.gen(function* () {
});
}
const preferredProvider: ProviderDriverKind = desiredDriverKind;
if (thread.session !== null) {
yield* rejectStartedThreadModelChangeIfRequired({
threadId,
currentModelSelection:
activeSession?.model !== undefined
? {
...thread.modelSelection,
instanceId: currentInstanceId,
model: activeSession.model,
}
: thread.modelSelection,
requestedModelSelection,
});
}
if (
thread.session !== null &&
requestedModelSelection !== undefined &&
Expand Down
1 change: 1 addition & 0 deletions apps/server/src/provider/Layers/GrokProvider.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ describe("buildInitialGrokProviderSnapshot", () => {
expect(snapshot.status).toBe("warning");
expect(snapshot.version).toBeNull();
expect(snapshot.message).toContain("Checking Grok");
expect(snapshot.requiresNewThreadForModelChange).toBe(true);
});
});

Expand Down
1 change: 1 addition & 0 deletions apps/server/src/provider/Layers/GrokProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ const GROK_PRESENTATION = {
displayName: "Grok",
badgeLabel: "Early Access",
showInteractionModeToggle: false,
requiresNewThreadForModelChange: true,
} as const;
const PROVIDER = ProviderDriverKind.make("grok");
const EMPTY_CAPABILITIES: ModelCapabilities = createModelCapabilities({
Expand Down
4 changes: 4 additions & 0 deletions apps/server/src/provider/providerSnapshot.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ export interface ServerProviderPresentation {
readonly displayName: string;
readonly badgeLabel?: string;
readonly showInteractionModeToggle?: boolean;
readonly requiresNewThreadForModelChange?: boolean;
}

export type ServerProviderDraft = Omit<ServerProvider, "instanceId" | "driver">;
Expand Down Expand Up @@ -214,6 +215,9 @@ export function buildServerProvider(input: {
...(typeof input.presentation.showInteractionModeToggle === "boolean"
? { showInteractionModeToggle: input.presentation.showInteractionModeToggle }
: {}),
...(typeof input.presentation.requiresNewThreadForModelChange === "boolean"
? { requiresNewThreadForModelChange: input.presentation.requiresNewThreadForModelChange }
: {}),
enabled: input.enabled,
installed: input.probe.installed,
version: input.probe.version,
Expand Down
21 changes: 21 additions & 0 deletions apps/server/src/provider/testUtils/providerRegistryMock.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import { ProviderRegistry, type ProviderRegistryShape } from "../Services/ProviderRegistry.ts";
import type { ServerProvider } from "@t3tools/contracts";
import * as Effect from "effect/Effect";
import * as Layer from "effect/Layer";
import * as Stream from "effect/Stream";
import { makeManualOnlyProviderMaintenanceCapabilities } from "../providerMaintenance.ts";

export const makeProviderRegistryMock = (
providers: ReadonlyArray<ServerProvider> = [],
): ProviderRegistryShape => ({
getProviders: Effect.succeed(providers),
refresh: () => Effect.succeed(providers),
refreshInstance: () => Effect.succeed(providers),
getProviderMaintenanceCapabilitiesForInstance: (_instanceId, provider) =>
Effect.succeed(makeManualOnlyProviderMaintenanceCapabilities({ provider, packageName: null })),
setProviderMaintenanceActionState: () => Effect.succeed(providers),
streamChanges: Stream.empty,
});

export const makeProviderRegistryLayer = (providers: ReadonlyArray<ServerProvider> = []) =>
Layer.succeed(ProviderRegistry, makeProviderRegistryMock(providers));
68 changes: 68 additions & 0 deletions apps/web/src/components/ChatView.logic.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import {
buildExpiredTerminalContextToastCopy,
createLocalDispatchSnapshot,
deriveComposerSendState,
getStartedThreadModelChangeBlockReason,
hasServerAcknowledgedLocalDispatch,
reconcileMountedTerminalThreadIds,
resolveSendEnvMode,
Expand Down Expand Up @@ -90,6 +91,73 @@ describe("buildExpiredTerminalContextToastCopy", () => {
});
});

describe("getStartedThreadModelChangeBlockReason", () => {
const providers = [
{
instanceId: ProviderInstanceId.make("codex"),
},
{
instanceId: ProviderInstanceId.make("grok"),
requiresNewThreadForModelChange: true,
},
];

it("allows model changes before a provider session has started", () => {
expect(
getStartedThreadModelChangeBlockReason({
providers,
hasStartedSession: false,
currentModelSelection: {
instanceId: ProviderInstanceId.make("grok"),
model: "grok-build",
},
nextModelSelection: {
instanceId: ProviderInstanceId.make("grok"),
model: "grok-other",
},
}),
).toBeNull();
});

it("allows unchanged model selections for restricted providers", () => {
expect(
getStartedThreadModelChangeBlockReason({
providers,
hasStartedSession: true,
currentModelSelection: {
instanceId: ProviderInstanceId.make("grok"),
model: "grok-build",
},
nextModelSelection: {
instanceId: ProviderInstanceId.make("grok"),
model: "grok-build",
},
}),
).toBeNull();
});

it("blocks started-session model changes when either provider requires a new thread", () => {
expect(
getStartedThreadModelChangeBlockReason({
providers,
hasStartedSession: true,
currentModelSelection: {
instanceId: ProviderInstanceId.make("codex"),
model: "gpt-5.4",
},
nextModelSelection: {
instanceId: ProviderInstanceId.make("grok"),
model: "grok-build",
},
}),
).toEqual({
title: "Start a new chat to change models",
description:
"This provider does not allow switching models after a conversation has started.",
});
});
});

describe("resolveSendEnvMode", () => {
it("keeps worktree mode for git repositories", () => {
expect(resolveSendEnvMode({ requestedEnvMode: "worktree", isGitRepo: true })).toBe("worktree");
Expand Down
Loading