Skip to content

Commit

Permalink
Fix invalid token triggering token refresh in an infinite loop (#3056)
Browse files Browse the repository at this point in the history
* Fix invalid token calling token refresh in an infinite loop

* Update CHANGELOG.md

* Update CHANGELOG.md
  • Loading branch information
nuno-vieira committed Mar 5, 2024
1 parent 35c366d commit 44527c8
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 27 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
- Fix token provider retrying after calling disconnect [#3052](https://github.com/GetStream/stream-chat-swift/pull/3052)
- Fix connect user never completing when disconnecting after token provider fails [#3052](https://github.com/GetStream/stream-chat-swift/pull/3052)
- Fix current user cache not deleted on logout causing unread count issues after switching users [#3055](https://github.com/GetStream/stream-chat-swift/pull/3055)
- Fix invalid token triggering token refresh in an infinite loop [#3056](https://github.com/GetStream/stream-chat-swift/pull/3056)

# [4.49.0](https://github.com/GetStream/stream-chat-swift/releases/tag/4.49.0)
_February 27, 2024_
Expand Down
9 changes: 6 additions & 3 deletions Sources/StreamChat/ChatClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -451,9 +451,12 @@ extension ChatClient: AuthenticationRepositoryDelegate {

extension ChatClient: ConnectionStateDelegate {
func webSocketClient(_ client: WebSocketClient, didUpdateConnectionState state: WebSocketConnectionState) {
connectionRepository.handleConnectionUpdate(state: state, onInvalidToken: { [weak self] in
self?.refreshToken(completion: nil)
})
connectionRepository.handleConnectionUpdate(
state: state,
onExpiredToken: { [weak self] in
self?.refreshToken(completion: nil)
}
)
connectionRecoveryHandler?.webSocketClient(client, didUpdateConnectionState: state)
}
}
Expand Down
2 changes: 1 addition & 1 deletion Sources/StreamChat/Errors/ErrorPayload.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public struct ErrorPayload: LocalizedError, Codable, CustomDebugStringConvertibl
}

/// https://getstream.io/chat/docs/ios-swift/api_errors_response/
private enum StreamErrorCode {
enum StreamErrorCode {
/// Usually returned when trying to perform an API call without a token.
static let accessKeyInvalid = 2
static let expiredToken = 40
Expand Down
7 changes: 4 additions & 3 deletions Sources/StreamChat/Repositories/ConnectionRepository.swift
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class ConnectionRepository {

func handleConnectionUpdate(
state: WebSocketConnectionState,
onInvalidToken: () -> Void
onExpiredToken: () -> Void
) {
connectionStatus = .init(webSocketConnectionState: state)

Expand All @@ -140,9 +140,10 @@ class ConnectionRepository {
case let .connected(connectionId: id):
shouldNotifyConnectionIdWaiters = true
connectionId = id

case let .disconnected(source) where source.serverError?.isInvalidTokenError == true:
onInvalidToken()
if source.serverError?.isExpiredTokenError == true {
onExpiredToken()
}
shouldNotifyConnectionIdWaiters = false
connectionId = nil
case .disconnected:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ final class ConnectionRepository_Mock: ConnectionRepository, Spy {
static let updateWebSocketEndpointUserId = "updateWebSocketEndpoint(with:)"
static let completeConnectionIdWaiters = "completeConnectionIdWaiters(connectionId:)"
static let provideConnectionId = "provideConnectionId(timeout:completion:)"
static let handleConnectionUpdate = "handleConnectionUpdate(state:onInvalidToken:)"
static let handleConnectionUpdate = "handleConnectionUpdate(state:onExpiredToken:)"
}

var recordedFunctions: [String] = []
Expand All @@ -28,7 +28,7 @@ final class ConnectionRepository_Mock: ConnectionRepository, Spy {
var updateWebSocketEndpointUserInfo: UserInfo?
var completeWaitersConnectionId: ConnectionId?
var connectionUpdateState: WebSocketConnectionState?
var simulateInvalidTokenOnConnectionUpdate = false
var simulateExpiredTokenOnConnectionUpdate = false

convenience init() {
self.init(isClientInActiveMode: true,
Expand Down Expand Up @@ -100,11 +100,11 @@ final class ConnectionRepository_Mock: ConnectionRepository, Spy {
record()
}

override func handleConnectionUpdate(state: WebSocketConnectionState, onInvalidToken: () -> Void) {
override func handleConnectionUpdate(state: WebSocketConnectionState, onExpiredToken: () -> Void) {
record()
connectionUpdateState = state
if simulateInvalidTokenOnConnectionUpdate {
onInvalidToken()
if simulateExpiredTokenOnConnectionUpdate {
onExpiredToken()
}
}

Expand All @@ -117,7 +117,7 @@ final class ConnectionRepository_Mock: ConnectionRepository, Spy {

disconnectResult = nil
disconnectSource = nil
simulateInvalidTokenOnConnectionUpdate = false
simulateExpiredTokenOnConnectionUpdate = false
connectionUpdateState = nil
completeWaitersConnectionId = nil
updateWebSocketEndpointToken = nil
Expand Down
4 changes: 2 additions & 2 deletions Tests/StreamChatTests/ChatClient_Tests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -758,14 +758,14 @@ final class ChatClient_Tests: XCTestCase {
XCTAssertNotCall(AuthenticationRepository_Mock.Signature.refreshToken, on: authenticationRepository)
}

func test_webSocketClientStateUpdate_calls_connectionRepository_invalidToken() throws {
func test_webSocketClientStateUpdate_calls_connectionRepository_expiredToken() throws {
let client = ChatClient(config: inMemoryStorageConfig, environment: testEnv.environment)
let webSocketClient = try XCTUnwrap(client.webSocketClient)
let connectionRepository = try XCTUnwrap(client.connectionRepository as? ConnectionRepository_Mock)
let authenticationRepository = try XCTUnwrap(client.authenticationRepository as? AuthenticationRepository_Mock)

let state = WebSocketConnectionState.disconnected(source: .systemInitiated)
connectionRepository.simulateInvalidTokenOnConnectionUpdate = true
connectionRepository.simulateExpiredTokenOnConnectionUpdate = true
client.webSocketClient(webSocketClient, didUpdateConnectionState: state)

XCTAssertCall(ConnectionRepository_Mock.Signature.handleConnectionUpdate, on: connectionRepository)
Expand Down
43 changes: 31 additions & 12 deletions Tests/StreamChatTests/Repositories/ConnectionRepository_Tests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ final class ConnectionRepository_Tests: XCTestCase {
]

for (webSocketState, connectionStatus) in pairs {
repository.handleConnectionUpdate(state: webSocketState, onInvalidToken: {})
repository.handleConnectionUpdate(state: webSocketState, onExpiredToken: {})
XCTAssertEqual(repository.connectionStatus, connectionStatus)
}
}
Expand Down Expand Up @@ -326,7 +326,7 @@ final class ConnectionRepository_Tests: XCTestCase {
}
}

repository.handleConnectionUpdate(state: webSocketState, onInvalidToken: {})
repository.handleConnectionUpdate(state: webSocketState, onExpiredToken: {})

if shouldNotify {
waitForExpectations(timeout: defaultTimeout)
Expand Down Expand Up @@ -357,45 +357,61 @@ final class ConnectionRepository_Tests: XCTestCase {
repository.completeConnectionIdWaiters(connectionId: originalConnectionId)
XCTAssertEqual(repository.connectionId, originalConnectionId)

repository.handleConnectionUpdate(state: webSocketState, onInvalidToken: {})
repository.handleConnectionUpdate(state: webSocketState, onExpiredToken: {})

XCTAssertEqual(repository.connectionId, newConnectionIdValue)
}
}

func test_handleConnectionUpdate_whenInvalidToken_shouldExecuteInvalidTokenBlock() {
let expectation = self.expectation(description: "Invalid Token Block Executed")
func test_handleConnectionUpdate_whenExpiredToken_shouldExecuteExpiredTokenBlock() {
let expectation = self.expectation(description: "Expired Token Block Not Executed")
let expiredTokenError = ClientError(with: ErrorPayload(
code: StreamErrorCode.expiredToken,
message: .unique,
statusCode: .unique
))

repository.handleConnectionUpdate(state: .disconnected(source: .serverInitiated(error: expiredTokenError)), onExpiredToken: {
expectation.fulfill()
})

waitForExpectations(timeout: defaultTimeout)
}

func test_handleConnectionUpdate_whenInvalidToken_shouldNotExecuteExpiredTokenBlock() {
let expectation = self.expectation(description: "Expired Token Block Not Executed")
expectation.isInverted = true
let invalidTokenError = ClientError(with: ErrorPayload(
code: .random(in: ClosedRange.tokenInvalidErrorCodes),
code: StreamErrorCode.invalidTokenSignature,
message: .unique,
statusCode: .unique
))

repository.handleConnectionUpdate(state: .disconnected(source: .serverInitiated(error: invalidTokenError)), onInvalidToken: {
repository.handleConnectionUpdate(state: .disconnected(source: .serverInitiated(error: invalidTokenError)), onExpiredToken: {
expectation.fulfill()
})

waitForExpectations(timeout: defaultTimeout)
}

func test_handleConnectionUpdate_whenInvalidToken_whenDisconnecting_shouldNOTExecuteInvalidTokenBlock() {
func test_handleConnectionUpdate_whenInvalidToken_whenDisconnecting_shouldNOTExecuteRefreshTokenBlock() {
// We only want to refresh the token when it is actually disconnected, not while it is disconnecting, otherwise we trigger refresh token twice.
let invalidTokenError = ClientError(with: ErrorPayload(
code: .random(in: ClosedRange.tokenInvalidErrorCodes),
message: .unique,
statusCode: .unique
))

repository.handleConnectionUpdate(state: .disconnecting(source: .serverInitiated(error: invalidTokenError)), onInvalidToken: {
repository.handleConnectionUpdate(state: .disconnecting(source: .serverInitiated(error: invalidTokenError)), onExpiredToken: {
XCTFail("Should not execute invalid token block")
})
}

func test_handleConnectionUpdate_whenNoError_shouldNOTExecuteInvalidTokenBlock() {
func test_handleConnectionUpdate_whenNoError_shouldNOTExecuteRefreshTokenBlock() {
let states: [WebSocketConnectionState] = [.connecting, .initialized, .connected(connectionId: .newUniqueId), .waitingForConnectionId]

for state in states {
repository.handleConnectionUpdate(state: state, onInvalidToken: {
repository.handleConnectionUpdate(state: state, onExpiredToken: {
XCTFail("Should not execute invalid token block")
})
}
Expand Down Expand Up @@ -499,7 +515,10 @@ final class ConnectionRepository_Tests: XCTestCase {
func test_completeConnectionIdWaiters_nil_connectionId() {
// Set initial connectionId
let initialConnectionId = "initial-connection-id"
repository.handleConnectionUpdate(state: .connected(connectionId: initialConnectionId), onInvalidToken: {})
repository.handleConnectionUpdate(
state: .connected(connectionId: initialConnectionId),
onExpiredToken: {}
)
XCTAssertEqual(repository.connectionId, initialConnectionId)

repository.completeConnectionIdWaiters(connectionId: nil)
Expand Down

0 comments on commit 44527c8

Please sign in to comment.