Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import io.getstream.android.core.api.subscribe.StreamSubscription
import io.getstream.android.core.api.subscribe.StreamSubscriptionManager
import io.getstream.android.core.api.utils.flatMap
import io.getstream.android.core.api.utils.onTokenError
import io.getstream.android.core.api.utils.runCatchingCancellable
import io.getstream.android.core.api.utils.update
import io.getstream.android.core.internal.observers.StreamNetworkAndLifeCycleMonitor
import io.getstream.android.core.internal.observers.StreamNetworkAndLifecycleMonitorListener
Expand Down Expand Up @@ -133,8 +134,10 @@ internal class StreamClientImpl<T>(
.subscribe(networkAndLifecycleMonitorListener, retentionOptions)
.getOrThrow()
}
tokenManager
.loadIfAbsent()
// Network and Lifecycle manager must start first
networkAndLifeCycleMonitor
.start()
.flatMap { tokenManager.loadIfAbsent() }
.flatMap { token -> connectSocketSession(token) }
.fold(
onSuccess = { connected ->
Expand All @@ -150,9 +153,6 @@ internal class StreamClientImpl<T>(
Result.failure(error)
},
)
.flatMap { connectedUser ->
networkAndLifeCycleMonitor.start().map { connectedUser }
}
.getOrThrow()
}

Expand Down Expand Up @@ -211,7 +211,13 @@ internal class StreamClientImpl<T>(

is Recovery.Disconnect<*> -> {
logger.v { "[recovery] Disconnecting: $recovery" }
socketSession.disconnect().notifyFailure(subscriptionManager)
mutableConnectionState.update(StreamConnectionState.Disconnected())
runCatchingCancellable {
singleFlight.cancel(connectKey).getOrThrow()
connectionIdHolder.clear().getOrThrow()
socketSession.disconnect().getOrThrow()
}
.notifyFailure(subscriptionManager)
}

is Recovery.Error -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package io.getstream.android.core.internal.client

import io.getstream.android.core.api.authentication.StreamTokenManager
import io.getstream.android.core.api.log.StreamLogger
import io.getstream.android.core.api.model.StreamTypedKey
import io.getstream.android.core.api.model.connection.StreamConnectedUser
import io.getstream.android.core.api.model.connection.StreamConnectionState
import io.getstream.android.core.api.model.connection.lifecycle.StreamLifecycleState
Expand All @@ -31,24 +32,35 @@ import io.getstream.android.core.api.model.exceptions.StreamEndpointErrorData
import io.getstream.android.core.api.model.exceptions.StreamEndpointException
import io.getstream.android.core.api.model.value.StreamToken
import io.getstream.android.core.api.model.value.StreamUserId
import io.getstream.android.core.api.observers.lifecycle.StreamLifecycleListener
import io.getstream.android.core.api.observers.lifecycle.StreamLifecycleMonitor
import io.getstream.android.core.api.observers.network.StreamNetworkMonitor
import io.getstream.android.core.api.observers.network.StreamNetworkMonitorListener
import io.getstream.android.core.api.processing.StreamSerialProcessingQueue
import io.getstream.android.core.api.processing.StreamSingleFlightProcessor
import io.getstream.android.core.api.recovery.StreamConnectionRecoveryEvaluator
import io.getstream.android.core.api.socket.StreamConnectionIdHolder
import io.getstream.android.core.api.socket.listeners.StreamClientListener
import io.getstream.android.core.api.subscribe.StreamSubscription
import io.getstream.android.core.api.subscribe.StreamSubscriptionManager
import io.getstream.android.core.api.subscribe.StreamSubscriptionManager.Options
import io.getstream.android.core.internal.observers.StreamNetworkAndLifeCycleMonitor
import io.getstream.android.core.internal.observers.StreamNetworkAndLifecycleMonitorListener
import io.getstream.android.core.internal.recovery.StreamConnectionRecoveryEvaluatorImpl
import io.getstream.android.core.internal.socket.StreamSocketSession
import io.getstream.android.core.testing.TestLogger
import io.mockk.*
import kotlin.time.ExperimentalTime
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.update
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.suspendCancellableCoroutine
import kotlinx.coroutines.test.advanceTimeBy
import kotlinx.coroutines.test.advanceUntilIdle
import kotlinx.coroutines.test.runTest
import org.bouncycastle.util.test.SimpleTest.runTest
import org.junit.Assert.*
import org.junit.Before
import org.junit.Test
Expand Down Expand Up @@ -156,6 +168,99 @@ class StreamClientIImplTest {
}
}

