Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix invalid token triggering token refresh in an infinite loop #3056

Merged
merged 3 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
- Fix token provider retrying after calling disconnect [#3052](https://github.com/GetStream/stream-chat-swift/pull/3052)
- Fix connect user never completing when disconnecting after token provider fails [#3052](https://github.com/GetStream/stream-chat-swift/pull/3052)
- Fix current user cache not deleted on logout causing unread count issues after switching users [#3055](https://github.com/GetStream/stream-chat-swift/pull/3055)
- Fix invalid token triggering token refresh in an infinite loop [#3056](https://github.com/GetStream/stream-chat-swift/pull/3056)

# [4.49.0](https://github.com/GetStream/stream-chat-swift/releases/tag/4.49.0)
_February 27, 2024_
Expand Down
9 changes: 6 additions & 3 deletions Sources/StreamChat/ChatClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -451,9 +451,12 @@ extension ChatClient: AuthenticationRepositoryDelegate {

extension ChatClient: ConnectionStateDelegate {
func webSocketClient(_ client: WebSocketClient, didUpdateConnectionState state: WebSocketConnectionState) {
connectionRepository.handleConnectionUpdate(state: state, onInvalidToken: { [weak self] in
self?.refreshToken(completion: nil)
})
connectionRepository.handleConnectionUpdate(
state: state,
onExpiredToken: { [weak self] in
self?.refreshToken(completion: nil)
}
)
connectionRecoveryHandler?.webSocketClient(client, didUpdateConnectionState: state)
}
}
Expand Down
2 changes: 1 addition & 1 deletion Sources/StreamChat/Errors/ErrorPayload.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public struct ErrorPayload: LocalizedError, Codable, CustomDebugStringConvertibl
}

/// https://getstream.io/chat/docs/ios-swift/api_errors_response/
private enum StreamErrorCode {
enum StreamErrorCode {
/// Usually returned when trying to perform an API call without a token.
static let accessKeyInvalid = 2
static let expiredToken = 40
Expand Down
7 changes: 4 additions & 3 deletions Sources/StreamChat/Repositories/ConnectionRepository.swift
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class ConnectionRepository {

func handleConnectionUpdate(
state: WebSocketConnectionState,
onInvalidToken: () -> Void
onExpiredToken: () -> Void
) {
connectionStatus = .init(webSocketConnectionState: state)

Expand All @@ -140,9 +140,10 @@ class ConnectionRepository {
case let .connected(connectionId: id):
shouldNotifyConnectionIdWaiters = true
connectionId = id

case let .disconnected(source) where source.serverError?.isInvalidTokenError == true:
onInvalidToken()
if source.serverError?.isExpiredTokenError == true {
onExpiredToken()
}
shouldNotifyConnectionIdWaiters = false
connectionId = nil
case .disconnected:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ final class ConnectionRepository_Mock: ConnectionRepository, Spy {
static let updateWebSocketEndpointUserId = "updateWebSocketEndpoint(with:)"
static let completeConnectionIdWaiters = "completeConnectionIdWaiters(connectionId:)"
static let provideConnectionId = "provideConnectionId(timeout:completion:)"
static let handleConnectionUpdate = "handleConnectionUpdate(state:onInvalidToken:)"
static let handleConnectionUpdate = "handleConnectionUpdate(state:onExpiredToken:)"
}

var recordedFunctions: [String] = []
Expand All @@ -28,7 +28,7 @@ final class ConnectionRepository_Mock: ConnectionRepository, Spy {
var updateWebSocketEndpointUserInfo: UserInfo?
var completeWaitersConnectionId: ConnectionId?
var connectionUpdateState: WebSocketConnectionState?
var simulateInvalidTokenOnConnectionUpdate = false
var simulateExpiredTokenOnConnectionUpdate = false

convenience init() {
self.init(isClientInActiveMode: true,
Expand Down Expand Up @@ -100,11 +100,11 @@ final class ConnectionRepository_Mock: ConnectionRepository, Spy {
record()
}

override func handleConnectionUpdate(state: WebSocketConnectionState, onInvalidToken: () -> Void) {
override func handleConnectionUpdate(state: WebSocketConnectionState, onExpiredToken: () -> Void) {
record()
connectionUpdateState = state
if simulateInvalidTokenOnConnectionUpdate {
onInvalidToken()
if simulateExpiredTokenOnConnectionUpdate {
onExpiredToken()
}
}

Expand All @@ -117,7 +117,7 @@ final class ConnectionRepository_Mock: ConnectionRepository, Spy {

disconnectResult = nil
disconnectSource = nil
simulateInvalidTokenOnConnectionUpdate = false
simulateExpiredTokenOnConnectionUpdate = false
connectionUpdateState = nil
completeWaitersConnectionId = nil
updateWebSocketEndpointToken = nil
Expand Down
4 changes: 2 additions & 2 deletions Tests/StreamChatTests/ChatClient_Tests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -758,14 +758,14 @@ final class ChatClient_Tests: XCTestCase {
XCTAssertNotCall(AuthenticationRepository_Mock.Signature.refreshToken, on: authenticationRepository)
}

func test_webSocketClientStateUpdate_calls_connectionRepository_invalidToken() throws {
func test_webSocketClientStateUpdate_calls_connectionRepository_expiredToken() throws {
let client = ChatClient(config: inMemoryStorageConfig, environment: testEnv.environment)
let webSocketClient = try XCTUnwrap(client.webSocketClient)
let connectionRepository = try XCTUnwrap(client.connectionRepository as? ConnectionRepository_Mock)
let authenticationRepository = try XCTUnwrap(client.authenticationRepository as? AuthenticationRepository_Mock)

let state = WebSocketConnectionState.disconnected(source: .systemInitiated)
connectionRepository.simulateInvalidTokenOnConnectionUpdate = true
connectionRepository.simulateExpiredTokenOnConnectionUpdate = true
client.webSocketClient(webSocketClient, didUpdateConnectionState: state)

XCTAssertCall(ConnectionRepository_Mock.Signature.handleConnectionUpdate, on: connectionRepository)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ final class ConnectionRepository_Tests: XCTestCase {
]

for (webSocketState, connectionStatus) in pairs {
repository.handleConnectionUpdate(state: webSocketState, onInvalidToken: {})
repository.handleConnectionUpdate(state: webSocketState, onExpiredToken: {})
XCTAssertEqual(repository.connectionStatus, connectionStatus)
}
}
Expand Down Expand Up @@ -326,7 +326,7 @@ final class ConnectionRepository_Tests: XCTestCase {
}
}

repository.handleConnectionUpdate(state: webSocketState, onInvalidToken: {})
repository.handleConnectionUpdate(state: webSocketState, onExpiredToken: {})

if shouldNotify {
waitForExpectations(timeout: defaultTimeout)
Expand Down Expand Up @@ -357,45 +357,61 @@ final class ConnectionRepository_Tests: XCTestCase {
repository.completeConnectionIdWaiters(connectionId: originalConnectionId)
XCTAssertEqual(repository.connectionId, originalConnectionId)

repository.handleConnectionUpdate(state: webSocketState, onInvalidToken: {})
repository.handleConnectionUpdate(state: webSocketState, onExpiredToken: {})

XCTAssertEqual(repository.connectionId, newConnectionIdValue)
}
}

func test_handleConnectionUpdate_whenInvalidToken_shouldExecuteInvalidTokenBlock() {
let expectation = self.expectation(description: "Invalid Token Block Executed")
func test_handleConnectionUpdate_whenExpiredToken_shouldExecuteExpiredTokenBlock() {
let expectation = self.expectation(description: "Expired Token Block Not Executed")
let expiredTokenError = ClientError(with: ErrorPayload(
code: StreamErrorCode.expiredToken,
message: .unique,
statusCode: .unique
))

repository.handleConnectionUpdate(state: .disconnected(source: .serverInitiated(error: expiredTokenError)), onExpiredToken: {
expectation.fulfill()
})

waitForExpectations(timeout: defaultTimeout)
}

func test_handleConnectionUpdate_whenInvalidToken_shouldNotExecuteExpiredTokenBlock() {
let expectation = self.expectation(description: "Expired Token Block Not Executed")
expectation.isInverted = true
let invalidTokenError = ClientError(with: ErrorPayload(
code: .random(in: ClosedRange.tokenInvalidErrorCodes),
code: StreamErrorCode.invalidTokenSignature,
message: .unique,
statusCode: .unique
))

repository.handleConnectionUpdate(state: .disconnected(source: .serverInitiated(error: invalidTokenError)), onInvalidToken: {
repository.handleConnectionUpdate(state: .disconnected(source: .serverInitiated(error: invalidTokenError)), onExpiredToken: {
expectation.fulfill()
})

waitForExpectations(timeout: defaultTimeout)
}

func test_handleConnectionUpdate_whenInvalidToken_whenDisconnecting_shouldNOTExecuteInvalidTokenBlock() {
func test_handleConnectionUpdate_whenInvalidToken_whenDisconnecting_shouldNOTExecuteRefreshTokenBlock() {
// We only want to refresh the token when it is actually disconnected, not while it is disconnecting, otherwise we trigger refresh token twice.
let invalidTokenError = ClientError(with: ErrorPayload(
code: .random(in: ClosedRange.tokenInvalidErrorCodes),
message: .unique,
statusCode: .unique
))

repository.handleConnectionUpdate(state: .disconnecting(source: .serverInitiated(error: invalidTokenError)), onInvalidToken: {
repository.handleConnectionUpdate(state: .disconnecting(source: .serverInitiated(error: invalidTokenError)), onExpiredToken: {
XCTFail("Should not execute invalid token block")
})
}

func test_handleConnectionUpdate_whenNoError_shouldNOTExecuteInvalidTokenBlock() {
func test_handleConnectionUpdate_whenNoError_shouldNOTExecuteRefreshTokenBlock() {
let states: [WebSocketConnectionState] = [.connecting, .initialized, .connected(connectionId: .newUniqueId), .waitingForConnectionId]

for state in states {
repository.handleConnectionUpdate(state: state, onInvalidToken: {
repository.handleConnectionUpdate(state: state, onExpiredToken: {
XCTFail("Should not execute invalid token block")
})
}
Expand Down Expand Up @@ -499,7 +515,10 @@ final class ConnectionRepository_Tests: XCTestCase {
func test_completeConnectionIdWaiters_nil_connectionId() {
// Set initial connectionId
let initialConnectionId = "initial-connection-id"
repository.handleConnectionUpdate(state: .connected(connectionId: initialConnectionId), onInvalidToken: {})
repository.handleConnectionUpdate(
state: .connected(connectionId: initialConnectionId),
onExpiredToken: {}
)
XCTAssertEqual(repository.connectionId, initialConnectionId)

repository.completeConnectionIdWaiters(connectionId: nil)
Expand Down