Skip to content

Commit

Permalink
fix(DataStore): dataStore cannot connect to model's sync subscription…
Browse files Browse the repository at this point in the history
…s (AWS_LAMBDA auth type) aws-amplify#3549
  • Loading branch information
MuniekMg committed Apr 11, 2024
1 parent e70a3f5 commit e0398e4
Show file tree
Hide file tree
Showing 8 changed files with 348 additions and 118 deletions.
24 changes: 12 additions & 12 deletions Amplify/Categories/API/Operation/RetryableGraphQLOperation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ public final class RetryableGraphQLSubscriptionOperation<Payload: Decodable>: Op
public var errorListener: OperationErrorListener
public var resultListener: OperationResultListener
public var operationFactory: OperationFactory
private var filterLimitRetried: Bool = false
private var retriedRTFErrors: [RTFError: Bool] = [:]

public init(requestFactory: @escaping RequestFactory,
maxRetries: Int,
Expand All @@ -195,9 +195,7 @@ public final class RetryableGraphQLSubscriptionOperation<Payload: Decodable>: Op
}

public func shouldRetry(error: APIError?) -> Bool {
// return attempts < maxRetries

guard case let .operationError(errorDescription, recoverySuggestion, underlyingError) = error else {
guard case let .operationError(_, _, underlyingError) = error else {
return false
}

Expand All @@ -209,16 +207,18 @@ public final class RetryableGraphQLSubscriptionOperation<Payload: Decodable>: Op
return false
}
}

