Skip to content

Commit

Permalink
Add retries mechanism to Authentication Repository
Browse files Browse the repository at this point in the history
  • Loading branch information
polqf committed Dec 15, 2022
1 parent 835367c commit c558462
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 53 deletions.
37 changes: 15 additions & 22 deletions Sources/StreamChat/APIClient/APIClient.swift
Expand Up @@ -14,7 +14,7 @@ class APIClient {

/// `APIClient` uses this object to decode the results of network requests.
let decoder: RequestDecoder

/// Used for reobtaining tokens when they expire and API client receives token expiration error
let tokenRefresher: (@escaping () -> Void) -> Void

Expand Down Expand Up @@ -45,12 +45,6 @@ class APIClient {

/// Shows whether the token is being refreshed at the moment
@Atomic private var isRefreshingToken: Bool = false

/// Amount of consecutive token refresh attempts
@Atomic private var tokenRefreshConsecutiveFailures: Int = 0

/// Maximum amount of consecutive token refresh attempts before failing
let maximumTokenRefreshAttempts = 10

/// Maximum amount of times a request can be retried
private let maximumRequestRetries = 3
Expand Down Expand Up @@ -166,11 +160,6 @@ class APIClient {
endpoint: Endpoint<Response>,
completion: @escaping (Result<Response, Error>) -> Void
) {
if tokenRefreshConsecutiveFailures > maximumTokenRefreshAttempts {
completion(.failure(ClientError.TooManyTokenRefreshAttempts()))
return
}

guard !isRefreshingToken else {
completion(.failure(ClientError.RefreshingToken()))
return
Expand Down Expand Up @@ -207,7 +196,6 @@ class APIClient {
response: response,
error: error
)
self.tokenRefreshConsecutiveFailures = 0
completion(.success(decodedResponse))
} catch {
if error is ClientError.ExpiredToken == false {
Expand Down Expand Up @@ -236,18 +224,11 @@ class APIClient {
completion(ClientError.RefreshingToken())
return
}
isRefreshingToken = true

// We stop the queue so no more operations are triggered during the refresh
operationQueue.isSuspended = true

// Increase the amount of consecutive failures
_tokenRefreshConsecutiveFailures.mutate { $0 += 1 }
enterTokenFetchMode()

tokenRefresher { [weak self] in
self?.isRefreshingToken = false
// We restart the queue now that token refresh is completed
self?.operationQueue.isSuspended = false
self?.exitTokenFetchMode()
completion(ClientError.TokenRefreshed())
}
}
Expand Down Expand Up @@ -299,6 +280,18 @@ class APIClient {
isInRecoveryMode = false
operationQueue.isSuspended = false
}

func enterTokenFetchMode() {
// We stop the queue so no more operations are triggered during the refresh
isRefreshingToken = true
operationQueue.isSuspended = true
}

func exitTokenFetchMode() {
// We restart the queue now that token refresh is completed
isRefreshingToken = false
operationQueue.isSuspended = false
}
}

extension URLRequest {
Expand Down
84 changes: 61 additions & 23 deletions Sources/StreamChat/Repositories/AuthenticationRepository.swift
Expand Up @@ -18,8 +18,20 @@ protocol AuthenticationRepositoryDelegate: AnyObject {
}

class AuthenticationRepository {
private enum Constants {
/// Maximum amount of consecutive token refresh attempts before failing
static let maximumTokenRefreshAttempts = 10
}

private let tokenQueue: DispatchQueue = DispatchQueue(label: "io.getstream.auth-repository", attributes: .concurrent)
private var _isGettingToken: Bool = false
private var _isGettingToken: Bool = false {
didSet {
guard oldValue != _isGettingToken else { return }
_isGettingToken ? apiClient.enterTokenFetchMode() : apiClient.exitTokenFetchMode()
}
}

private var _consecutiveRefreshFailures: Int = 0
private var _currentUserId: UserId?
private var _currentToken: Token?
private var _tokenProvider: TokenProvider?
Expand All @@ -31,6 +43,11 @@ class AuthenticationRepository {
set { tokenQueue.async(flags: .barrier) { self._isGettingToken = newValue }}
}

private var consecutiveRefreshFailures: Int {
get { tokenQueue.sync { _consecutiveRefreshFailures } }
set { tokenQueue.async(flags: .barrier) { self._consecutiveRefreshFailures = newValue }}
}

private(set) var currentUserId: UserId? {
get { tokenQueue.sync { _currentUserId } }
set { tokenQueue.async(flags: .barrier) { self._currentUserId = newValue }}
Expand Down Expand Up @@ -65,8 +82,6 @@ class AuthenticationRepository {
private let apiClient: APIClient
private let databaseContainer: DatabaseContainer
private let connectionRepository: ConnectionRepository
/// A timer that runs token refreshing job
private var tokenRetryTimer: TimerControl?
/// Retry timing strategy for refreshing an expired token
private var tokenExpirationRetryStrategy: RetryStrategy
private let timerType: Timer.Type
Expand Down Expand Up @@ -121,7 +136,7 @@ class AuthenticationRepository {
/// - tokenProvider: The block to be used to get a token.
func connectUser(userInfo: UserInfo?, tokenProvider: @escaping TokenProvider, completion: @escaping (Error?) -> Void) {
self.tokenProvider = tokenProvider
getToken(userInfo: userInfo, tokenProvider: tokenProvider, completion: completion)
scheduleTokenFetch(isRetry: false, userInfo: userInfo, tokenProvider: tokenProvider, completion: completion)
}

/// Establishes a connection for a guest user.
Expand Down Expand Up @@ -156,16 +171,7 @@ class AuthenticationRepository {
return
}

let tokenProviderCheckingSuccess: TokenProvider = { [weak self] completion in
tokenProvider { result in
if case .success = result {
self?.tokenExpirationRetryStrategy.resetConsecutiveFailures()
}
completion(result)
}
}

scheduleTokenFetch(userInfo: nil, tokenProvider: tokenProviderCheckingSuccess, completion: completion)
scheduleTokenFetch(isRetry: false, userInfo: nil, tokenProvider: tokenProvider, completion: completion)
}

func prepareEnvironment(
Expand Down Expand Up @@ -240,24 +246,24 @@ class AuthenticationRepository {
waiters.forEach { $0.value(token) }
}

private func scheduleTokenFetch(userInfo: UserInfo?, tokenProvider: @escaping TokenProvider, completion: @escaping (Error?) -> Void) {
guard !isGettingToken else {
private func scheduleTokenFetch(isRetry: Bool, userInfo: UserInfo?, tokenProvider: @escaping TokenProvider, completion: @escaping (Error?) -> Void) {
guard !isGettingToken || isRetry else {
tokenRequestCompletions.append(completion)
return
}

tokenRetryTimer = timerType.schedule(
timerType.schedule(
timeInterval: tokenExpirationRetryStrategy.getDelayAfterTheFailure(),
queue: .main
) { [weak self] in
log.debug("Firing timer for a new token request", subsystems: .authentication)
self?.getToken(userInfo: nil, tokenProvider: tokenProvider, completion: completion)
self?.getToken(isRetry: isRetry, userInfo: nil, tokenProvider: tokenProvider, completion: completion)
}
}

private func getToken(userInfo: UserInfo?, tokenProvider: @escaping TokenProvider, completion: @escaping (Error?) -> Void) {
private func getToken(isRetry: Bool, userInfo: UserInfo?, tokenProvider: @escaping TokenProvider, completion: @escaping (Error?) -> Void) {
tokenRequestCompletions.append(completion)
guard !isGettingToken else {
guard !isGettingToken || isRetry else {
log.debug("Trying to get a token while already getting one", subsystems: .authentication)
return
}
Expand All @@ -276,11 +282,17 @@ class AuthenticationRepository {
self._isGettingToken = false
let completions = self._tokenRequestCompletions
self._tokenRequestCompletions = []
self._consecutiveRefreshFailures = 0
return completions
}
completionBlocks?.forEach { $0(error) }
}

guard consecutiveRefreshFailures < Constants.maximumTokenRefreshAttempts else {
onCompletion(ClientError.TooManyFailedTokenRefreshAttempts())
return
}

let onTokenReceived: (Token) -> Void = { [weak self, weak connectionRepository] token in
self?.prepareEnvironment(userInfo: userInfo, newToken: token) { error in
// Errors thrown during `prepareEnvironment` cannot be recovered
Expand All @@ -297,13 +309,28 @@ class AuthenticationRepository {
}
}

let retryFetchIfPossible: () -> Void = { [weak self] in
guard let self = self else { return }
guard self.consecutiveRefreshFailures < Constants.maximumTokenRefreshAttempts else {
onCompletion(ClientError.TooManyFailedTokenRefreshAttempts())
return
}

self.consecutiveRefreshFailures += 1
self.scheduleTokenFetch(isRetry: true, userInfo: userInfo, tokenProvider: tokenProvider, completion: completion)
}

log.debug("Requesting a new token", subsystems: .authentication)
tokenProvider { result in
tokenProvider { [weak self] result in
switch result {
case let .success(newToken):
case let .success(newToken) where !newToken.isExpired:
onTokenReceived(newToken)
self?.tokenExpirationRetryStrategy.resetConsecutiveFailures()
case .success:
retryFetchIfPossible()
case let .failure(error):
onCompletion(error)
log.info("Failed fetching token with error: \(error)")
retryFetchIfPossible()
}
}
}
Expand Down Expand Up @@ -334,3 +361,14 @@ class AuthenticationRepository {
tokenWaiters[waiter] = nil
}
}

extension ClientError {
public class TooManyFailedTokenRefreshAttempts: ClientError {
override public var localizedDescription: String {
"""
Token fetch has failed more than 10 times.
Please make sure that your `tokenProvider` is correctly functioning.
"""
}
}
}
16 changes: 10 additions & 6 deletions Sources/StreamChat/Repositories/ConnectionRepository.swift
Expand Up @@ -149,15 +149,19 @@ class ConnectionRepository {
case let .connected(connectionId: id):
shouldNotifyConnectionIdWaiters = true
connectionId = id
case let .disconnected(source):
if let error = source.serverError,
error.isInvalidTokenError {
onInvalidToken()
shouldNotifyConnectionIdWaiters = false
} else {

case let .disconnecting(source) where source.serverError?.isInvalidTokenError == true,
let .disconnected(source) where source.serverError?.isInvalidTokenError == true:
onInvalidToken()
if case .disconnected = state {
shouldNotifyConnectionIdWaiters = true
} else {
shouldNotifyConnectionIdWaiters = false
}
connectionId = nil
case .disconnected:
shouldNotifyConnectionIdWaiters = true
connectionId = nil
case .initialized,
.connecting,
.disconnecting,
Expand Down
4 changes: 2 additions & 2 deletions StreamChatUITestsAppUITests/Tests/Authentication_Tests.swift
Expand Up @@ -23,7 +23,7 @@ final class Authentication_Tests: StreamTestCase {
userRobot.login()
}
THEN("app requests a token refresh") {
userRobot.assertConnectionStatus(.connected, timeout: 60)
userRobot.assertConnectionStatus(.connected)
}
}

Expand Down Expand Up @@ -101,7 +101,7 @@ final class Authentication_Tests: StreamTestCase {
AND("server returns an error") {}
AND("JWT generation recovers on server side") {}
THEN("app requests a token refresh a second time") {
userRobot.assertConnectionStatus(.connected, timeout: 60)
userRobot.assertConnectionStatus(.connected)
}
}
}

0 comments on commit c558462

Please sign in to comment.