private class TestLifecycleMonitor : StreamLifecycleMonitor {
private val listeners = mutableSetOf<StreamLifecycleListener>()
private var started = false

override fun start(): Result<Unit> = Result.success(Unit).also { started = true }

override fun stop(): Result<Unit> =
Result.success(Unit).also {
started = false
listeners.clear()
}

override fun subscribe(
listener: StreamLifecycleListener,
options: Options,
): Result<StreamSubscription> {
listeners += listener
return Result.success(
object : StreamSubscription {
override fun cancel() {
listeners -= listener
}
}
)
}

override fun getCurrentState(): StreamLifecycleState = StreamLifecycleState.Unknown

fun emitBackground() {
if (!started) return
listeners.forEach { it.onBackground() }
}

fun emitForeground() {
if (!started) return
listeners.forEach { it.onForeground() }
}
}

private class TestNetworkMonitor : StreamNetworkMonitor {
private var listener: StreamNetworkMonitorListener? = null
private var started = false

override fun start(): Result<Unit> = Result.success(Unit).also { started = true }

override fun stop(): Result<Unit> =
Result.success(Unit).also {
started = false
listener = null
}

override fun subscribe(
listener: StreamNetworkMonitorListener,
options: Options,
): Result<StreamSubscription> {
this.listener = listener
return Result.success(
object : StreamSubscription {
override fun cancel() {
if (this@TestNetworkMonitor.listener === listener) {
this@TestNetworkMonitor.listener = null
}
}
}
)
}

fun emitConnected(snapshot: StreamNetworkInfo.Snapshot?) {
if (!started) return
runBlocking { listener?.onNetworkConnected(snapshot) }
}

fun emitLost(permanent: Boolean) {
if (!started) return
runBlocking { listener?.onNetworkLost(permanent) }
}
}

private class ImmediateSingleFlightProcessor : StreamSingleFlightProcessor {
override suspend fun <T> run(key: StreamTypedKey<T>, block: suspend () -> T): Result<T> =
runCatching {
block()
}

override fun <T> has(key: StreamTypedKey<T>): Boolean = false

override fun <T> cancel(key: StreamTypedKey<T>): Result<Unit> = Result.success(Unit)

override fun clear(cancelRunning: Boolean): Result<Unit> = Result.success(Unit)

override fun stop(): Result<Unit> = Result.success(Unit)
}

