diff --git a/package.json b/package.json index 33cefcb..e74860d 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "mcpcat", - "version": "0.1.7", + "version": "0.1.8", "description": "Analytics tool for MCP (Model Context Protocol) servers - tracks tool usage patterns and provides insights", "type": "module", "main": "dist/index.js", diff --git a/src/index.ts b/src/index.ts index b128d14..50b6c9f 100644 --- a/src/index.ts +++ b/src/index.ts @@ -16,7 +16,10 @@ import { writeToLog } from "./modules/logging.js"; import { setupMCPCatTools } from "./modules/tools.js"; import { setupToolCallTracing } from "./modules/tracing.js"; import { getSessionInfo, newSessionId } from "./modules/session.js"; -import { setServerTrackingData } from "./modules/internal.js"; +import { + setServerTrackingData, + getServerTrackingData, +} from "./modules/internal.js"; import { setupTracking } from "./modules/tracingV2.js"; import { TelemetryManager } from "./modules/telemetry.js"; import { setTelemetryManager } from "./modules/eventQueue.js"; @@ -136,6 +139,15 @@ function track( : validatedServer ) as MCPServerLike; + // Check if server is already being tracked + const existingData = getServerTrackingData(lowLevelServer); + if (existingData) { + writeToLog( + "[SESSION DEBUG] track() - Server already being tracked, skipping initialization", + ); + return validatedServer; + } + // Initialize telemetry if exporters are configured if (options.exporters) { const telemetryManager = new TelemetryManager(options.exporters); @@ -167,6 +179,7 @@ function track( identify: options.identify, redactSensitiveInformation: options.redactSensitiveInformation, }, + sessionSource: "mcpcat", // Initially MCPCat-generated, will change to "mcp" if MCP sessionId is provided }; setServerTrackingData(lowLevelServer, mcpcatData); diff --git a/src/modules/internal.ts b/src/modules/internal.ts index 10ec3da..eaa8f9b 100644 --- a/src/modules/internal.ts +++ b/src/modules/internal.ts @@ -1,4 +1,130 @@ -import { MCPCatData, MCPServerLike, UserIdentity } from "../types.js"; +import { + MCPCatData, + MCPServerLike, + UserIdentity, + CompatibleRequestHandlerExtra, + UnredactedEvent, +} from "../types.js"; +import { PublishEventRequestEventTypeEnum } from "mcpcat-api"; +import { publishEvent } from "./eventQueue.js"; +import { getMCPCompatibleErrorMessage } from "./compatibility.js"; +import { writeToLog } from "./logging.js"; +import { INACTIVITY_TIMEOUT_IN_MINUTES } from "./constants.js"; + +/** + * Simple LRU cache for session identities. + * Prevents memory leaks by capping at maxSize entries. + * This cache persists across server instance restarts. + */ +class IdentityCache { + private cache: Map; + private maxSize: number; + + constructor(maxSize: number = 1000) { + this.cache = new Map(); + this.maxSize = maxSize; + } + + get(sessionId: string): UserIdentity | undefined { + const entry = this.cache.get(sessionId); + if (entry) { + // Update timestamp on access (LRU behavior) + entry.timestamp = Date.now(); + // Move to end (most recently used) + this.cache.delete(sessionId); + this.cache.set(sessionId, entry); + return entry.identity; + } + return undefined; + } + + set(sessionId: string, identity: UserIdentity): void { + // Remove if already exists (to re-add at end) + this.cache.delete(sessionId); + + // Evict oldest if at capacity + if (this.cache.size >= this.maxSize) { + const oldestKey = this.cache.keys().next().value; + if (oldestKey !== undefined) { + this.cache.delete(oldestKey); + } + } + + this.cache.set(sessionId, { identity, timestamp: Date.now() }); + } + + has(sessionId: string): boolean { + return this.cache.has(sessionId); + } + + size(): number { + return this.cache.size; + } +} + +// Global identity cache shared across all server instances +// This prevents duplicate identify events when server objects are recreated +const _globalIdentityCache = new IdentityCache(1000); + +/** + * Maps userId to recent session IDs for reconnection support. + * When a user reconnects (new initialize without MCP sessionId), + * we can reuse their previous session if it's recent enough. + */ +class UserSessionCache { + private cache: Map; + private maxSize: number; + + constructor(maxSize: number = 1000) { + this.cache = new Map(); + this.maxSize = maxSize; + } + + getRecentSession(userId: string, timeoutMs: number): string | undefined { + const entry = this.cache.get(userId); + if (!entry) return undefined; + + // Check if session has expired + if (Date.now() - entry.lastSeen > timeoutMs) { + this.cache.delete(userId); + return undefined; + } + + return entry.sessionId; + } + + set(userId: string, sessionId: string): void { + // Remove if already exists (to re-add at end for LRU) + this.cache.delete(userId); + + // Evict oldest if at capacity + if (this.cache.size >= this.maxSize) { + const oldestKey = this.cache.keys().next().value; + if (oldestKey !== undefined) { + this.cache.delete(oldestKey); + } + } + + this.cache.set(userId, { sessionId, lastSeen: Date.now() }); + } +} + +// Global user session cache for reconnection support +const _globalUserSessionCache = new UserSessionCache(1000); + +/** + * FOR TESTING ONLY: Manually set a user session cache entry with custom lastSeen timestamp + */ +export function _testSetUserSession( + userId: string, + sessionId: string, + lastSeenMs: number, +): void { + (_globalUserSessionCache as any).cache.set(userId, { + sessionId, + lastSeen: lastSeenMs, + }); +} // Internal tracking storage const _serverTracking = new WeakMap(); @@ -61,3 +187,147 @@ export function mergeIdentities( }, }; } + +/** + * Handles user identification for a request. + * Calls the identify function if configured, compares with previous identity, + * and publishes an identify event only if the identity has changed. + * + * @param server - The MCP server instance + * @param data - The server tracking data + * @param request - The request object to pass to identify function + * @param extra - Optional extra parameters containing headers, sessionId, etc. + */ +export async function handleIdentify( + server: MCPServerLike, + data: MCPCatData, + request: any, + extra?: CompatibleRequestHandlerExtra, +): Promise { + if (!data.options.identify) { + return; + } + + const sessionId = data.sessionId; + let identifyEvent: UnredactedEvent = { + sessionId: sessionId, + resourceName: request.params?.name || "Unknown", + eventType: PublishEventRequestEventTypeEnum.mcpcatIdentify, + parameters: { + request: request, + extra: extra, + }, + timestamp: new Date(), + redactionFn: data.options.redactSensitiveInformation, + }; + + try { + const identityResult = await data.options.identify(request, extra); + if (identityResult) { + // Check for session reconnection (if no MCP sessionId provided in extra) + // If this user had a recent session, switch to it instead of creating new one + if (!extra?.sessionId && identityResult.userId) { + const timeoutMs = INACTIVITY_TIMEOUT_IN_MINUTES * 60 * 1000; + const previousSessionId = _globalUserSessionCache.getRecentSession( + identityResult.userId, + timeoutMs, + ); + + if (previousSessionId && previousSessionId !== data.sessionId) { + // User has a previous session - reconnect to it + const currentSessionIdentity = _globalIdentityCache.get( + data.sessionId, + ); + + if (!currentSessionIdentity) { + // Current session is brand new (no identity) - reconnect to previous session + data.sessionId = previousSessionId; + data.lastActivity = new Date(); + setServerTrackingData(server, data); + + writeToLog( + `Reconnected user ${identityResult.userId} to previous session ${previousSessionId} (current session was new)`, + ); + } else if (currentSessionIdentity.userId !== identityResult.userId) { + // Current session belongs to different user - reconnect to user's previous session + data.sessionId = previousSessionId; + data.lastActivity = new Date(); + setServerTrackingData(server, data); + + writeToLog( + `Reconnected user ${identityResult.userId} to previous session ${previousSessionId}`, + ); + } + // If current session already belongs to this user, no need to do anything + } else if (!previousSessionId) { + // User has NO previous session - check if current session belongs to someone else + const currentSessionIdentity = _globalIdentityCache.get( + data.sessionId, + ); + if ( + currentSessionIdentity && + currentSessionIdentity.userId !== identityResult.userId + ) { + // Current session belongs to different user - create new session + const { newSessionId } = await import("./session.js"); + data.sessionId = newSessionId(); + data.sessionSource = "mcpcat"; + data.lastActivity = new Date(); + setServerTrackingData(server, data); + + writeToLog( + `Created new session ${data.sessionId} for user ${identityResult.userId} (previous session belonged to ${currentSessionIdentity.userId})`, + ); + } + } + } + + // Now use the (possibly updated) sessionId for all subsequent operations + const currentSessionId = data.sessionId; + + // Check global cache first (works across server instance restarts) + const previousIdentity = _globalIdentityCache.get(currentSessionId); + + // Merge identities (overwrite userId/userName, merge userData) + const mergedIdentity = mergeIdentities(previousIdentity, identityResult); + + // Only publish if identity has changed + const hasChanged = + !previousIdentity || + !areIdentitiesEqual(previousIdentity, mergedIdentity); + + // Update BOTH caches to keep them in sync + // Global cache: persists across server instances + _globalIdentityCache.set(currentSessionId, mergedIdentity); + // Per-server cache: used by getSessionInfo() for fast local access + data.identifiedSessions.set(data.sessionId, mergedIdentity); + + // Track userId → sessionId mapping for reconnection support + _globalUserSessionCache.set(mergedIdentity.userId, currentSessionId); + + if (hasChanged) { + writeToLog( + `Identified session ${currentSessionId} with identity: ${JSON.stringify(mergedIdentity)}`, + ); + publishEvent(server, identifyEvent); + } + } else { + writeToLog( + `Warning: Supplied identify function returned null for session ${sessionId}`, + ); + } + } catch (error) { + writeToLog( + `Warning: Supplied identify function threw an error while identifying session ${sessionId} - ${error}`, + ); + identifyEvent.duration = + (identifyEvent.timestamp && + new Date().getTime() - identifyEvent.timestamp.getTime()) || + undefined; + identifyEvent.isError = true; + identifyEvent.error = { + message: getMCPCompatibleErrorMessage(error), + }; + publishEvent(server, identifyEvent); + } +} diff --git a/src/modules/session.ts b/src/modules/session.ts index ee2557e..3a915cf 100644 --- a/src/modules/session.ts +++ b/src/modules/session.ts @@ -3,10 +3,12 @@ import { MCPServerLike, ServerClientInfoLike, SessionInfo, + CompatibleRequestHandlerExtra, } from "../types.js"; import { getServerTrackingData, setServerTrackingData } from "./internal.js"; import KSUID from "../thirdparty/ksuid/index.js"; import packageJson from "../../package.json" with { type: "json" }; +import { createHash } from "crypto"; import { INACTIVITY_TIMEOUT_IN_MINUTES } from "./constants.js"; @@ -14,18 +16,89 @@ export function newSessionId(): string { return KSUID.withPrefix("ses").randomSync(); } -export function getServerSessionId(server: MCPServerLike): string { +/** + * Creates a deterministic KSUID session ID from an MCP sessionId and optional projectId. + * The same inputs will always produce the same session ID, enabling correlation across server restarts. + * + * @param mcpSessionId - The session ID from the MCP protocol + * @param projectId - Optional MCPCat project ID to include in the hash + * @returns A KSUID with "ses" prefix derived deterministically from the inputs + */ +export function deriveSessionIdFromMCPSession( + mcpSessionId: string, + projectId?: string, +): string { + // Create input string for hashing + const input = projectId ? `${mcpSessionId}:${projectId}` : mcpSessionId; + + // Hash the input with SHA-256 + const hash = createHash("sha256").update(input).digest(); + + // Extract timestamp from first 4 bytes of hash (for deterministic but reasonable timestamp) + // We'll use a fixed epoch (2024-01-01) plus the hash value to get a deterministic but valid timestamp + const EPOCH_2024 = new Date("2024-01-01T00:00:00Z").getTime(); + const timestampOffset = hash.readUInt32BE(0) % (365 * 24 * 60 * 60 * 1000); // Max 1 year offset + const timestamp = EPOCH_2024 + timestampOffset; + + // Use the remaining 16 bytes of hash as the KSUID payload + const payload = hash.subarray(4, 20); + + // Create deterministic KSUID with prefix + return KSUID.withPrefix("ses").fromParts(timestamp, payload); +} + +/** + * Gets or generates a session ID for the server. + * Prioritizes MCP protocol sessionId over MCPCat-generated sessionId. + * + * @param server - The MCP server instance + * @param extra - Optional extra data containing MCP sessionId + * @returns The session ID to use for events + */ +export function getServerSessionId( + server: MCPServerLike, + extra?: CompatibleRequestHandlerExtra, +): string { const data = getServerTrackingData(server); if (!data) { throw new Error("Server tracking data not found"); } + const mcpSessionId = extra?.sessionId; + + // If MCP sessionId is provided + if (mcpSessionId) { + // Check if it's a new or changed MCP sessionId + if (mcpSessionId !== data.lastMcpSessionId) { + // Derive deterministic KSUID from MCP sessionId + data.sessionId = deriveSessionIdFromMCPSession( + mcpSessionId, + data.projectId || undefined, + ); + data.lastMcpSessionId = mcpSessionId; + data.sessionSource = "mcp"; + setServerTrackingData(server, data); + } + // If MCP sessionId hasn't changed, continue using the existing derived KSUID + setLastActivity(server); + return data.sessionId; + } + + // No MCP sessionId provided - handle MCPCat-generated sessions + // If we had an MCP sessionId before but it disappeared, keep using the last derived ID + if (data.sessionSource === "mcp" && data.lastMcpSessionId) { + setLastActivity(server); + return data.sessionId; + } + + // For MCPCat-generated sessions, apply timeout logic const now = Date.now(); const timeoutMs = INACTIVITY_TIMEOUT_IN_MINUTES * 60 * 1000; // If last activity timed out if (now - data.lastActivity.getTime() > timeoutMs) { data.sessionId = newSessionId(); + data.sessionSource = "mcpcat"; setServerTrackingData(server, data); } setLastActivity(server); diff --git a/src/modules/tracing.ts b/src/modules/tracing.ts index 16b8041..f77f2bc 100644 --- a/src/modules/tracing.ts +++ b/src/modules/tracing.ts @@ -11,11 +11,7 @@ import { } from "../types.js"; import { writeToLog } from "./logging.js"; import { handleReportMissing } from "./tools.js"; -import { - getServerTrackingData, - areIdentitiesEqual, - mergeIdentities, -} from "./internal.js"; +import { getServerTrackingData, handleIdentify } from "./internal.js"; import { getServerSessionId } from "./session.js"; import { PublishEventRequestEventTypeEnum } from "mcpcat-api"; import { publishEvent } from "./eventQueue.js"; @@ -25,8 +21,8 @@ function isToolResultError(result: any): boolean { return result && typeof result === "object" && result.isError === true; } -// Track if we've already set up list tools tracing -let listToolsTracingSetup = false; +// Track if we've already set up list tools tracing per server instance +const listToolsTracingSetup = new WeakMap(); export function setupListToolsTracing( highLevelServer: HighLevelMCPServerLike, @@ -39,8 +35,8 @@ export function setupListToolsTracing( return; } - // Check if we've already set up tracing - if (listToolsTracingSetup) { + // Check if we've already set up tracing for this server instance + if (listToolsTracingSetup.get(server)) { return; } @@ -57,7 +53,7 @@ export function setupListToolsTracing( let tools: any[] = []; const data = getServerTrackingData(server); let event: UnredactedEvent = { - sessionId: getServerSessionId(server), + sessionId: getServerSessionId(server, extra), parameters: { request: request, extra: extra, @@ -117,8 +113,8 @@ export function setupListToolsTracing( return { tools }; }); - // Mark as setup successful - listToolsTracingSetup = true; + // Mark as setup successful for this server instance + listToolsTracingSetup.set(server, true); } catch (error) { writeToLog(`Warning: Failed to override list tools handler - ${error}`); } @@ -143,7 +139,11 @@ export function setupInitializeTracing( return await originalInitializeHandler(request, extra); } - const sessionId = getServerSessionId(server); + const sessionId = getServerSessionId(server, extra); + + // Try to identify the session + await handleIdentify(server, data, request, extra); + let event: UnredactedEvent = { sessionId: sessionId, resourceName: request.params?.name || "Unknown Tool Name", @@ -183,7 +183,11 @@ export function setupToolCallTracing(server: MCPServerLike): void { return await originalInitializeHandler(request, extra); } - const sessionId = getServerSessionId(server); + const sessionId = getServerSessionId(server, extra); + + // Try to identify the session + await handleIdentify(server, data, request, extra); + let event: UnredactedEvent = { sessionId: sessionId, resourceName: request.params?.name || "Unknown Tool Name", @@ -212,7 +216,7 @@ export function setupToolCallTracing(server: MCPServerLike): void { return await originalCallToolHandler?.(request, extra); } - const sessionId = getServerSessionId(server); + const sessionId = getServerSessionId(server, extra); let event: UnredactedEvent = { sessionId: sessionId, resourceName: request.params?.name || "Unknown Tool Name", @@ -227,58 +231,7 @@ export function setupToolCallTracing(server: MCPServerLike): void { try { // Try to identify the session if we haven't already and identify function is provided - if (data.options.identify) { - let identifyEvent: UnredactedEvent = { - ...event, - eventType: PublishEventRequestEventTypeEnum.mcpcatIdentify, - }; - try { - const identityResult = await data.options.identify(request, extra); - if (identityResult) { - // Get previous identity for this session - const previousIdentity = data.identifiedSessions.get(sessionId); - - // Merge identities (overwrite userId/userName, merge userData) - const mergedIdentity = mergeIdentities( - previousIdentity, - identityResult, - ); - - // Only publish if identity has changed - const hasChanged = - !previousIdentity || - !areIdentitiesEqual(previousIdentity, mergedIdentity); - - // Always update the stored identity with the merged version FIRST - // so that publishEvent can get the latest identity in sessionInfo - data.identifiedSessions.set(sessionId, mergedIdentity); - - if (hasChanged) { - writeToLog( - `Identified session ${sessionId} with identity: ${JSON.stringify(mergedIdentity)}`, - ); - publishEvent(server, identifyEvent); - } - } else { - writeToLog( - `Warning: Supplied identify function returned null for session ${sessionId}`, - ); - } - } catch (error) { - writeToLog( - `Warning: Supplied identify function threw an error while identifying session ${sessionId} - ${error}`, - ); - identifyEvent.duration = - (identifyEvent.timestamp && - new Date().getTime() - identifyEvent.timestamp.getTime()) || - undefined; - identifyEvent.isError = true; - identifyEvent.error = { - message: getMCPCompatibleErrorMessage(error), - }; - publishEvent(server, identifyEvent); - } - } + await handleIdentify(server, data, request, extra); // Check for missing context if enableToolCallContext is true and it's not report_missing if ( diff --git a/src/modules/tracingV2.ts b/src/modules/tracingV2.ts index 3e0bff1..aaf0518 100644 --- a/src/modules/tracingV2.ts +++ b/src/modules/tracingV2.ts @@ -8,11 +8,7 @@ import { CompatibleRequestHandlerExtra, } from "../types.js"; import { writeToLog } from "./logging.js"; -import { - getServerTrackingData, - areIdentitiesEqual, - mergeIdentities, -} from "./internal.js"; +import { getServerTrackingData, handleIdentify } from "./internal.js"; import { getServerSessionId } from "./session.js"; import { PublishEventRequestEventTypeEnum } from "mcpcat-api"; import { publishEvent } from "./eventQueue.js"; @@ -308,7 +304,7 @@ function addTracingToToolCallback( )(cleanedArgs, extra)); } - const sessionId = getServerSessionId(lowLevelServer); + const sessionId = getServerSessionId(lowLevelServer, extra); // Create a request-like object for compatibility with existing code const request = { @@ -332,58 +328,10 @@ function addTracingToToolCallback( try { // Try to identify the session if identify function is provided - if (data.options.identify) { - let identifyEvent: UnredactedEvent = { - ...event, - eventType: PublishEventRequestEventTypeEnum.mcpcatIdentify, - }; - try { - const identityResult = await data.options.identify(request, extra); - if (identityResult) { - // Get previous identity for this session - const previousIdentity = data.identifiedSessions.get(sessionId); - - // Merge identities (overwrite userId/userName, merge userData) - const mergedIdentity = mergeIdentities( - previousIdentity, - identityResult, - ); + await handleIdentify(lowLevelServer, data, request, extra); - // Only publish if identity has changed - const hasChanged = - !previousIdentity || - !areIdentitiesEqual(previousIdentity, mergedIdentity); - - // Always update the stored identity with the merged version FIRST - // so that publishEvent can get the latest identity in sessionInfo - data.identifiedSessions.set(sessionId, mergedIdentity); - - if (hasChanged) { - writeToLog( - `Identified session ${sessionId} with identity: ${JSON.stringify(mergedIdentity)}`, - ); - publishEvent(lowLevelServer, identifyEvent); - } - } else { - writeToLog( - `Warning: Supplied identify function returned null for session ${sessionId}`, - ); - } - } catch (error) { - writeToLog( - `Warning: Supplied identify function threw an error while identifying session ${sessionId} - ${error}`, - ); - identifyEvent.duration = - (identifyEvent.timestamp && - new Date().getTime() - identifyEvent.timestamp.getTime()) || - undefined; - identifyEvent.isError = true; - identifyEvent.error = { - message: getMCPCompatibleErrorMessage(error), - }; - publishEvent(lowLevelServer, identifyEvent); - } - } + // Update event sessionId in case handleIdentify reconnected to a different session + event.sessionId = data.sessionId; // Extract context for userIntent if present if (args && typeof args === "object" && "context" in args) { diff --git a/src/tests/identify.test.ts b/src/tests/identify.test.ts index c167f77..dcd273e 100644 --- a/src/tests/identify.test.ts +++ b/src/tests/identify.test.ts @@ -753,7 +753,7 @@ describe("Identify Feature", () => { }); describe("Identity Merging Behavior", () => { - it("should override userId/userName but merge userData fields", async () => { + it("should create separate sessions for different users", async () => { const eventCapture = new EventCapture(); await eventCapture.start(); @@ -780,15 +780,15 @@ describe("Identify Feature", () => { userId: secondUserId, userName: "Bob", userData: { - department: "Sales", // This should overwrite - location: "NYC", // This should be added + department: "Sales", + location: "NYC", }, }; } }, }); - // First tool call - sets initial identity + // First tool call - sets initial identity for Alice await client.request( { method: "tools/call", @@ -796,7 +796,7 @@ describe("Identify Feature", () => { name: "add_todo", arguments: { text: "First todo", - context: "Testing identity merge", + context: "Testing identity separation", }, }, }, @@ -806,12 +806,12 @@ describe("Identify Feature", () => { // Wait for first identify to complete await new Promise((resolve) => setTimeout(resolve, 50)); - // Verify first identity was stored + // Verify Alice's identity was stored const data = getServerTrackingData(server.server); - const sessionId = data?.sessionId; - let storedIdentity = data?.identifiedSessions.get(sessionId!); + const aliceSessionId = data?.sessionId; + const aliceIdentity = data?.identifiedSessions.get(aliceSessionId!); - expect(storedIdentity).toEqual({ + expect(aliceIdentity).toEqual({ userId: firstUserId, userName: "Alice", userData: { @@ -820,14 +820,14 @@ describe("Identify Feature", () => { }, }); - // Second tool call - should merge identities + // Second tool call - Bob should get his own NEW session (not take over Alice's) await client.request( { method: "tools/call", params: { name: "list_todos", arguments: { - context: "Testing identity merge again", + context: "Testing identity separation again", }, }, }, @@ -837,20 +837,35 @@ describe("Identify Feature", () => { // Wait for second identify to complete await new Promise((resolve) => setTimeout(resolve, 50)); - // Verify merged identity - storedIdentity = data?.identifiedSessions.get(sessionId!); + // Verify Bob got his own session (different from Alice's) + const bobSessionId = data?.sessionId; + expect(bobSessionId).not.toEqual(aliceSessionId); - expect(storedIdentity).toEqual({ - userId: secondUserId, // Overwritten - userName: "Bob", // Overwritten + // Verify Bob's identity is stored in his own session + const bobIdentity = data?.identifiedSessions.get(bobSessionId!); + expect(bobIdentity).toEqual({ + userId: secondUserId, + userName: "Bob", + userData: { + department: "Sales", + location: "NYC", + }, + }); + + // Verify Alice's session still has her identity (unchanged) + const aliceIdentityStillThere = data?.identifiedSessions.get( + aliceSessionId!, + ); + expect(aliceIdentityStillThere).toEqual({ + userId: firstUserId, + userName: "Alice", userData: { - role: "admin", // Preserved from first call - department: "Sales", // Overwritten from first call - location: "NYC", // Added in second call + role: "admin", + department: "Engineering", }, }); - // Verify two identify events were published (both represent changes) + // Verify two identify events were published (one for each user) const events = eventCapture.getEvents(); const identifyEvents = events.filter( (e) => e.eventType === PublishEventRequestEventTypeEnum.mcpcatIdentify, diff --git a/src/tests/session-id.test.ts b/src/tests/session-id.test.ts new file mode 100644 index 0000000..f9c8be0 --- /dev/null +++ b/src/tests/session-id.test.ts @@ -0,0 +1,694 @@ +import { describe, it, expect, beforeEach, afterEach } from "vitest"; +import { + setupTestServerAndClient, + resetTodos, +} from "./test-utils/client-server-factory"; +import { track } from "../index"; +import { CallToolResultSchema } from "@modelcontextprotocol/sdk/types"; +import { EventCapture } from "./test-utils"; +import { getServerTrackingData } from "../modules/internal"; +import { HighLevelMCPServerLike } from "../types"; +import { + deriveSessionIdFromMCPSession, + getServerSessionId, +} from "../modules/session"; + +describe("Session ID Management", () => { + let server: HighLevelMCPServerLike; + let client: any; + let cleanup: () => Promise; + + beforeEach(async () => { + resetTodos(); + const setup = await setupTestServerAndClient(); + server = setup.server; + client = setup.client; + cleanup = setup.cleanup; + }); + + afterEach(async () => { + await cleanup(); + }); + + describe("Deterministic KSUID Derivation", () => { + it("should generate deterministic session IDs from the same MCP sessionId", () => { + const mcpSessionId = "test-session-123"; + const projectId = "proj_abc"; + + const sessionId1 = deriveSessionIdFromMCPSession(mcpSessionId, projectId); + const sessionId2 = deriveSessionIdFromMCPSession(mcpSessionId, projectId); + + expect(sessionId1).toBe(sessionId2); + expect(sessionId1).toMatch(/^ses_/); + }); + + it("should generate different session IDs for different MCP sessionIds", () => { + const projectId = "proj_abc"; + + const sessionId1 = deriveSessionIdFromMCPSession("session-1", projectId); + const sessionId2 = deriveSessionIdFromMCPSession("session-2", projectId); + + expect(sessionId1).not.toBe(sessionId2); + expect(sessionId1).toMatch(/^ses_/); + expect(sessionId2).toMatch(/^ses_/); + }); + + it("should generate different session IDs for different projectIds", () => { + const mcpSessionId = "test-session-123"; + + const sessionId1 = deriveSessionIdFromMCPSession( + mcpSessionId, + "proj_abc", + ); + const sessionId2 = deriveSessionIdFromMCPSession( + mcpSessionId, + "proj_xyz", + ); + + expect(sessionId1).not.toBe(sessionId2); + expect(sessionId1).toMatch(/^ses_/); + expect(sessionId2).toMatch(/^ses_/); + }); + + it("should handle missing projectId gracefully", () => { + const mcpSessionId = "test-session-123"; + + const sessionId1 = deriveSessionIdFromMCPSession(mcpSessionId); + const sessionId2 = deriveSessionIdFromMCPSession(mcpSessionId); + const sessionId3 = deriveSessionIdFromMCPSession(mcpSessionId, undefined); + + expect(sessionId1).toBe(sessionId2); + expect(sessionId1).toBe(sessionId3); + expect(sessionId1).toMatch(/^ses_/); + }); + + it("should generate different session IDs when projectId is present vs absent", () => { + const mcpSessionId = "test-session-123"; + + const sessionIdWithProject = deriveSessionIdFromMCPSession( + mcpSessionId, + "proj_abc", + ); + const sessionIdWithoutProject = + deriveSessionIdFromMCPSession(mcpSessionId); + + expect(sessionIdWithProject).not.toBe(sessionIdWithoutProject); + }); + }); + + describe("MCP SessionId Prioritization", () => { + it("should use MCP sessionId when provided in extra parameter", async () => { + const eventCapture = new EventCapture(); + await eventCapture.start(); + + const projectId = "test-project-mcp"; + const mcpSessionId = "mcp-session-abc-123"; + + track(server, projectId, { + enableTracing: true, + }); + + // Get the low-level server + const lowLevelServer = server.server; + + // Simulate MCP sessionId in extra parameter + const extra = { sessionId: mcpSessionId }; + + // Get session ID with MCP sessionId provided + const sessionId = getServerSessionId(lowLevelServer, extra); + + // Verify it's deterministically derived + const expectedSessionId = deriveSessionIdFromMCPSession( + mcpSessionId, + projectId, + ); + expect(sessionId).toBe(expectedSessionId); + + // Verify tracking data is updated + const data = getServerTrackingData(lowLevelServer); + expect(data?.lastMcpSessionId).toBe(mcpSessionId); + expect(data?.sessionSource).toBe("mcp"); + + await eventCapture.stop(); + }); + + it("should use MCPCat-generated sessionId when no MCP sessionId provided", async () => { + const eventCapture = new EventCapture(); + await eventCapture.start(); + + track(server, "test-project", { + enableTracing: true, + }); + + const lowLevelServer = server.server; + + // Get initial session ID without MCP sessionId + const sessionId1 = getServerSessionId(lowLevelServer); + expect(sessionId1).toMatch(/^ses_/); + + // Verify tracking data shows MCPCat source + const data = getServerTrackingData(lowLevelServer); + expect(data?.sessionSource).toBe("mcpcat"); + expect(data?.lastMcpSessionId).toBeUndefined(); + + // Get session ID again - should be the same + const sessionId2 = getServerSessionId(lowLevelServer); + expect(sessionId2).toBe(sessionId1); + + await eventCapture.stop(); + }); + + it("should switch to MCP-derived sessionId when MCP sessionId appears", async () => { + const eventCapture = new EventCapture(); + await eventCapture.start(); + + const projectId = "test-project-switch"; + const mcpSessionId = "mcp-session-appears"; + + track(server, projectId, { + enableTracing: true, + }); + + const lowLevelServer = server.server; + + // Start with no MCP sessionId + const mcpcatSessionId = getServerSessionId(lowLevelServer); + expect(mcpcatSessionId).toMatch(/^ses_/); + + let data = getServerTrackingData(lowLevelServer); + expect(data?.sessionSource).toBe("mcpcat"); + + // Now provide MCP sessionId + const extra = { sessionId: mcpSessionId }; + const mcpDerivedSessionId = getServerSessionId(lowLevelServer, extra); + + // Verify it switched to MCP-derived ID + const expectedSessionId = deriveSessionIdFromMCPSession( + mcpSessionId, + projectId, + ); + expect(mcpDerivedSessionId).toBe(expectedSessionId); + expect(mcpDerivedSessionId).not.toBe(mcpcatSessionId); + + // Verify tracking data is updated + data = getServerTrackingData(lowLevelServer); + expect(data?.lastMcpSessionId).toBe(mcpSessionId); + expect(data?.sessionSource).toBe("mcp"); + + await eventCapture.stop(); + }); + + it("should keep last derived sessionId when MCP sessionId disappears", async () => { + const eventCapture = new EventCapture(); + await eventCapture.start(); + + const projectId = "test-project-disappear"; + const mcpSessionId = "mcp-session-disappears"; + + track(server, projectId, { + enableTracing: true, + }); + + const lowLevelServer = server.server; + + // Provide MCP sessionId + const extra = { sessionId: mcpSessionId }; + const mcpDerivedSessionId = getServerSessionId(lowLevelServer, extra); + + const expectedSessionId = deriveSessionIdFromMCPSession( + mcpSessionId, + projectId, + ); + expect(mcpDerivedSessionId).toBe(expectedSessionId); + + // Now call without MCP sessionId (it disappeared) + const sessionIdAfterDisappear = getServerSessionId(lowLevelServer); + + // Should keep using the last derived sessionId + expect(sessionIdAfterDisappear).toBe(mcpDerivedSessionId); + + // Verify tracking data still shows MCP source + const data = getServerTrackingData(lowLevelServer); + expect(data?.sessionSource).toBe("mcp"); + expect(data?.lastMcpSessionId).toBe(mcpSessionId); + + await eventCapture.stop(); + }); + + it("should regenerate sessionId when MCP sessionId changes", async () => { + const eventCapture = new EventCapture(); + await eventCapture.start(); + + const projectId = "test-project-change"; + const mcpSessionId1 = "mcp-session-first"; + const mcpSessionId2 = "mcp-session-second"; + + track(server, projectId, { + enableTracing: true, + }); + + const lowLevelServer = server.server; + + // Provide first MCP sessionId + const extra1 = { sessionId: mcpSessionId1 }; + const sessionId1 = getServerSessionId(lowLevelServer, extra1); + + const expectedSessionId1 = deriveSessionIdFromMCPSession( + mcpSessionId1, + projectId, + ); + expect(sessionId1).toBe(expectedSessionId1); + + // Change to second MCP sessionId + const extra2 = { sessionId: mcpSessionId2 }; + const sessionId2 = getServerSessionId(lowLevelServer, extra2); + + const expectedSessionId2 = deriveSessionIdFromMCPSession( + mcpSessionId2, + projectId, + ); + expect(sessionId2).toBe(expectedSessionId2); + expect(sessionId2).not.toBe(sessionId1); + + // Verify tracking data is updated + const data = getServerTrackingData(lowLevelServer); + expect(data?.lastMcpSessionId).toBe(mcpSessionId2); + expect(data?.sessionSource).toBe("mcp"); + + await eventCapture.stop(); + }); + }); + + describe("Session Timeout Behavior", () => { + it("should NOT apply timeout to MCP-derived sessions", async () => { + const eventCapture = new EventCapture(); + await eventCapture.start(); + + const projectId = "test-project-timeout"; + const mcpSessionId = "mcp-session-persistent"; + + track(server, projectId, { + enableTracing: true, + }); + + const lowLevelServer = server.server; + + // Get MCP-derived session ID + const extra = { sessionId: mcpSessionId }; + const sessionId1 = getServerSessionId(lowLevelServer, extra); + + // Manually set lastActivity to simulate timeout (31 minutes ago) + const data = getServerTrackingData(lowLevelServer); + if (data) { + data.lastActivity = new Date(Date.now() - 31 * 60 * 1000); + } + + // Get session ID again with same MCP sessionId + const sessionId2 = getServerSessionId(lowLevelServer, extra); + + // Should still be the same (no timeout for MCP sessions) + expect(sessionId2).toBe(sessionId1); + + await eventCapture.stop(); + }); + + it("should apply timeout to MCPCat-generated sessions", async () => { + const eventCapture = new EventCapture(); + await eventCapture.start(); + + track(server, "test-project", { + enableTracing: true, + }); + + const lowLevelServer = server.server; + + // Get MCPCat-generated session ID + const sessionId1 = getServerSessionId(lowLevelServer); + + // Manually set lastActivity to simulate timeout (31 minutes ago) + const data = getServerTrackingData(lowLevelServer); + if (data) { + data.lastActivity = new Date(Date.now() - 31 * 60 * 1000); + } + + // Get session ID again without MCP sessionId + const sessionId2 = getServerSessionId(lowLevelServer); + + // Should be different (timeout occurred) + expect(sessionId2).not.toBe(sessionId1); + expect(sessionId2).toMatch(/^ses_/); + + await eventCapture.stop(); + }); + }); + + describe("Event Publishing with Session IDs", () => { + it("should publish events with MCP-derived session IDs", async () => { + const eventCapture = new EventCapture(); + await eventCapture.start(); + + const projectId = "test-project-events"; + const mcpSessionId = "mcp-session-for-events"; + + track(server, projectId, { + enableTracing: true, + }); + + // TODO: This test would require mocking the transport to inject sessionId into extra + // For now, we'll verify the logic with direct function calls above + // In a real MCP environment, the sessionId would come from the transport layer + + await eventCapture.stop(); + }); + }); + + describe("Session Reconnection", () => { + it("should reconnect user to previous session when reinitializing without MCP sessionId", async () => { + const eventCapture = new EventCapture(); + await eventCapture.start(); + + const projectId = "test-project-reconnect"; + let identifyCallCount = 0; + + track(server, projectId, { + enableTracing: true, + identify: async (request: any, extra?: any) => { + identifyCallCount++; + return { + userId: "user-123", + userName: "Test User", + }; + }, + }); + + const lowLevelServer = server.server; + + // First initialize - creates initial session + const request1 = { + method: "initialize", + params: { + protocolVersion: "1.0", + capabilities: {}, + clientInfo: { name: "test-client", version: "1.0" }, + }, + }; + await lowLevelServer._requestHandlers.get("initialize")?.(request1, {}); + + const data1 = getServerTrackingData(lowLevelServer); + const firstSessionId = data1?.sessionId; + expect(firstSessionId).toMatch(/^ses_/); + expect(identifyCallCount).toBe(1); + + // Second initialize WITHOUT MCP sessionId - should reconnect to first session + const request2 = { + method: "initialize", + params: { + protocolVersion: "1.0", + capabilities: {}, + clientInfo: { name: "test-client", version: "1.0" }, + }, + }; + await lowLevelServer._requestHandlers.get("initialize")?.(request2, {}); + + const data2 = getServerTrackingData(lowLevelServer); + const secondSessionId = data2?.sessionId; + + // Should reuse the first session + expect(secondSessionId).toBe(firstSessionId); + expect(identifyCallCount).toBe(2); + + await eventCapture.stop(); + }); + + it("should create new session if previous session expired (>30 min)", async () => { + const eventCapture = new EventCapture(); + await eventCapture.start(); + + const projectId = "test-project-timeout-reconnect"; + let identifyCallCount = 0; + + track(server, projectId, { + enableTracing: true, + identify: async (request: any, extra?: any) => { + identifyCallCount++; + return { + userId: "user-456", + userName: "Test User", + }; + }, + }); + + const lowLevelServer = server.server; + + // First initialize + const request1 = { + method: "initialize", + params: { + protocolVersion: "1.0", + capabilities: {}, + clientInfo: { name: "test-client", version: "1.0" }, + }, + }; + await lowLevelServer._requestHandlers.get("initialize")?.(request1, {}); + + const data1 = getServerTrackingData(lowLevelServer); + const firstSessionId = data1?.sessionId; + + // Manually expire the session by setting lastActivity to 31 minutes ago + // Also update the cache entry to reflect this expiration + if (data1) { + data1.lastActivity = new Date(Date.now() - 31 * 60 * 1000); + // Update cache entry to 31 minutes ago + const { _testSetUserSession } = await import("../modules/internal.js"); + _testSetUserSession( + "user-456", + firstSessionId!, + Date.now() - 31 * 60 * 1000, + ); + } + + // Second initialize - session expired, should get new session + const request2 = { + method: "initialize", + params: { + protocolVersion: "1.0", + capabilities: {}, + clientInfo: { name: "test-client", version: "1.0" }, + }, + }; + await lowLevelServer._requestHandlers.get("initialize")?.(request2, {}); + + const data2 = getServerTrackingData(lowLevelServer); + const secondSessionId = data2?.sessionId; + + // Should have a different session (expired) + expect(secondSessionId).not.toBe(firstSessionId); + expect(secondSessionId).toMatch(/^ses_/); + expect(identifyCallCount).toBe(2); + + await eventCapture.stop(); + }); + + it("should NOT reconnect when MCP sessionId is provided (MCP takes priority)", async () => { + const eventCapture = new EventCapture(); + await eventCapture.start(); + + const projectId = "test-project-mcp-priority"; + const mcpSessionId1 = "mcp-session-aaa"; + const mcpSessionId2 = "mcp-session-bbb"; + + track(server, projectId, { + enableTracing: true, + identify: async (request: any, extra?: any) => { + return { + userId: "user-789", + userName: "Test User", + }; + }, + }); + + const lowLevelServer = server.server; + + // First initialize with MCP sessionId + const request1 = { + method: "initialize", + params: { + protocolVersion: "1.0", + capabilities: {}, + clientInfo: { name: "test-client", version: "1.0" }, + }, + }; + await lowLevelServer._requestHandlers.get("initialize")?.(request1, { + sessionId: mcpSessionId1, + }); + + const data1 = getServerTrackingData(lowLevelServer); + const firstSessionId = data1?.sessionId; + const expectedSessionId1 = deriveSessionIdFromMCPSession( + mcpSessionId1, + projectId, + ); + expect(firstSessionId).toBe(expectedSessionId1); + + // Second initialize with DIFFERENT MCP sessionId + // Should use the new MCP sessionId, NOT reconnect to previous + const request2 = { + method: "initialize", + params: { + protocolVersion: "1.0", + capabilities: {}, + clientInfo: { name: "test-client", version: "1.0" }, + }, + }; + await lowLevelServer._requestHandlers.get("initialize")?.(request2, { + sessionId: mcpSessionId2, + }); + + const data2 = getServerTrackingData(lowLevelServer); + const secondSessionId = data2?.sessionId; + const expectedSessionId2 = deriveSessionIdFromMCPSession( + mcpSessionId2, + projectId, + ); + + // Should use new MCP-derived session, not reconnect + expect(secondSessionId).toBe(expectedSessionId2); + expect(secondSessionId).not.toBe(firstSessionId); + + await eventCapture.stop(); + }); + + it("should create new session when no identify function configured", async () => { + const eventCapture = new EventCapture(); + await eventCapture.start(); + + const projectId = "test-project-no-identify"; + + track(server, projectId, { + enableTracing: true, + // No identify function + }); + + const lowLevelServer = server.server; + + // First initialize + const request1 = { + method: "initialize", + params: { + protocolVersion: "1.0", + capabilities: {}, + clientInfo: { name: "test-client", version: "1.0" }, + }, + }; + await lowLevelServer._requestHandlers.get("initialize")?.(request1, {}); + + const data1 = getServerTrackingData(lowLevelServer); + const firstSessionId = data1?.sessionId; + + // Second initialize - no identify, should timeout and create new session + // Simulate timeout + if (data1) { + data1.lastActivity = new Date(Date.now() - 31 * 60 * 1000); + } + + const request2 = { + method: "initialize", + params: { + protocolVersion: "1.0", + capabilities: {}, + clientInfo: { name: "test-client", version: "1.0" }, + }, + }; + await lowLevelServer._requestHandlers.get("initialize")?.(request2, {}); + + const data2 = getServerTrackingData(lowLevelServer); + const secondSessionId = data2?.sessionId; + + // Should have different session (no reconnection without identify) + expect(secondSessionId).not.toBe(firstSessionId); + + await eventCapture.stop(); + }); + + it("should handle different users reconnecting to their own sessions", async () => { + const eventCapture = new EventCapture(); + await eventCapture.start(); + + const projectId = "test-project-multi-user"; + let currentUserId = "user-alice"; + + track(server, projectId, { + enableTracing: true, + identify: async (request: any, extra?: any) => { + return { + userId: currentUserId, + userName: currentUserId === "user-alice" ? "Alice" : "Bob", + }; + }, + }); + + const lowLevelServer = server.server; + + // Alice's first session + currentUserId = "user-alice"; + const request1 = { + method: "initialize", + params: { + protocolVersion: "1.0", + capabilities: {}, + clientInfo: { name: "test-client", version: "1.0" }, + }, + }; + await lowLevelServer._requestHandlers.get("initialize")?.(request1, {}); + const aliceSession1 = getServerTrackingData(lowLevelServer)?.sessionId; + + // Bob's first session + currentUserId = "user-bob"; + const request2 = { + method: "initialize", + params: { + protocolVersion: "1.0", + capabilities: {}, + clientInfo: { name: "test-client", version: "1.0" }, + }, + }; + await lowLevelServer._requestHandlers.get("initialize")?.(request2, {}); + const bobSession1 = getServerTrackingData(lowLevelServer)?.sessionId; + + // Different users should have different sessions + expect(aliceSession1).not.toBe(bobSession1); + + // Alice reconnects - should get her original session back + currentUserId = "user-alice"; + const request3 = { + method: "initialize", + params: { + protocolVersion: "1.0", + capabilities: {}, + clientInfo: { name: "test-client", version: "1.0" }, + }, + }; + await lowLevelServer._requestHandlers.get("initialize")?.(request3, {}); + const aliceSession2 = getServerTrackingData(lowLevelServer)?.sessionId; + + expect(aliceSession2).toBe(aliceSession1); + + // Bob reconnects - should get his original session back + currentUserId = "user-bob"; + const request4 = { + method: "initialize", + params: { + protocolVersion: "1.0", + capabilities: {}, + clientInfo: { name: "test-client", version: "1.0" }, + }, + }; + await lowLevelServer._requestHandlers.get("initialize")?.(request4, {}); + const bobSession2 = getServerTrackingData(lowLevelServer)?.sessionId; + + expect(bobSession2).toBe(bobSession1); + + await eventCapture.stop(); + }); + }); +}); diff --git a/src/types.ts b/src/types.ts index 2f6e0e9..3604a50 100644 --- a/src/types.ts +++ b/src/types.ts @@ -162,9 +162,11 @@ export interface SessionInfo { export interface MCPCatData { projectId: string; // Project ID for MCPCat - sessionId: string; // Unique identifier for the session + sessionId: string; // Unique identifier for the session (KSUID with ses prefix) lastActivity: Date; // Last activity timestamp identifiedSessions: Map; sessionInfo: SessionInfo; options: MCPCatOptions; + lastMcpSessionId?: string; // Track the last MCP sessionId we saw + sessionSource: "mcp" | "mcpcat"; // Track whether session ID came from MCP protocol or MCPCat generation }