diff --git a/acp/src/commonMain/kotlin/com/agentclientprotocol/client/Client.kt b/acp/src/commonMain/kotlin/com/agentclientprotocol/client/Client.kt index 10b96e0..a38f87f 100644 --- a/acp/src/commonMain/kotlin/com/agentclientprotocol/client/Client.kt +++ b/acp/src/commonMain/kotlin/com/agentclientprotocol/client/Client.kt @@ -14,6 +14,9 @@ import kotlinx.atomicfu.update import kotlinx.collections.immutable.persistentMapOf import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.first +import kotlinx.coroutines.flow.update import kotlinx.serialization.json.JsonElement private val logger = KotlinLogging.logger {} @@ -37,6 +40,7 @@ public class Client( private val _sessions = atomic(persistentMapOf>()) private val _clientInfo = CompletableDeferred() private val _agentInfo = CompletableDeferred() + private val _currentlyInitializingSessionsCount = MutableStateFlow(0) init { // Set up request handlers for incoming agent requests @@ -158,15 +162,18 @@ public class Client( * @return a [ClientSession] instance for the new session */ public suspend fun newSession(sessionParameters: SessionCreationParameters, operationsFactory: ClientOperationsFactory): ClientSession { - val newSessionResponse = AcpMethod.AgentMethods.SessionNew(protocol, - NewSessionRequest( - sessionParameters.cwd, - sessionParameters.mcpServers, - sessionParameters._meta + return withInitializingSession { + val newSessionResponse = AcpMethod.AgentMethods.SessionNew( + protocol, + NewSessionRequest( + sessionParameters.cwd, + sessionParameters.mcpServers, + sessionParameters._meta + ) ) - ) - val sessionId = newSessionResponse.sessionId - return createSession(sessionId, sessionParameters, newSessionResponse, operationsFactory) + val sessionId = newSessionResponse.sessionId + return@withInitializingSession createSession(sessionId, sessionParameters, newSessionResponse, operationsFactory) + } } /** @@ -180,15 +187,18 @@ public class Client( * @return a [ClientSession] instance for the new session */ public suspend fun loadSession(sessionId: SessionId, sessionParameters: SessionCreationParameters, operationsFactory: ClientOperationsFactory): ClientSession { - val loadSessionResponse = AcpMethod.AgentMethods.SessionLoad(protocol, - LoadSessionRequest( - sessionId, - sessionParameters.cwd, - sessionParameters.mcpServers, - sessionParameters._meta - )) - - return createSession(sessionId, sessionParameters, loadSessionResponse, operationsFactory) + return withInitializingSession { + val loadSessionResponse = AcpMethod.AgentMethods.SessionLoad( + protocol, + LoadSessionRequest( + sessionId, + sessionParameters.cwd, + sessionParameters.mcpServers, + sessionParameters._meta + ) + ) + return@withInitializingSession createSession(sessionId, sessionParameters, loadSessionResponse, operationsFactory) + } } private suspend fun createSession(sessionId: SessionId, sessionParameters: SessionCreationParameters, sessionResponse: AcpCreatedSessionResponse, factory: ClientOperationsFactory): ClientSession { @@ -215,7 +225,27 @@ public class Client( return completableDeferred.getCompleted() } - private suspend fun getSessionOrThrow(sessionId: SessionId): ClientSessionImpl = (_sessions.value[sessionId] ?: acpFail("Session $sessionId not found")).await() + private suspend fun getSessionOrThrow(sessionId: SessionId): ClientSessionImpl { + _sessions.value[sessionId]?.let { + return it.await() + } + // try to wait for all pending sessions to initialize + _currentlyInitializingSessionsCount.first { it == 0 } + // try to get the session again + _sessions.value[sessionId]?.let { + return it.await() + } + acpFail("Session $sessionId not found") + } + + private suspend fun withInitializingSession(block: suspend () -> T): T { + _currentlyInitializingSessionsCount.update { it + 1 } + try { + return block() + } finally { + _currentlyInitializingSessionsCount.update { it - 1 } + } + } } private inline fun sessionMethodNotFound(method: AcpMethod): Nothing { diff --git a/build.gradle.kts b/build.gradle.kts index 3c3b10d..271f753 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -7,7 +7,7 @@ plugins { private val buildNumber: String? = System.getenv("GITHUB_RUN_NUMBER") private val isReleasePublication = System.getenv("RELEASE_PUBLICATION")?.toBoolean() ?: false -private val baseVersion = "0.7.1" +private val baseVersion = "0.7.2" allprojects { group = "com.agentclientprotocol"