diff --git a/stream-android-core/src/main/java/io/getstream/android/core/internal/client/StreamClientImpl.kt b/stream-android-core/src/main/java/io/getstream/android/core/internal/client/StreamClientImpl.kt index d9fc88a..4a065ec 100644 --- a/stream-android-core/src/main/java/io/getstream/android/core/internal/client/StreamClientImpl.kt +++ b/stream-android-core/src/main/java/io/getstream/android/core/internal/client/StreamClientImpl.kt @@ -36,6 +36,7 @@ import io.getstream.android.core.api.subscribe.StreamSubscription import io.getstream.android.core.api.subscribe.StreamSubscriptionManager import io.getstream.android.core.api.utils.flatMap import io.getstream.android.core.api.utils.onTokenError +import io.getstream.android.core.api.utils.runCatchingCancellable import io.getstream.android.core.api.utils.update import io.getstream.android.core.internal.observers.StreamNetworkAndLifeCycleMonitor import io.getstream.android.core.internal.observers.StreamNetworkAndLifecycleMonitorListener @@ -133,8 +134,10 @@ internal class StreamClientImpl( .subscribe(networkAndLifecycleMonitorListener, retentionOptions) .getOrThrow() } - tokenManager - .loadIfAbsent() + // Network and Lifecycle manager must start first + networkAndLifeCycleMonitor + .start() + .flatMap { tokenManager.loadIfAbsent() } .flatMap { token -> connectSocketSession(token) } .fold( onSuccess = { connected -> @@ -150,9 +153,6 @@ internal class StreamClientImpl( Result.failure(error) }, ) - .flatMap { connectedUser -> - networkAndLifeCycleMonitor.start().map { connectedUser } - } .getOrThrow() } @@ -211,7 +211,13 @@ internal class StreamClientImpl( is Recovery.Disconnect<*> -> { logger.v { "[recovery] Disconnecting: $recovery" } - socketSession.disconnect().notifyFailure(subscriptionManager) + mutableConnectionState.update(StreamConnectionState.Disconnected()) + runCatchingCancellable { + singleFlight.cancel(connectKey).getOrThrow() + connectionIdHolder.clear().getOrThrow() + socketSession.disconnect().getOrThrow() + } + .notifyFailure(subscriptionManager) } is Recovery.Error -> { diff --git a/stream-android-core/src/test/java/io/getstream/android/core/internal/client/StreamClientIImplTest.kt b/stream-android-core/src/test/java/io/getstream/android/core/internal/client/StreamClientIImplTest.kt index 68498eb..16f55cd 100644 --- a/stream-android-core/src/test/java/io/getstream/android/core/internal/client/StreamClientIImplTest.kt +++ b/stream-android-core/src/test/java/io/getstream/android/core/internal/client/StreamClientIImplTest.kt @@ -20,6 +20,7 @@ package io.getstream.android.core.internal.client import io.getstream.android.core.api.authentication.StreamTokenManager import io.getstream.android.core.api.log.StreamLogger +import io.getstream.android.core.api.model.StreamTypedKey import io.getstream.android.core.api.model.connection.StreamConnectedUser import io.getstream.android.core.api.model.connection.StreamConnectionState import io.getstream.android.core.api.model.connection.lifecycle.StreamLifecycleState @@ -31,6 +32,10 @@ import io.getstream.android.core.api.model.exceptions.StreamEndpointErrorData import io.getstream.android.core.api.model.exceptions.StreamEndpointException import io.getstream.android.core.api.model.value.StreamToken import io.getstream.android.core.api.model.value.StreamUserId +import io.getstream.android.core.api.observers.lifecycle.StreamLifecycleListener +import io.getstream.android.core.api.observers.lifecycle.StreamLifecycleMonitor +import io.getstream.android.core.api.observers.network.StreamNetworkMonitor +import io.getstream.android.core.api.observers.network.StreamNetworkMonitorListener import io.getstream.android.core.api.processing.StreamSerialProcessingQueue import io.getstream.android.core.api.processing.StreamSingleFlightProcessor import io.getstream.android.core.api.recovery.StreamConnectionRecoveryEvaluator @@ -38,17 +43,24 @@ import io.getstream.android.core.api.socket.StreamConnectionIdHolder import io.getstream.android.core.api.socket.listeners.StreamClientListener import io.getstream.android.core.api.subscribe.StreamSubscription import io.getstream.android.core.api.subscribe.StreamSubscriptionManager +import io.getstream.android.core.api.subscribe.StreamSubscriptionManager.Options import io.getstream.android.core.internal.observers.StreamNetworkAndLifeCycleMonitor import io.getstream.android.core.internal.observers.StreamNetworkAndLifecycleMonitorListener +import io.getstream.android.core.internal.recovery.StreamConnectionRecoveryEvaluatorImpl import io.getstream.android.core.internal.socket.StreamSocketSession +import io.getstream.android.core.testing.TestLogger import io.mockk.* import kotlin.time.ExperimentalTime +import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.update +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.suspendCancellableCoroutine +import kotlinx.coroutines.test.advanceTimeBy import kotlinx.coroutines.test.advanceUntilIdle import kotlinx.coroutines.test.runTest -import org.bouncycastle.util.test.SimpleTest.runTest import org.junit.Assert.* import org.junit.Before import org.junit.Test @@ -156,6 +168,99 @@ class StreamClientIImplTest { } } + private class TestLifecycleMonitor : StreamLifecycleMonitor { + private val listeners = mutableSetOf() + private var started = false + + override fun start(): Result = Result.success(Unit).also { started = true } + + override fun stop(): Result = + Result.success(Unit).also { + started = false + listeners.clear() + } + + override fun subscribe( + listener: StreamLifecycleListener, + options: Options, + ): Result { + listeners += listener + return Result.success( + object : StreamSubscription { + override fun cancel() { + listeners -= listener + } + } + ) + } + + override fun getCurrentState(): StreamLifecycleState = StreamLifecycleState.Unknown + + fun emitBackground() { + if (!started) return + listeners.forEach { it.onBackground() } + } + + fun emitForeground() { + if (!started) return + listeners.forEach { it.onForeground() } + } + } + + private class TestNetworkMonitor : StreamNetworkMonitor { + private var listener: StreamNetworkMonitorListener? = null + private var started = false + + override fun start(): Result = Result.success(Unit).also { started = true } + + override fun stop(): Result = + Result.success(Unit).also { + started = false + listener = null + } + + override fun subscribe( + listener: StreamNetworkMonitorListener, + options: Options, + ): Result { + this.listener = listener + return Result.success( + object : StreamSubscription { + override fun cancel() { + if (this@TestNetworkMonitor.listener === listener) { + this@TestNetworkMonitor.listener = null + } + } + } + ) + } + + fun emitConnected(snapshot: StreamNetworkInfo.Snapshot?) { + if (!started) return + runBlocking { listener?.onNetworkConnected(snapshot) } + } + + fun emitLost(permanent: Boolean) { + if (!started) return + runBlocking { listener?.onNetworkLost(permanent) } + } + } + + private class ImmediateSingleFlightProcessor : StreamSingleFlightProcessor { + override suspend fun run(key: StreamTypedKey, block: suspend () -> T): Result = + runCatching { + block() + } + + override fun has(key: StreamTypedKey): Boolean = false + + override fun cancel(key: StreamTypedKey): Result = Result.success(Unit) + + override fun clear(cancelRunning: Boolean): Result = Result.success(Unit) + + override fun stop(): Result = Result.success(Unit) + } + @Test fun `connect short-circuits when already connected`() = runTest { val connectedUser = mockk(relaxed = true) @@ -457,6 +562,143 @@ class StreamClientIImplTest { assertTrue(recoveries.contains(expectedRecovery)) } + @Test + fun `recovery disconnects when backgrounding during long connect`() = runTest { + var networkListener: StreamNetworkAndLifecycleMonitorListener? = null + val networkMonitor = capturingNetworkMonitor { networkListener = it } + val recoveryEvaluator = mockk() + val expectedRecovery = Recovery.Disconnect("background") + coEvery { recoveryEvaluator.evaluate(any(), any(), any()) } returns + Result.success(expectedRecovery) + coEvery { socketSession.disconnect() } returns Result.success(Unit) + coEvery { socketSession.subscribe(any(), any()) } returns + Result.success(mockk(relaxed = true)) + coEvery { tokenManager.loadIfAbsent() } returns + Result.success(StreamToken.fromString("tok")) + coEvery { socketSession.connect(any()) } coAnswers + { + suspendCancellableCoroutine> {} + } + + val client = createClient(this, networkMonitor, recoveryEvaluator) + + val connectJob = launch { client.connect().onFailure {} } + advanceUntilIdle() + + val listener = networkListener ?: error("Network listener not registered") + listener.onNetworkAndLifecycleState( + StreamNetworkState.Disconnected, + StreamLifecycleState.Background, + ) + advanceUntilIdle() + + coVerify(exactly = 1) { recoveryEvaluator.evaluate(any(), any(), any()) } + verify(exactly = 1) { socketSession.disconnect() } + + connectJob.cancel() + } + + @Test + fun `backgrounding while initial connect is pending cancels the session`() = runTest { + val lifecycleMonitor = TestLifecycleMonitor() + val networkMonitor = TestNetworkMonitor() + val downstreamSubscriptionManager = + StreamSubscriptionManager(TestLogger) + val monitor = + StreamNetworkAndLifeCycleMonitor( + logger = TestLogger, + lifecycleMonitor = lifecycleMonitor, + networkMonitor = networkMonitor, + mutableNetworkState = MutableStateFlow(StreamNetworkState.Unknown), + mutableLifecycleState = MutableStateFlow(StreamLifecycleState.Unknown), + subscriptionManager = downstreamSubscriptionManager, + ) + val recoveryEvaluator = + StreamConnectionRecoveryEvaluatorImpl(TestLogger, ImmediateSingleFlightProcessor()) + + val firstConnectDeferred = CompletableDeferred>() + val connectedUser = mockk(relaxed = true) + val connectedState = StreamConnectionState.Connected(connectedUser, "conn-2") + coEvery { tokenManager.loadIfAbsent() } returns + Result.success(StreamToken.fromString("tok")) + every { socketSession.subscribe(any(), any()) } returns + Result.success(mockk(relaxed = true)) + every { socketSession.disconnect() } returns Result.success(Unit) + var firstCall = true + coEvery { socketSession.connect(any()) } coAnswers + { + if (firstCall) { + firstCall = false + connFlow.update { StreamConnectionState.Connecting.Opening(userId.rawValue) } + firstConnectDeferred.await() + } else { + Result.success(connectedState) + } + } + + val client = createClient(this, monitor, recoveryEvaluator) + + val connectJob = launch { client.connect().onFailure {} } + advanceUntilIdle() + + lifecycleMonitor.emitBackground() + advanceUntilIdle() + + verify(exactly = 1) { socketSession.disconnect() } + + connectJob.cancel() + firstConnectDeferred.cancel() + } + + @Test + fun `background disconnect followed by foreground reconnect succeeds unless client disconnects`() = + runTest { + val lifecycleMonitor = TestLifecycleMonitor() + val networkMonitor = TestNetworkMonitor() + val downstreamSubscriptionManager = + StreamSubscriptionManager(TestLogger) + val monitor = + StreamNetworkAndLifeCycleMonitor( + logger = TestLogger, + lifecycleMonitor = lifecycleMonitor, + networkMonitor = networkMonitor, + mutableNetworkState = MutableStateFlow(StreamNetworkState.Unknown), + mutableLifecycleState = MutableStateFlow(StreamLifecycleState.Unknown), + subscriptionManager = downstreamSubscriptionManager, + ) + val recoveryEvaluator = + StreamConnectionRecoveryEvaluatorImpl(TestLogger, ImmediateSingleFlightProcessor()) + + val connectedUser = mockk(relaxed = true) + val connectedState = StreamConnectionState.Connected(connectedUser, "conn-42") + coEvery { tokenManager.loadIfAbsent() } returns + Result.success(StreamToken.fromString("tok")) + every { socketSession.subscribe(any(), any()) } returns + Result.success(mockk(relaxed = true)) + every { socketSession.disconnect() } returns Result.success(Unit) + coEvery { socketSession.connect(any()) } returnsMany + listOf(Result.success(connectedState), Result.success(connectedState)) + every { connectionIdHolder.setConnectionId("conn-42") } returns + Result.success("conn-42") + + val client = createClient(this, monitor, recoveryEvaluator) + + client.connect().onFailure {} + advanceUntilIdle() + + lifecycleMonitor.emitBackground() + advanceUntilIdle() + verify(exactly = 1) { socketSession.disconnect() } + assertEquals(StreamConnectionState.Disconnected(), connFlow.value) + + lifecycleMonitor.emitForeground() + networkMonitor.emitConnected(StreamNetworkInfo.Snapshot()) + advanceTimeBy(1000) + advanceUntilIdle() + + coVerify(exactly = 2) { socketSession.connect(any()) } + } + @Test fun `recovery error notifies subscribers`() = runTest { var networkListener: StreamNetworkAndLifecycleMonitorListener? = null