// TODO: - How to distinguish errors?
// TODO: - Handle other errors
if error.debugDescription.contains("Filters combination exceed maximum limit 10 for subscription.") &&
filterLimitRetried == false {

if let rtfError = RTFError(description: error.debugDescription) {

// Just to be sure that endless retry won't happen
filterLimitRetried = true
// Do not retry the same RTF error more than once
guard retriedRTFErrors[rtfError] == nil else { return false }
retriedRTFErrors[rtfError] = true

// maxRetries represent the number of auth types to attempt.
// (maxRetries is set to the number of auth types to attempt in multi-auth rules scenarios)
// Increment by 1 to account for that as this is not a "change auth" retry attempt
maxRetries += 1

return true
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
//
// Copyright Amazon.com Inc. or its affiliates.
// All Rights Reserved.
//
// SPDX-License-Identifier: Apache-2.0
//

import Foundation

public enum RTFError: CaseIterable {
case unknownField
case maxAttributes
case maxCombinations
case repeatedFieldname
case notGroup
case fieldNotInType

Check warning on line 17 in Amplify/Categories/DataStore/Subscribe/DataStoreSubscriptionRTFError.swift

View workflow job for this annotation

GitHub Actions / run-swiftlint

Lines should not have trailing whitespace (trailing_whitespace)
private var uniqueMessagePart: String {
switch self {
case .unknownField:
return "UnknownArgument: Unknown field argument filter"
case .maxAttributes:
return "Filters exceed maximum attributes limit"
case .maxCombinations:
return "Filters combination exceed maximum limit"
case .repeatedFieldname:
return "filter uses same fieldName multiple time"
case .notGroup:
return "The variables input contains a field name 'not'"
case .fieldNotInType:
return "The variables input contains a field that is not defined for input object type"
}
}

Check warning on line 34 in Amplify/Categories/DataStore/Subscribe/DataStoreSubscriptionRTFError.swift

View workflow job for this annotation

GitHub Actions / run-swiftlint

Lines should not have trailing whitespace (trailing_whitespace)
/// Init RTF error based on error's debugDescription value
public init?(description: String) {
guard

Check warning on line 37 in Amplify/Categories/DataStore/Subscribe/DataStoreSubscriptionRTFError.swift

View workflow job for this annotation

GitHub Actions / run-swiftlint

Lines should not have trailing whitespace (trailing_whitespace)
let rtfError = RTFError.allCases.first(where: { description.contains($0.uniqueMessagePart) })
else {
return nil
}

Check warning on line 42 in Amplify/Categories/DataStore/Subscribe/DataStoreSubscriptionRTFError.swift

View workflow job for this annotation

GitHub Actions / run-swiftlint

Lines should not have trailing whitespace (trailing_whitespace)
self = rtfError
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import AWSPluginsCore
@testable import APIHostApp

final class GraphQLSubscriptionsTests: XCTestCase {
static let amplifyConfiguration = "testconfiguration/AWSAPIPluginV2Tests-amplifyconfiguration"
static let amplifyConfiguration = "AWSAPIPluginV2Tests-amplifyconfiguration"

override func setUp() async throws {
await Amplify.reset()
Expand Down Expand Up @@ -107,91 +107,15 @@ final class GraphQLSubscriptionsTests: XCTestCase {
XCTFail("Failed to create post"); return
}

await fulfillment(of: [onCreateCorrectPost1], timeout: TestCommonConstants.networkTimeout)
await fulfillment(of: [onCreateCorrectPost2], timeout: TestCommonConstants.networkTimeout)

subscription.cancel()
}

func testOnCreatePostSubscriptionWithTooManyFiltersFallbackToNoFilter() async throws {
let incorrectTitle = "other_title"
let incorrectPost1Id = UUID().uuidString

let correctTitle = "correct"
let correctPost1Id = UUID().uuidString
let correctPost2Id = UUID().uuidString

let connectedInvoked = expectation(description: "Connection established")
let onCreateCorrectPost1 = expectation(description: "Receioved onCreate for correctPost1")
let onCreateCorrectPost2 = expectation(description: "Receioved onCreate for correctPost2")

let modelType = Post.self
let filter: QueryPredicate = QueryPredicateGroup(type: .or, predicates:
(0...20).map {
modelType.keys.title.eq("\($0)")
}
await fulfillment(
of: [onCreateCorrectPost1, onCreateCorrectPost2],
timeout: TestCommonConstants.networkTimeout,
enforceOrder: true
)

let request = GraphQLRequest<MutationSyncResult>.subscription(to: modelType, where: filter, subscriptionType: .onCreate)

let subscription = Amplify.API.subscribe(request: request)
Task {
do {
for try await subscriptionEvent in subscription {
switch subscriptionEvent {
case .connection(let state):
switch state {
case .connected:
connectedInvoked.fulfill()

case .connecting, .disconnected:
break
}

case .data(let graphQLResponse):
switch graphQLResponse {
case .success(let mutationSync):
if mutationSync.model.id == correctPost1Id {
onCreateCorrectPost1.fulfill()

} else if mutationSync.model.id == correctPost2Id {
onCreateCorrectPost2.fulfill()

} else if mutationSync.model.id == incorrectPost1Id {
XCTFail("We should not receive onCreate for filtered out model!")
}

case .failure(let error):
XCTFail(error.errorDescription)
}
}
}

} catch {
XCTFail("Unexpected subscription failure: \(error)")
}
}

await fulfillment(of: [connectedInvoked], timeout: TestCommonConstants.networkTimeout)

guard try await createPost(id: incorrectPost1Id, title: incorrectTitle) != nil else {
XCTFail("Failed to create post"); return
}

guard try await createPost(id: correctPost1Id, title: correctTitle) != nil else {
XCTFail("Failed to create post"); return
}

guard try await createPost(id: correctPost2Id, title: correctTitle) != nil else {
XCTFail("Failed to create post"); return
}

await fulfillment(of: [onCreateCorrectPost1], timeout: TestCommonConstants.networkTimeout)
await fulfillment(of: [onCreateCorrectPost2], timeout: TestCommonConstants.networkTimeout)

subscription.cancel()
}

// MARK: Helpers

func createPost(id: String, title: String) async throws -> Post? {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,19 +90,16 @@ final class IncomingAsyncSubscriptionEventPublisher: AmplifyCancellable {
authTypeProvider: { onCreateAuthType }),
maxRetries: onCreateAuthTypeProvider.count,
errorListener: { error in
// TODO: - How to distinguish errors?
// TODO: - Handle other errors
if error.debugDescription.contains("Filters combination exceed maximum limit 10 for subscription.") {

if let _ = RTFError(description: error.debugDescription) {
onCreateModelPredicate = nil

} else if case let .operationError(errorDescription, recoverySuggestion, underlyingError) = error,
let authError = underlyingError as? AuthError {

} else if case let .operationError(_, _, underlyingError) = error, let authError = underlyingError as? AuthError {
switch authError {
case .signedOut, .notAuthorized:
onCreateAuthType = onCreateAuthTypeProvider.next()
default:
return
break
}
}
},
Expand Down Expand Up @@ -132,19 +129,16 @@ final class IncomingAsyncSubscriptionEventPublisher: AmplifyCancellable {
authTypeProvider: { onUpdateAuthType }),
maxRetries: onUpdateAuthTypeProvider.count,
errorListener: { error in
// TODO: - How to distinguish errors?
// TODO: - Handle other errors
if error.debugDescription.contains("Filters combination exceed maximum limit 10 for subscription.") {

if let _ = RTFError(description: error.debugDescription) {
onUpdateModelPredicate = nil

} else if case let .operationError(errorDescription, recoverySuggestion, underlyingError) = error,
let authError = underlyingError as? AuthError {

} else if case let .operationError(_, _, underlyingError) = error, let authError = underlyingError as? AuthError {
switch authError {
case .signedOut, .notAuthorized:
onUpdateAuthType = onUpdateAuthTypeProvider.next()
default:
return
break
}
}
},
Expand Down Expand Up @@ -174,19 +168,16 @@ final class IncomingAsyncSubscriptionEventPublisher: AmplifyCancellable {
authTypeProvider: { onDeleteAuthType }),
maxRetries: onUpdateAuthTypeProvider.count,
errorListener: { error in
// TODO: - How to distinguish errors?
// TODO: - Handle other errors
if error.debugDescription.contains("Filters combination exceed maximum limit 10 for subscription.") {

if let _ = RTFError(description: error.debugDescription) {
onDeleteModelPredicate = nil

} else if case let .operationError(errorDescription, recoverySuggestion, underlyingError) = error,
let authError = underlyingError as? AuthError {

} else if case let .operationError(_, _, underlyingError) = error, let authError = underlyingError as? AuthError {
switch authError {
case .signedOut, .notAuthorized:
onDeleteAuthType = onDeleteAuthTypeProvider.next()
default:
return
break
}
}
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,91 @@ class SubscriptionEndToEndTests: SyncEngineIntegrationTestBase {
let deleteSyncData = await getMutationSync(forPostWithId: id)
XCTAssertNil(deleteSyncData)
}

/// Given: DataStore configured with syncExpressions which causes error "Validation error of type UnknownArgument: Unknown field argument filter @ \'onCreatePost\'" when connecting to sync subscriptions
/// When: Adding, editing, removing model
/// Then: Receives create, update, delete mutation
func testRestartsSubscriptionAfterFailureAndReceivesCreateMutateDelete() async throws {

// Filter all events to ensure they have this ID. This prevents us from overfulfilling on
// unrelated subscriptions
let id = UUID().uuidString

let originalContent = "Original content from SubscriptionTests at \(Date())"
let updatedContent = "UPDATED CONTENT from SubscriptionTests at \(Date())"

let createReceived = expectation(description: "createReceived")
let updateReceived = expectation(description: "updateReceived")
let deleteReceived = expectation(description: "deleteReceived")

let syncExpressions: [DataStoreSyncExpression] = [
.syncExpression(Post.schema) {
QueryPredicateGroup(type: .or, predicates: [
Post.keys.id.eq(id)
])
}
]

#if os(watchOS)
let dataStoreConfiguration = DataStoreConfiguration.custom(syncMaxRecords: 100, syncExpressions: syncExpressions, disableSubscriptions: { false })
#else
let dataStoreConfiguration = DataStoreConfiguration.custom(syncMaxRecords: 100, syncExpressions: syncExpressions)
#endif

await setUp(withModels: TestModelRegistration(), dataStoreConfiguration: dataStoreConfiguration)
try await startAmplifyAndWaitForSync()

var cancellables = Set<AnyCancellable>()
Amplify.Hub.publisher(for: .dataStore)
.filter { $0.eventName == HubPayload.EventName.DataStore.syncReceived }
.compactMap { $0.data as? MutationEvent }
.filter { $0.modelId == id }
.map(\.mutationType)
.sink {
switch $0 {
case GraphQLMutationType.create.rawValue:
createReceived.fulfill()
case GraphQLMutationType.update.rawValue:
updateReceived.fulfill()
case GraphQLMutationType.delete.rawValue:
deleteReceived.fulfill()
default:
break
}
}
.store(in: &cancellables)

// Act: send create mutation
try await sendCreateRequest(withId: id, content: originalContent)
await fulfillment(of: [createReceived], timeout: 10)
// Assert
let createSyncData = await getMutationSync(forPostWithId: id)
XCTAssertNotNil(createSyncData)
let createdPost = createSyncData?.model.instance as? Post
XCTAssertNotNil(createdPost)
XCTAssertEqual(createdPost?.content, originalContent)
XCTAssertEqual(createSyncData?.syncMetadata.version, 1)
XCTAssertEqual(createSyncData?.syncMetadata.deleted, false)

// Act: send update mutation
try await sendUpdateRequest(forId: id, content: updatedContent, version: 1)
await fulfillment(of: [updateReceived], timeout: 10)
// Assert
let updateSyncData = await getMutationSync(forPostWithId: id)
XCTAssertNotNil(updateSyncData)
let updatedPost = updateSyncData?.model.instance as? Post
XCTAssertNotNil(updatedPost)
XCTAssertEqual(updatedPost?.content, updatedContent)
XCTAssertEqual(updateSyncData?.syncMetadata.version, 2)
XCTAssertEqual(updateSyncData?.syncMetadata.deleted, false)

// Act: send delete mutation
try await sendDeleteRequest(forId: id, version: 2)
await fulfillment(of: [deleteReceived], timeout: 10)
// Assert
let deleteSyncData = await getMutationSync(forPostWithId: id)
XCTAssertNil(deleteSyncData)
}

// MARK: - Utilities

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class SyncEngineIntegrationV2TestBase: DataStoreTestBase {
// swiftlint:enable force_try
// swiftlint:enable force_cast

func setUp(withModels models: AmplifyModelRegistration, logLevel: LogLevel = .error) async {
func setUp(withModels models: AmplifyModelRegistration, syncExpressions: [DataStoreSyncExpression] = [], logLevel: LogLevel = .error) async {

Amplify.Logging.logLevel = logLevel

Expand All @@ -55,10 +55,17 @@ class SyncEngineIntegrationV2TestBase: DataStoreTestBase {
))
#if os(watchOS)
try Amplify.add(plugin: AWSDataStorePlugin(modelRegistration: models,
configuration: .custom(syncMaxRecords: 100, disableSubscriptions: { false })))
configuration: .custom(
syncMaxRecords: 100,
syncExpressions: syncExpressions,
disableSubscriptions: { false }
)))
#else
try Amplify.add(plugin: AWSDataStorePlugin(modelRegistration: models,
configuration: .custom(syncMaxRecords: 100)))
configuration: .custom(
syncMaxRecords: 100,
syncExpressions: syncExpressions
)))
#endif
} catch {
XCTFail(String(describing: error))
Expand Down
Loading

0 comments on commit e0398e4

Please sign in to comment.