Skip to content

Commit

Permalink
When mutating a state that involves reading the state first, it must …
Browse files Browse the repository at this point in the history
…happen within the same work item to ensure correctness in a multi-threaded environment (#2986)
  • Loading branch information
laevandus committed Jan 31, 2024
1 parent 5b516c8 commit de8ae78
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 41 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
### 🐞 Fixed
- Fix message link preview showing empty space when no metadata available [#2984](https://github.com/GetStream/stream-chat-swift/pull/2984)
- Fix threading issue in `ConnectionRepository` [#2985](https://github.com/GetStream/stream-chat-swift/pull/2985)
- Fix threading issues in `AuthenticationRepository` [#2986](https://github.com/GetStream/stream-chat-swift/pull/2986)

## StreamChatUI
### ✅ Added
Expand Down
55 changes: 26 additions & 29 deletions Sources/StreamChat/Repositories/AuthenticationRepository.swift
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class AuthenticationRepository {
private var _consecutiveRefreshFailures: Int = 0
private var _currentUserId: UserId?
private var _currentToken: Token?
/// Retry timing strategy for refreshing an expired token
private var _tokenExpirationRetryStrategy: RetryStrategy
private var _tokenProvider: TokenProvider?
private var _tokenRequestCompletions: [(Error?) -> Void] = []
private var _tokenWaiters: [String: (Result<Token, Error>) -> Void] = [:]
Expand All @@ -54,8 +56,7 @@ class AuthenticationRepository {
}

private var consecutiveRefreshFailures: Int {
get { tokenQueue.sync { _consecutiveRefreshFailures } }
set { tokenQueue.async(flags: .barrier) { self._consecutiveRefreshFailures = newValue }}
tokenQueue.sync { _consecutiveRefreshFailures }
}

private(set) var currentUserId: UserId? {
Expand All @@ -75,25 +76,12 @@ class AuthenticationRepository {
get { tokenQueue.sync { _tokenProvider } }
set { tokenQueue.async(flags: .barrier) { self._tokenProvider = newValue }}
}

private var tokenRequestCompletions: [(Error?) -> Void] {
get { tokenQueue.sync { _tokenRequestCompletions } }
set { tokenQueue.async(flags: .barrier) { self._tokenRequestCompletions = newValue }}
}

/// An array of requests waiting for the token
private(set) var tokenWaiters: [String: (Result<Token, Error>) -> Void] {
get { tokenQueue.sync { _tokenWaiters } }
set { tokenQueue.async(flags: .barrier) { self._tokenWaiters = newValue }}
}


weak var delegate: AuthenticationRepositoryDelegate?

private let apiClient: APIClient
private let databaseContainer: DatabaseContainer
private let connectionRepository: ConnectionRepository
/// Retry timing strategy for refreshing an expired token
private var tokenExpirationRetryStrategy: RetryStrategy
private let timerType: Timer.Type

init(
Expand All @@ -106,7 +94,7 @@ class AuthenticationRepository {
self.apiClient = apiClient
self.databaseContainer = databaseContainer
self.connectionRepository = connectionRepository
self.tokenExpirationRetryStrategy = tokenExpirationRetryStrategy
_tokenExpirationRetryStrategy = tokenExpirationRetryStrategy
self.timerType = timerType

fetchCurrentUser()
Expand Down Expand Up @@ -251,7 +239,9 @@ class AuthenticationRepository {
}

let waiterToken = String.newUniqueId
tokenWaiters[waiterToken] = completion
tokenQueue.async(flags: .barrier) {
self._tokenWaiters[waiterToken] = completion
}

let globalQueue = DispatchQueue.global()
timerType.schedule(timeInterval: timeout, queue: globalQueue) { [weak self] in
Expand All @@ -275,7 +265,7 @@ class AuthenticationRepository {
}

private func updateToken(token: Token?, notifyTokenWaiters: Bool) {
let waiters: [String: (Result<Token, Error>) -> Void] = tokenQueue.sync {
let waiters: [String: (Result<Token, Error>) -> Void] = tokenQueue.sync(flags: .barrier) {
_currentToken = token
_currentUserId = token?.userId
guard notifyTokenWaiters else { return [:] }
Expand All @@ -295,12 +285,17 @@ class AuthenticationRepository {

private func scheduleTokenFetch(isRetry: Bool, userInfo: UserInfo?, tokenProvider: @escaping TokenProvider, completion: @escaping (Error?) -> Void) {
guard !isGettingToken || isRetry else {
tokenRequestCompletions.append(completion)
tokenQueue.async(flags: .barrier) {
self._tokenRequestCompletions.append(completion)
}
return
}

let interval = tokenQueue.sync(flags: .barrier) {
_tokenExpirationRetryStrategy.getDelayAfterTheFailure()
}
timerType.schedule(
timeInterval: tokenExpirationRetryStrategy.getDelayAfterTheFailure(),
timeInterval: interval,
queue: .main
) { [weak self] in
log.debug("Firing timer for a new token request", subsystems: .authentication)
Expand All @@ -309,7 +304,9 @@ class AuthenticationRepository {
}

private func getToken(isRetry: Bool, userInfo: UserInfo?, tokenProvider: @escaping TokenProvider, completion: @escaping (Error?) -> Void) {
tokenRequestCompletions.append(completion)
tokenQueue.async(flags: .barrier) {
self._tokenRequestCompletions.append(completion)
}
guard !isGettingToken || isRetry else {
log.debug("Trying to get a token while already getting one", subsystems: .authentication)
return
Expand All @@ -325,7 +322,7 @@ class AuthenticationRepository {
log.debug("Successfully retrieved token", subsystems: .authentication)
}

let completionBlocks: [(Error?) -> Void]? = self.tokenQueue.sync {
let completionBlocks: [(Error?) -> Void]? = self.tokenQueue.sync(flags: .barrier) {
self._isGettingToken = false
let completions = self._tokenRequestCompletions
self._tokenRequestCompletions = []
Expand All @@ -351,7 +348,9 @@ class AuthenticationRepository {

let retryFetchIfPossible: (Error?) -> Void = { [weak self] error in
guard let self = self else { return }
self.consecutiveRefreshFailures += 1
self.tokenQueue.async(flags: .barrier) {
self._consecutiveRefreshFailures += 1
}
guard self.consecutiveRefreshFailures < Constants.maximumTokenRefreshAttempts else {
onCompletion(error ?? ClientError.TooManyFailedTokenRefreshAttempts())
return
Expand All @@ -366,7 +365,9 @@ class AuthenticationRepository {
switch result {
case let .success(newToken) where !newToken.isExpired:
onTokenReceived(newToken)
self?.tokenExpirationRetryStrategy.resetConsecutiveFailures()
self?.tokenQueue.sync(flags: .barrier) {
self?._tokenExpirationRetryStrategy.resetConsecutiveFailures()
}
case .success:
retryFetchIfPossible(nil)
case let .failure(error):
Expand Down Expand Up @@ -401,10 +402,6 @@ class AuthenticationRepository {
}
}
}

private func invalidateTokenWaiter(_ waiter: WaiterToken) {
tokenWaiters[waiter] = nil
}
}

extension ClientError {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@ final class AuthenticationRepository_Tests: XCTestCase {
DispatchQueue.concurrentPerform(iterations: 100) { _ in
_ = repository.tokenProvider
}
DispatchQueue.concurrentPerform(iterations: 100) { _ in
_ = repository.tokenWaiters
}
}

func test_currentUserId_isNil_whenNoPreviousSession() {
Expand Down Expand Up @@ -912,6 +909,26 @@ final class AuthenticationRepository_Tests: XCTestCase {
XCTAssertEqual(connectionRepository.updateWebSocketEndpointToken, token)
XCTAssertNil(connectionRepository.updateWebSocketEndpointUserInfo)
}

func test_refreshToken_triggersCompletions_whenConcurrentlyCalled() throws {
let delegate = AuthenticationRepositoryDelegateMock()
delegate.isCapturingStatistics = false
let userId = "user1"
let newUserInfo = UserInfo(id: userId)
let newToken = Token.unique(userId: userId)
repository.delegate = delegate
let error = testPrepareEnvironmentAfterConnect(existingToken: nil, newUserInfo: newUserInfo, newToken: newToken)
XCTAssertNil(error)

let iteration = 100
let expectations = (0..<iteration).map { XCTestExpectation(description: "\($0)") }
DispatchQueue.concurrentPerform(iterations: iteration) { index in
repository.refreshToken { _ in
expectations[index].fulfill()
}
}
wait(for: expectations, timeout: defaultTimeout)
}

// MARK: Provide Token

Expand Down Expand Up @@ -981,18 +998,17 @@ final class AuthenticationRepository_Tests: XCTestCase {
XCTAssertEqual(result?.value, token)
}

func test_provideToken_doesNotDeadlock() {
DispatchQueue.concurrentPerform(iterations: 100) { _ in
func test_provideToken_triggersCompletions_whenConcurrentlyCalled() {
let iteration = 100
let expectations = (0..<iteration).map { XCTestExpectation(description: "\($0)") }
DispatchQueue.concurrentPerform(iterations: iteration) { index in
repository.provideToken(timeout: 0) { _ in
self.repository.tokenWaiters.forEach { _ in }
expectations[index].fulfill()
}
}

DispatchQueue.concurrentPerform(iterations: 100) { _ in
repository.tokenWaiters.forEach { _ in }
}
wait(for: expectations, timeout: defaultTimeout)
}

// MARK: EnvironmentState

func test_environmentState_nilCurrentUserId() {
Expand Down Expand Up @@ -1103,14 +1119,18 @@ private class AuthenticationRepositoryDelegateMock: AuthenticationRepositoryDele
var newState: EnvironmentState?
var logoutCallCount: Int = 0
var newStateCalls: Int = 0
var isCapturingStatistics = true

func didFinishSettingUpAuthenticationEnvironment(for state: EnvironmentState) {
guard isCapturingStatistics else { return }
newStateCalls += 1
newState = state
}

func logOutUser(completion: @escaping () -> Void) {
logoutCallCount += 1
if isCapturingStatistics {
logoutCallCount += 1
}
completion()
}
}
Expand Down

0 comments on commit de8ae78

Please sign in to comment.