diff --git a/CHANGELOG.md b/CHANGELOG.md index 0eb84b790d2..6b36985afd0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). - Fix token provider retrying after calling disconnect [#3052](https://github.com/GetStream/stream-chat-swift/pull/3052) - Fix connect user never completing when disconnecting after token provider fails [#3052](https://github.com/GetStream/stream-chat-swift/pull/3052) - Fix current user cache not deleted on logout causing unread count issues after switching users [#3055](https://github.com/GetStream/stream-chat-swift/pull/3055) +- Fix invalid token triggering token refresh in an infinite loop [#3056](https://github.com/GetStream/stream-chat-swift/pull/3056) # [4.49.0](https://github.com/GetStream/stream-chat-swift/releases/tag/4.49.0) _February 27, 2024_ diff --git a/Sources/StreamChat/ChatClient.swift b/Sources/StreamChat/ChatClient.swift index 3d72b1ca10d..106c6a87ce2 100644 --- a/Sources/StreamChat/ChatClient.swift +++ b/Sources/StreamChat/ChatClient.swift @@ -451,9 +451,12 @@ extension ChatClient: AuthenticationRepositoryDelegate { extension ChatClient: ConnectionStateDelegate { func webSocketClient(_ client: WebSocketClient, didUpdateConnectionState state: WebSocketConnectionState) { - connectionRepository.handleConnectionUpdate(state: state, onInvalidToken: { [weak self] in - self?.refreshToken(completion: nil) - }) + connectionRepository.handleConnectionUpdate( + state: state, + onExpiredToken: { [weak self] in + self?.refreshToken(completion: nil) + } + ) connectionRecoveryHandler?.webSocketClient(client, didUpdateConnectionState: state) } } diff --git a/Sources/StreamChat/Errors/ErrorPayload.swift b/Sources/StreamChat/Errors/ErrorPayload.swift index 73c4d58ed9c..c2b6c48c258 100644 --- a/Sources/StreamChat/Errors/ErrorPayload.swift +++ b/Sources/StreamChat/Errors/ErrorPayload.swift @@ -29,7 +29,7 @@ public struct ErrorPayload: LocalizedError, Codable, CustomDebugStringConvertibl } /// https://getstream.io/chat/docs/ios-swift/api_errors_response/ -private enum StreamErrorCode { +enum StreamErrorCode { /// Usually returned when trying to perform an API call without a token. static let accessKeyInvalid = 2 static let expiredToken = 40 diff --git a/Sources/StreamChat/Repositories/ConnectionRepository.swift b/Sources/StreamChat/Repositories/ConnectionRepository.swift index dda4066b30b..cc32d79f6f1 100644 --- a/Sources/StreamChat/Repositories/ConnectionRepository.swift +++ b/Sources/StreamChat/Repositories/ConnectionRepository.swift @@ -128,7 +128,7 @@ class ConnectionRepository { func handleConnectionUpdate( state: WebSocketConnectionState, - onInvalidToken: () -> Void + onExpiredToken: () -> Void ) { connectionStatus = .init(webSocketConnectionState: state) @@ -140,9 +140,10 @@ class ConnectionRepository { case let .connected(connectionId: id): shouldNotifyConnectionIdWaiters = true connectionId = id - case let .disconnected(source) where source.serverError?.isInvalidTokenError == true: - onInvalidToken() + if source.serverError?.isExpiredTokenError == true { + onExpiredToken() + } shouldNotifyConnectionIdWaiters = false connectionId = nil case .disconnected: diff --git a/TestTools/StreamChatTestTools/Mocks/StreamChat/ConnectionRepository_Mock.swift b/TestTools/StreamChatTestTools/Mocks/StreamChat/ConnectionRepository_Mock.swift index 7059254a056..89afc1336cb 100644 --- a/TestTools/StreamChatTestTools/Mocks/StreamChat/ConnectionRepository_Mock.swift +++ b/TestTools/StreamChatTestTools/Mocks/StreamChat/ConnectionRepository_Mock.swift @@ -15,7 +15,7 @@ final class ConnectionRepository_Mock: ConnectionRepository, Spy { static let updateWebSocketEndpointUserId = "updateWebSocketEndpoint(with:)" static let completeConnectionIdWaiters = "completeConnectionIdWaiters(connectionId:)" static let provideConnectionId = "provideConnectionId(timeout:completion:)" - static let handleConnectionUpdate = "handleConnectionUpdate(state:onInvalidToken:)" + static let handleConnectionUpdate = "handleConnectionUpdate(state:onExpiredToken:)" } var recordedFunctions: [String] = [] @@ -28,7 +28,7 @@ final class ConnectionRepository_Mock: ConnectionRepository, Spy { var updateWebSocketEndpointUserInfo: UserInfo? var completeWaitersConnectionId: ConnectionId? var connectionUpdateState: WebSocketConnectionState? - var simulateInvalidTokenOnConnectionUpdate = false + var simulateExpiredTokenOnConnectionUpdate = false convenience init() { self.init(isClientInActiveMode: true, @@ -100,11 +100,11 @@ final class ConnectionRepository_Mock: ConnectionRepository, Spy { record() } - override func handleConnectionUpdate(state: WebSocketConnectionState, onInvalidToken: () -> Void) { + override func handleConnectionUpdate(state: WebSocketConnectionState, onExpiredToken: () -> Void) { record() connectionUpdateState = state - if simulateInvalidTokenOnConnectionUpdate { - onInvalidToken() + if simulateExpiredTokenOnConnectionUpdate { + onExpiredToken() } } @@ -117,7 +117,7 @@ final class ConnectionRepository_Mock: ConnectionRepository, Spy { disconnectResult = nil disconnectSource = nil - simulateInvalidTokenOnConnectionUpdate = false + simulateExpiredTokenOnConnectionUpdate = false connectionUpdateState = nil completeWaitersConnectionId = nil updateWebSocketEndpointToken = nil diff --git a/Tests/StreamChatTests/ChatClient_Tests.swift b/Tests/StreamChatTests/ChatClient_Tests.swift index 37944f1c243..25e88047793 100644 --- a/Tests/StreamChatTests/ChatClient_Tests.swift +++ b/Tests/StreamChatTests/ChatClient_Tests.swift @@ -758,14 +758,14 @@ final class ChatClient_Tests: XCTestCase { XCTAssertNotCall(AuthenticationRepository_Mock.Signature.refreshToken, on: authenticationRepository) } - func test_webSocketClientStateUpdate_calls_connectionRepository_invalidToken() throws { + func test_webSocketClientStateUpdate_calls_connectionRepository_expiredToken() throws { let client = ChatClient(config: inMemoryStorageConfig, environment: testEnv.environment) let webSocketClient = try XCTUnwrap(client.webSocketClient) let connectionRepository = try XCTUnwrap(client.connectionRepository as? ConnectionRepository_Mock) let authenticationRepository = try XCTUnwrap(client.authenticationRepository as? AuthenticationRepository_Mock) let state = WebSocketConnectionState.disconnected(source: .systemInitiated) - connectionRepository.simulateInvalidTokenOnConnectionUpdate = true + connectionRepository.simulateExpiredTokenOnConnectionUpdate = true client.webSocketClient(webSocketClient, didUpdateConnectionState: state) XCTAssertCall(ConnectionRepository_Mock.Signature.handleConnectionUpdate, on: connectionRepository) diff --git a/Tests/StreamChatTests/Repositories/ConnectionRepository_Tests.swift b/Tests/StreamChatTests/Repositories/ConnectionRepository_Tests.swift index d71aed7ba42..a1b5ca9fcaa 100644 --- a/Tests/StreamChatTests/Repositories/ConnectionRepository_Tests.swift +++ b/Tests/StreamChatTests/Repositories/ConnectionRepository_Tests.swift @@ -280,7 +280,7 @@ final class ConnectionRepository_Tests: XCTestCase { ] for (webSocketState, connectionStatus) in pairs { - repository.handleConnectionUpdate(state: webSocketState, onInvalidToken: {}) + repository.handleConnectionUpdate(state: webSocketState, onExpiredToken: {}) XCTAssertEqual(repository.connectionStatus, connectionStatus) } } @@ -326,7 +326,7 @@ final class ConnectionRepository_Tests: XCTestCase { } } - repository.handleConnectionUpdate(state: webSocketState, onInvalidToken: {}) + repository.handleConnectionUpdate(state: webSocketState, onExpiredToken: {}) if shouldNotify { waitForExpectations(timeout: defaultTimeout) @@ -357,28 +357,44 @@ final class ConnectionRepository_Tests: XCTestCase { repository.completeConnectionIdWaiters(connectionId: originalConnectionId) XCTAssertEqual(repository.connectionId, originalConnectionId) - repository.handleConnectionUpdate(state: webSocketState, onInvalidToken: {}) + repository.handleConnectionUpdate(state: webSocketState, onExpiredToken: {}) XCTAssertEqual(repository.connectionId, newConnectionIdValue) } } - func test_handleConnectionUpdate_whenInvalidToken_shouldExecuteInvalidTokenBlock() { - let expectation = self.expectation(description: "Invalid Token Block Executed") + func test_handleConnectionUpdate_whenExpiredToken_shouldExecuteExpiredTokenBlock() { + let expectation = self.expectation(description: "Expired Token Block Not Executed") + let expiredTokenError = ClientError(with: ErrorPayload( + code: StreamErrorCode.expiredToken, + message: .unique, + statusCode: .unique + )) + + repository.handleConnectionUpdate(state: .disconnected(source: .serverInitiated(error: expiredTokenError)), onExpiredToken: { + expectation.fulfill() + }) + + waitForExpectations(timeout: defaultTimeout) + } + + func test_handleConnectionUpdate_whenInvalidToken_shouldNotExecuteExpiredTokenBlock() { + let expectation = self.expectation(description: "Expired Token Block Not Executed") + expectation.isInverted = true let invalidTokenError = ClientError(with: ErrorPayload( - code: .random(in: ClosedRange.tokenInvalidErrorCodes), + code: StreamErrorCode.invalidTokenSignature, message: .unique, statusCode: .unique )) - repository.handleConnectionUpdate(state: .disconnected(source: .serverInitiated(error: invalidTokenError)), onInvalidToken: { + repository.handleConnectionUpdate(state: .disconnected(source: .serverInitiated(error: invalidTokenError)), onExpiredToken: { expectation.fulfill() }) waitForExpectations(timeout: defaultTimeout) } - func test_handleConnectionUpdate_whenInvalidToken_whenDisconnecting_shouldNOTExecuteInvalidTokenBlock() { + func test_handleConnectionUpdate_whenInvalidToken_whenDisconnecting_shouldNOTExecuteRefreshTokenBlock() { // We only want to refresh the token when it is actually disconnected, not while it is disconnecting, otherwise we trigger refresh token twice. let invalidTokenError = ClientError(with: ErrorPayload( code: .random(in: ClosedRange.tokenInvalidErrorCodes), @@ -386,16 +402,16 @@ final class ConnectionRepository_Tests: XCTestCase { statusCode: .unique )) - repository.handleConnectionUpdate(state: .disconnecting(source: .serverInitiated(error: invalidTokenError)), onInvalidToken: { + repository.handleConnectionUpdate(state: .disconnecting(source: .serverInitiated(error: invalidTokenError)), onExpiredToken: { XCTFail("Should not execute invalid token block") }) } - func test_handleConnectionUpdate_whenNoError_shouldNOTExecuteInvalidTokenBlock() { + func test_handleConnectionUpdate_whenNoError_shouldNOTExecuteRefreshTokenBlock() { let states: [WebSocketConnectionState] = [.connecting, .initialized, .connected(connectionId: .newUniqueId), .waitingForConnectionId] for state in states { - repository.handleConnectionUpdate(state: state, onInvalidToken: { + repository.handleConnectionUpdate(state: state, onExpiredToken: { XCTFail("Should not execute invalid token block") }) } @@ -499,7 +515,10 @@ final class ConnectionRepository_Tests: XCTestCase { func test_completeConnectionIdWaiters_nil_connectionId() { // Set initial connectionId let initialConnectionId = "initial-connection-id" - repository.handleConnectionUpdate(state: .connected(connectionId: initialConnectionId), onInvalidToken: {}) + repository.handleConnectionUpdate( + state: .connected(connectionId: initialConnectionId), + onExpiredToken: {} + ) XCTAssertEqual(repository.connectionId, initialConnectionId) repository.completeConnectionIdWaiters(connectionId: nil)