Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(DataStore): dataStore cannot connect to model's sync subscriptions (AWS_LAMBDA auth type) #3550

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ public protocol AnyGraphQLOperation {
associatedtype Success
associatedtype Failure: Error
typealias ResultListener = (Result<Success, Failure>) -> Void
typealias ErrorListener = (Failure) -> Void
}

/// Abastraction for a retryable GraphQLOperation.
Expand All @@ -24,6 +25,7 @@ public protocol RetryableGraphQLOperationBehavior: Operation, DefaultLogger {
typealias RequestFactory = () async -> GraphQLRequest<Payload>
typealias OperationFactory = (GraphQLRequest<Payload>, @escaping OperationResultListener) -> OperationType
typealias OperationResultListener = OperationType.ResultListener
typealias OperationErrorListener = OperationType.ErrorListener

/// Operation unique identifier
var id: UUID { get }
Expand All @@ -45,9 +47,12 @@ public protocol RetryableGraphQLOperationBehavior: Operation, DefaultLogger {
var operationFactory: OperationFactory { get }

var resultListener: OperationResultListener { get }

var errorListener: OperationErrorListener { get }

init(requestFactory: @escaping RequestFactory,
maxRetries: Int,
errorListener: @escaping OperationErrorListener,
resultListener: @escaping OperationResultListener,
_ operationFactory: @escaping OperationFactory)

Expand All @@ -71,6 +76,11 @@ extension RetryableGraphQLOperationBehavior {
attempts += 1
log.debug("[\(id)] - Try [\(attempts)/\(maxRetries)]")
let wrappedResultListener: OperationResultListener = { result in
if case let .failure(error) = result {
// Give an operation a chance to prepare itself for a retry after a failure
self.errorListener(error)
}

if case let .failure(error) = result, self.shouldRetry(error: error as? APIError) {
self.log.debug("\(error)")
Task {
Expand Down Expand Up @@ -103,17 +113,20 @@ public final class RetryableGraphQLOperation<Payload: Decodable>: Operation, Ret
public var attempts: Int = 0
public var requestFactory: RequestFactory
public var underlyingOperation: AtomicValue<GraphQLOperation<Payload>?> = AtomicValue(initialValue: nil)
public var errorListener: OperationErrorListener
public var resultListener: OperationResultListener
public var operationFactory: OperationFactory

public init(requestFactory: @escaping RequestFactory,
maxRetries: Int,
errorListener: @escaping OperationErrorListener,
resultListener: @escaping OperationResultListener,
_ operationFactory: @escaping OperationFactory) {
self.id = UUID()
self.maxRetries = max(1, maxRetries)
self.requestFactory = requestFactory
self.operationFactory = operationFactory
self.errorListener = errorListener
self.resultListener = resultListener
}

Expand Down Expand Up @@ -154,17 +167,21 @@ public final class RetryableGraphQLSubscriptionOperation<Payload: Decodable>: Op
public var attempts: Int = 0
public var underlyingOperation: AtomicValue<GraphQLSubscriptionOperation<Payload>?> = AtomicValue(initialValue: nil)
public var requestFactory: RequestFactory
public var errorListener: OperationErrorListener
public var resultListener: OperationResultListener
public var operationFactory: OperationFactory

private var retriedRTFErrors: [RTFError: Bool] = [:]

public init(requestFactory: @escaping RequestFactory,
maxRetries: Int,
errorListener: @escaping OperationErrorListener,
resultListener: @escaping OperationResultListener,
_ operationFactory: @escaping OperationFactory) {
self.id = UUID()
self.maxRetries = max(1, maxRetries)
self.requestFactory = requestFactory
self.operationFactory = operationFactory
self.errorListener = errorListener
self.resultListener = resultListener
}
public override func main() {
Expand All @@ -178,9 +195,35 @@ public final class RetryableGraphQLSubscriptionOperation<Payload: Decodable>: Op
}

public func shouldRetry(error: APIError?) -> Bool {
return attempts < maxRetries
}
guard case let .operationError(_, _, underlyingError) = error else {
return false
}

if let authError = underlyingError as? AuthError {
switch authError {
case .signedOut, .notAuthorized:
return attempts < maxRetries
default:
return false
}
}

if let rtfError = RTFError(description: error.debugDescription) {

// Do not retry the same RTF error more than once
guard retriedRTFErrors[rtfError] == nil else { return false }
retriedRTFErrors[rtfError] = true

// maxRetries represent the number of auth types to attempt.
// (maxRetries is set to the number of auth types to attempt in multi-auth rules scenarios)
// Increment by 1 to account for that as this is not a "change auth" retry attempt
maxRetries += 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is maxRetries updated here to account for the attempt being incremented for the RTF retry case? So the overall attempt to maxRetries counts will continue to represent the number of auth types to attempt? (maxRetries is set to the number of auth types to attempt in multi-auth rules scenarios).


return true
}

return false
}
}

// MARK: GraphQLOperation - GraphQLSubscriptionOperation + AnyGraphQLOperation
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
//
// Copyright Amazon.com Inc. or its affiliates.
// All Rights Reserved.
//
// SPDX-License-Identifier: Apache-2.0
//

import Foundation

public enum RTFError: CaseIterable {
case unknownField
case maxAttributes
case maxCombinations
case repeatedFieldname
case notGroup
case fieldNotInType

private var uniqueMessagePart: String {
switch self {
case .unknownField:
return "UnknownArgument: Unknown field argument filter"
case .maxAttributes:
return "Filters exceed maximum attributes limit"
case .maxCombinations:
return "Filters combination exceed maximum limit"
case .repeatedFieldname:
return "filter uses same fieldName multiple time"
case .notGroup:
return "The variables input contains a field name 'not'"
case .fieldNotInType:
return "The variables input contains a field that is not defined for input object type"
}
}

/// Init RTF error based on error's debugDescription value
public init?(description: String) {
guard
let rtfError = RTFError.allCases.first(where: { description.contains($0.uniqueMessagePart) })
else {
return nil
}

self = rtfError
}
}