@Test
fun `connect short-circuits when already connected`() = runTest {
val connectedUser = mockk<StreamConnectedUser>(relaxed = true)
Expand Down Expand Up @@ -457,6 +562,143 @@ class StreamClientIImplTest {
assertTrue(recoveries.contains(expectedRecovery))
}

@Test
fun `recovery disconnects when backgrounding during long connect`() = runTest {
var networkListener: StreamNetworkAndLifecycleMonitorListener? = null
val networkMonitor = capturingNetworkMonitor { networkListener = it }
val recoveryEvaluator = mockk<StreamConnectionRecoveryEvaluator>()
val expectedRecovery = Recovery.Disconnect("background")
coEvery { recoveryEvaluator.evaluate(any(), any(), any()) } returns
Result.success(expectedRecovery)
coEvery { socketSession.disconnect() } returns Result.success(Unit)
coEvery { socketSession.subscribe(any<StreamClientListener>(), any()) } returns
Result.success(mockk(relaxed = true))
coEvery { tokenManager.loadIfAbsent() } returns
Result.success(StreamToken.fromString("tok"))
coEvery { socketSession.connect(any()) } coAnswers
{
suspendCancellableCoroutine<Result<StreamConnectionState.Connected>> {}
}

val client = createClient(this, networkMonitor, recoveryEvaluator)

val connectJob = launch { client.connect().onFailure {} }
advanceUntilIdle()

val listener = networkListener ?: error("Network listener not registered")
listener.onNetworkAndLifecycleState(
StreamNetworkState.Disconnected,
StreamLifecycleState.Background,
)
advanceUntilIdle()

coVerify(exactly = 1) { recoveryEvaluator.evaluate(any(), any(), any()) }
verify(exactly = 1) { socketSession.disconnect() }

connectJob.cancel()
}

@Test
fun `backgrounding while initial connect is pending cancels the session`() = runTest {
val lifecycleMonitor = TestLifecycleMonitor()
val networkMonitor = TestNetworkMonitor()
val downstreamSubscriptionManager =
StreamSubscriptionManager<StreamNetworkAndLifecycleMonitorListener>(TestLogger)
val monitor =
StreamNetworkAndLifeCycleMonitor(
logger = TestLogger,
lifecycleMonitor = lifecycleMonitor,
networkMonitor = networkMonitor,
mutableNetworkState = MutableStateFlow(StreamNetworkState.Unknown),
mutableLifecycleState = MutableStateFlow(StreamLifecycleState.Unknown),
subscriptionManager = downstreamSubscriptionManager,
)
val recoveryEvaluator =
StreamConnectionRecoveryEvaluatorImpl(TestLogger, ImmediateSingleFlightProcessor())

val firstConnectDeferred = CompletableDeferred<Result<StreamConnectionState.Connected>>()
val connectedUser = mockk<StreamConnectedUser>(relaxed = true)
val connectedState = StreamConnectionState.Connected(connectedUser, "conn-2")
coEvery { tokenManager.loadIfAbsent() } returns
Result.success(StreamToken.fromString("tok"))
every { socketSession.subscribe(any<StreamClientListener>(), any()) } returns
Result.success(mockk(relaxed = true))
every { socketSession.disconnect() } returns Result.success(Unit)
var firstCall = true
coEvery { socketSession.connect(any()) } coAnswers
{
if (firstCall) {
firstCall = false
connFlow.update { StreamConnectionState.Connecting.Opening(userId.rawValue) }
firstConnectDeferred.await()
} else {
Result.success(connectedState)
}
}

val client = createClient(this, monitor, recoveryEvaluator)

val connectJob = launch { client.connect().onFailure {} }
advanceUntilIdle()

lifecycleMonitor.emitBackground()
advanceUntilIdle()

verify(exactly = 1) { socketSession.disconnect() }

connectJob.cancel()
firstConnectDeferred.cancel()
}

@Test
fun `background disconnect followed by foreground reconnect succeeds unless client disconnects`() =
runTest {
val lifecycleMonitor = TestLifecycleMonitor()
val networkMonitor = TestNetworkMonitor()
val downstreamSubscriptionManager =
StreamSubscriptionManager<StreamNetworkAndLifecycleMonitorListener>(TestLogger)
val monitor =
StreamNetworkAndLifeCycleMonitor(
logger = TestLogger,
lifecycleMonitor = lifecycleMonitor,
networkMonitor = networkMonitor,
mutableNetworkState = MutableStateFlow(StreamNetworkState.Unknown),
mutableLifecycleState = MutableStateFlow(StreamLifecycleState.Unknown),
subscriptionManager = downstreamSubscriptionManager,
)
val recoveryEvaluator =
StreamConnectionRecoveryEvaluatorImpl(TestLogger, ImmediateSingleFlightProcessor())

val connectedUser = mockk<StreamConnectedUser>(relaxed = true)
val connectedState = StreamConnectionState.Connected(connectedUser, "conn-42")
coEvery { tokenManager.loadIfAbsent() } returns
Result.success(StreamToken.fromString("tok"))
every { socketSession.subscribe(any<StreamClientListener>(), any()) } returns
Result.success(mockk(relaxed = true))
every { socketSession.disconnect() } returns Result.success(Unit)
coEvery { socketSession.connect(any()) } returnsMany
listOf(Result.success(connectedState), Result.success(connectedState))
every { connectionIdHolder.setConnectionId("conn-42") } returns
Result.success("conn-42")

val client = createClient(this, monitor, recoveryEvaluator)

client.connect().onFailure {}
advanceUntilIdle()

lifecycleMonitor.emitBackground()
advanceUntilIdle()
verify(exactly = 1) { socketSession.disconnect() }
assertEquals(StreamConnectionState.Disconnected(), connFlow.value)

lifecycleMonitor.emitForeground()
networkMonitor.emitConnected(StreamNetworkInfo.Snapshot())
advanceTimeBy(1000)
advanceUntilIdle()

coVerify(exactly = 2) { socketSession.connect(any()) }
}

@Test
fun `recovery error notifies subscribers`() = runTest {
var networkListener: StreamNetworkAndLifecycleMonitorListener? = null
Expand Down