Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions packages/agent/src/adapters/acp-connection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import {
createTappedWritableStream,
type StreamPair,
} from "@/utils/streams.js";
import { ClaudeAcpAgent } from "./claude/agent.js";
import { ClaudeAcpAgent } from "./claude/claude-agent.js";

export type AgentAdapter = "claude";

Expand Down Expand Up @@ -92,7 +92,7 @@ export function createAcpConnection(
logger.info("Cleaning up ACP connection");

if (agent) {
await agent.closeAllSessions();
await agent.closeSession();
}

// Then close the streams to properly terminate the ACP connection
Expand Down
57 changes: 41 additions & 16 deletions packages/agent/src/adapters/base-acp-agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,18 @@ import type {
WriteTextFileResponse,
} from "@agentclientprotocol/sdk";
import { Logger } from "@/utils/logger.js";
import type { SessionState } from "./types.js";

export interface BaseSession extends SessionState {
abortController: AbortController;
export interface BaseSession {
notificationHistory: SessionNotification[];
cancelled: boolean;
interruptReason?: string;
abortController: AbortController;
}

export abstract class BaseAcpAgent implements Agent {
abstract readonly adapterName: string;
protected sessions: { [key: string]: BaseSession } = {};
protected session: BaseSession | null = null;
protected sessionId: string | null = null;
client: AgentSideConnection;
protected logger: Logger;
protected fileContentCache: { [key: string]: string } = {};
Expand All @@ -38,26 +40,49 @@ export abstract class BaseAcpAgent implements Agent {
abstract initialize(request: InitializeRequest): Promise<InitializeResponse>;
abstract newSession(params: NewSessionRequest): Promise<NewSessionResponse>;
abstract prompt(params: PromptRequest): Promise<PromptResponse>;
abstract cancel(params: CancelNotification): Promise<void>;
protected abstract interruptSession(): Promise<void>;

async cancel(params: CancelNotification): Promise<void> {
if (this.sessionId !== params.sessionId || !this.session) {
throw new Error("Session not found");
}
this.session.cancelled = true;
const meta = params._meta as { interruptReason?: string } | undefined;
if (meta?.interruptReason) {
this.session.interruptReason = meta.interruptReason;
}
await this.interruptSession();
}

async closeAllSessions(): Promise<void> {
for (const [sessionId, session] of Object.entries(this.sessions)) {
try {
await this.cancel({ sessionId });
session.abortController.abort();
this.logger.info("Closed session", { sessionId });
} catch (err) {
this.logger.warn("Failed to close session", { sessionId, error: err });
}
async closeSession(): Promise<void> {
if (!this.session || !this.sessionId) {
return;
}
try {
await this.cancel({ sessionId: this.sessionId });
this.session.abortController.abort();
this.logger.info("Closed session", { sessionId: this.sessionId });
} catch (err) {
this.logger.warn("Failed to close session", {
sessionId: this.sessionId,
error: err,
});
}
this.sessions = {};
this.session = null;
this.sessionId = null;
}

hasSession(sessionId: string): boolean {
return this.sessionId === sessionId && this.session !== null;
}

appendNotification(
sessionId: string,
notification: SessionNotification,
): void {
this.sessions[sessionId]?.notificationHistory.push(notification);
if (this.sessionId === sessionId && this.session) {
this.session.notificationHistory.push(notification);
}
}

async readTextFile(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import {
type AgentSideConnection,
type AuthenticateRequest,
type AvailableCommand,
type CancelNotification,
type ClientCapabilities,
type InitializeRequest,
type InitializeResponse,
Expand Down Expand Up @@ -147,7 +146,7 @@ async function getAvailableSlashCommands(

export class ClaudeAcpAgent extends BaseAcpAgent {
readonly adapterName = "claude";
declare sessions: { [key: string]: Session };
declare session: Session | null;
toolUseCache: ToolUseCache;
backgroundTerminals: { [key: string]: BackgroundTerminal } = {};
clientCapabilities?: ClientCapabilities;
Expand Down Expand Up @@ -175,7 +174,8 @@ export class ClaudeAcpAgent extends BaseAcpAgent {
notificationHistory: [],
abortController,
};
this.sessions[sessionId] = session;
this.session = session;
this.sessionId = sessionId;
return session;
}

Expand Down Expand Up @@ -239,12 +239,11 @@ export class ClaudeAcpAgent extends BaseAcpAgent {
}, 0);
}

private getSession(sessionId: string): Session {
const session = this.sessions[sessionId];
if (!session) {
private getSession(): Session {
if (!this.session) {
throw new Error("Session not found");
}
return session;
return this.session;
}

private registerPersistence(
Expand Down Expand Up @@ -421,7 +420,7 @@ Before pushing a "workspace-*" branch to origin, rename it to something descript
}

async prompt(params: PromptRequest): Promise<PromptResponse> {
const session = this.getSession(params.sessionId);
const session = this.getSession();
session.cancelled = false;
session.interruptReason = undefined;

Expand Down Expand Up @@ -506,25 +505,22 @@ Before pushing a "workspace-*" branch to origin, rename it to something descript
throw new Error("Session did not end in result");
}

async cancel(params: CancelNotification): Promise<void> {
const session = this.getSession(params.sessionId);
session.cancelled = true;
const meta = params._meta as { interruptReason?: string } | undefined;
if (meta?.interruptReason) {
session.interruptReason = meta.interruptReason;
protected async interruptSession(): Promise<void> {
if (!this.session) {
return;
}
await session.query.interrupt();
await this.session.query.interrupt();
}

async setSessionModel(params: SetSessionModelRequest) {
const session = this.getSession(params.sessionId);
const session = this.getSession();
await session.query.setModel(params.modelId);
}

async setSessionMode(
params: SetSessionModeRequest,
): Promise<SetSessionModeResponse> {
const session = this.getSession(params.sessionId);
const session = this.getSession();

switch (params.modeId) {
case "default":
Expand Down Expand Up @@ -553,14 +549,14 @@ Before pushing a "workspace-*" branch to origin, rename it to something descript

canUseTool(sessionId: string): CanUseTool {
return async (toolName, toolInput, { suggestions, toolUseID }) => {
const session = this.sessions[sessionId];
if (!session) {
if (this.sessionId !== sessionId || !this.session) {
return {
behavior: "deny",
message: "Session not found",
interrupt: true,
};
}
const session = this.session;

const context = {
session,
Expand Down Expand Up @@ -614,7 +610,7 @@ Before pushing a "workspace-*" branch to origin, rename it to something descript
this.logger.info("[RESUME] Resuming session", { params });
const { sessionId } = params;

if (this.sessions[sessionId]) {
if (this.sessionId === sessionId && this.session) {
return {};
}

Expand Down
10 changes: 5 additions & 5 deletions packages/agent/src/adapters/claude/mcp-server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js";
import * as diff from "diff";
import { z } from "zod";
import { Logger } from "@/utils/logger.js";
import type { ClaudeAcpAgent } from "./agent.js";
import type { ClaudeAcpAgent } from "./claude-agent.js";
import { extractLinesWithByteLimit, sleep, unreachable } from "./utils.js";

export const SYSTEM_REMINDER = `
Expand Down Expand Up @@ -111,7 +111,7 @@ Usage:
},
async (input) => {
try {
if (!agent.sessions[sessionId]) {
if (!agent.hasSession(sessionId)) {
return sessionNotFound();
}

Expand Down Expand Up @@ -194,7 +194,7 @@ Usage:
},
async (input) => {
try {
if (!agent.sessions[sessionId]) {
if (!agent.hasSession(sessionId)) {
return sessionNotFound();
}
await agent.writeTextFile({
Expand Down Expand Up @@ -251,7 +251,7 @@ Usage:
},
async (input) => {
try {
if (!agent.sessions[sessionId]) {
if (!agent.hasSession(sessionId)) {
return sessionNotFound();
}

Expand Down Expand Up @@ -334,7 +334,7 @@ Output: Create directory 'foo'`),
},
},
async (input, extra) => {
if (!agent.sessions[sessionId]) {
if (!agent.hasSession(sessionId)) {
return sessionNotFound();
}

Expand Down
2 changes: 1 addition & 1 deletion packages/agent/src/adapters/claude/session-helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import type {
PermissionMode,
} from "@anthropic-ai/claude-agent-sdk";
import type { Logger } from "@/utils/logger.js";
import type { ClaudeAcpAgent } from "./agent.js";
import type { ClaudeAcpAgent } from "./claude-agent.js";
import { createPostToolUseHook } from "./hooks.js";
import { createMcpServer, toolNames } from "./mcp-server.js";
import { clearStatsigCache } from "./plan-utils.js";
Expand Down
8 changes: 2 additions & 6 deletions packages/agent/src/adapters/claude/types.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import type {
SessionNotification,
TerminalHandle,
TerminalOutputResponse,
} from "@agentclientprotocol/sdk";
Expand All @@ -10,18 +9,15 @@ import type {
SDKUserMessage,
} from "@anthropic-ai/claude-agent-sdk";
import type { Pushable } from "@/utils/streams.js";
import type { BaseSession } from "../base-acp-agent.js";

export type Session = {
export type Session = BaseSession & {
query: Query;
input: Pushable<SDKUserMessage>;
cancelled: boolean;
permissionMode: PermissionMode;
notificationHistory: SessionNotification[];
sdkSessionId?: string;
lastPlanFilePath?: string;
lastPlanContent?: string;
abortController: AbortController;
interruptReason?: string;
};

export type ToolUseCache = {
Expand Down
10 changes: 0 additions & 10 deletions packages/agent/src/adapters/types.ts

This file was deleted.

Loading