Skip to content

Commit

Permalink
fix(DataStore): dataStore cannot connect to model's sync subscription…
Browse files Browse the repository at this point in the history
…s (AWS_LAMBDA auth type) aws-amplify#3549
  • Loading branch information
MuniekMg committed Mar 7, 2024
1 parent b96fbda commit 8fa7fed
Show file tree
Hide file tree
Showing 11 changed files with 114 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ public struct DataStoreConfiguration {

/// Selective sync expressions
public let syncExpressions: [DataStoreSyncExpression]

/// Use syncExpressions as subscriptions filter
public let subscriptionFiltering: Bool

/// Authorization mode strategy
public var authModeStrategyType: AuthModeStrategyType
Expand All @@ -79,6 +82,7 @@ public struct DataStoreConfiguration {
syncMaxRecords: UInt,
syncPageSize: UInt,
syncExpressions: [DataStoreSyncExpression],
subscriptionFiltering: Bool = false,
authModeStrategy: AuthModeStrategyType = .default,
disableSubscriptions: @escaping () -> Bool) {
self.errorHandler = errorHandler
Expand All @@ -87,6 +91,7 @@ public struct DataStoreConfiguration {
self.syncMaxRecords = syncMaxRecords
self.syncPageSize = syncPageSize
self.syncExpressions = syncExpressions
self.subscriptionFiltering = subscriptionFiltering
self.authModeStrategyType = authModeStrategy
self.disableSubscriptions = disableSubscriptions
}
Expand All @@ -97,13 +102,15 @@ public struct DataStoreConfiguration {
syncMaxRecords: UInt,
syncPageSize: UInt,
syncExpressions: [DataStoreSyncExpression],
subscriptionFiltering: Bool = false,
authModeStrategy: AuthModeStrategyType = .default) {
self.errorHandler = errorHandler
self.conflictHandler = conflictHandler
self.syncInterval = syncInterval
self.syncMaxRecords = syncMaxRecords
self.syncPageSize = syncPageSize
self.syncExpressions = syncExpressions
self.subscriptionFiltering = subscriptionFiltering
self.authModeStrategyType = authModeStrategy
self.disableSubscriptions = { false }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class RemoteSyncEngine: RemoteSyncEngineBehavior {
authModeStrategy: resolvedAuthStrategy)
// swiftlint:disable line_length
let reconciliationQueueFactory = reconciliationQueueFactory ??
AWSIncomingEventReconciliationQueue.init(modelSchemas:api:storageAdapter:syncExpressions:auth:authModeStrategy:modelReconciliationQueueFactory:disableSubscriptions:)
AWSIncomingEventReconciliationQueue.init(modelSchemas:api:storageAdapter:syncExpressions:auth:authModeStrategy:modelReconciliationQueueFactory:subscriptionFiltering:disableSubscriptions:)
// swiftlint:enable line_length
let initialSyncOrchestratorFactory = initialSyncOrchestratorFactory ??
AWSInitialSyncOrchestrator.init(dataStoreConfiguration:authModeStrategy:api:reconciliationQueue:storageAdapter:)
Expand Down Expand Up @@ -291,6 +291,7 @@ class RemoteSyncEngine: RemoteSyncEngineBehavior {
auth,
authModeStrategy,
nil,
dataStoreConfiguration.subscriptionFiltering,
dataStoreConfiguration.disableSubscriptions)
reconciliationQueueSink = reconciliationQueue?
.publisher
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import Combine
import Foundation

typealias DisableSubscriptions = () -> Bool
typealias SubscriptionFiltering = Bool

// Used for testing:
typealias IncomingEventReconciliationQueueFactory =
Expand All @@ -21,6 +22,7 @@ typealias IncomingEventReconciliationQueueFactory =
AuthCategoryBehavior?,
AuthModeStrategy,
ModelReconciliationQueueFactory?,
SubscriptionFiltering,
@escaping DisableSubscriptions
) async -> IncomingEventReconciliationQueue

Expand Down Expand Up @@ -52,6 +54,7 @@ final class AWSIncomingEventReconciliationQueue: IncomingEventReconciliationQueu
auth: AuthCategoryBehavior? = nil,
authModeStrategy: AuthModeStrategy,
modelReconciliationQueueFactory: ModelReconciliationQueueFactory? = nil,
subscriptionFiltering: Bool,
disableSubscriptions: @escaping () -> Bool = { false }) async {
self.modelSchemasCount = modelSchemas.count
self.modelReconciliationQueueSinks.set([:])
Expand Down Expand Up @@ -101,6 +104,7 @@ final class AWSIncomingEventReconciliationQueue: IncomingEventReconciliationQueu
modelPredicate,
auth,
authModeStrategy,
subscriptionFiltering,
subscriptionsDisabled ? OperationDisabledIncomingSubscriptionEventPublisher() : nil)

reconciliationQueues.with { reconciliationQueues in
Expand Down Expand Up @@ -208,14 +212,15 @@ extension AWSIncomingEventReconciliationQueue: DefaultLogger {
// MARK: - Static factory
extension AWSIncomingEventReconciliationQueue {
static let factory: IncomingEventReconciliationQueueFactory = {
modelSchemas, api, storageAdapter, syncExpressions, auth, authModeStrategy, _, disableSubscriptions in
modelSchemas, api, storageAdapter, syncExpressions, auth, authModeStrategy, _, subscriptionFiltering, disableSubscriptions in
await AWSIncomingEventReconciliationQueue(modelSchemas: modelSchemas,
api: api,
storageAdapter: storageAdapter,
syncExpressions: syncExpressions,
auth: auth,
authModeStrategy: authModeStrategy,
modelReconciliationQueueFactory: nil,
subscriptionFiltering: subscriptionFiltering,
disableSubscriptions: disableSubscriptions)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@ final class AWSIncomingSubscriptionEventPublisher: IncomingSubscriptionEventPubl
api: APICategoryGraphQLBehaviorExtended,
modelPredicate: QueryPredicate?,
auth: AuthCategoryBehavior?,
authModeStrategy: AuthModeStrategy) async {
authModeStrategy: AuthModeStrategy,
subscriptionFiltering: Bool) async {
self.subscriptionEventSubject = PassthroughSubject<IncomingSubscriptionEventPublisherEvent, DataStoreError>()
self.asyncEvents = await IncomingAsyncSubscriptionEventPublisher(modelSchema: modelSchema,
api: api,
modelPredicate: modelPredicate,
auth: auth,
authModeStrategy: authModeStrategy)
authModeStrategy: authModeStrategy,
subscriptionFiltering: subscriptionFiltering)

self.mapper = IncomingAsyncSubscriptionEventToAnyModelMapper()
asyncEvents.subscribe(subscriber: mapper)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ final class IncomingAsyncSubscriptionEventPublisher: AmplifyCancellable {
modelPredicate: QueryPredicate?,
auth: AuthCategoryBehavior?,
authModeStrategy: AuthModeStrategy,
subscriptionFiltering: Bool,
awsAuthService: AWSAuthServiceBehavior? = nil) async {
self.onCreateConnected = false
self.onUpdateConnected = false
Expand Down Expand Up @@ -84,6 +85,7 @@ final class IncomingAsyncSubscriptionEventPublisher: AmplifyCancellable {
api: api,
auth: auth,
awsAuthService: self.awsAuthService,
subscriptionFiltering: subscriptionFiltering,
authTypeProvider: onCreateAuthTypeProvider),
maxRetries: onCreateAuthTypeProvider.count,
resultListener: genericCompletionListenerHandler) { nextRequest, wrappedCompletion in
Expand All @@ -106,6 +108,7 @@ final class IncomingAsyncSubscriptionEventPublisher: AmplifyCancellable {
api: api,
auth: auth,
awsAuthService: self.awsAuthService,
subscriptionFiltering: subscriptionFiltering,
authTypeProvider: onUpdateAuthTypeProvider),
maxRetries: onUpdateAuthTypeProvider.count,
resultListener: genericCompletionListenerHandler) { nextRequest, wrappedCompletion in
Expand All @@ -128,6 +131,7 @@ final class IncomingAsyncSubscriptionEventPublisher: AmplifyCancellable {
api: api,
auth: auth,
awsAuthService: self.awsAuthService,
subscriptionFiltering: subscriptionFiltering,
authTypeProvider: onDeleteAuthTypeProvider),
maxRetries: onUpdateAuthTypeProvider.count,
resultListener: genericCompletionListenerHandler) { nextRequest, wrappedCompletion in
Expand Down Expand Up @@ -308,16 +312,18 @@ extension IncomingAsyncSubscriptionEventPublisher {
api: APICategoryGraphQLBehaviorExtended,
auth: AuthCategoryBehavior?,
awsAuthService: AWSAuthServiceBehavior,
subscriptionFiltering: Bool,
authTypeProvider: AWSAuthorizationTypeIterator) -> RetryableGraphQLOperation<Payload>.RequestFactory {
var authTypes = authTypeProvider

return {
return await IncomingAsyncSubscriptionEventPublisher.makeAPIRequest(for: modelSchema,
where: predicate,
subscriptionType: subscriptionType,
api: api,
auth: auth,
authType: authTypes.next(),
awsAuthService: awsAuthService)
return await IncomingAsyncSubscriptionEventPublisher.makeAPIRequest(for: modelSchema,
where: subscriptionFiltering ? predicate : nil,
subscriptionType: subscriptionType,
api: api,
auth: auth,
authType: authTypes.next(),
awsAuthService: awsAuthService)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ typealias ModelReconciliationQueueFactory = (
QueryPredicate?,
AuthCategoryBehavior?,
AuthModeStrategy,
SubscriptionFiltering,
IncomingSubscriptionEventPublisher?
) async -> ModelReconciliationQueue

Expand Down Expand Up @@ -83,6 +84,7 @@ final class AWSModelReconciliationQueue: ModelReconciliationQueue {
modelPredicate: QueryPredicate?,
auth: AuthCategoryBehavior?,
authModeStrategy: AuthModeStrategy,
subscriptionFiltering: Bool,
incomingSubscriptionEvents: IncomingSubscriptionEventPublisher? = nil) async {

self.modelSchema = modelSchema
Expand All @@ -108,7 +110,8 @@ final class AWSModelReconciliationQueue: ModelReconciliationQueue {
api: api,
modelPredicate: modelPredicate,
auth: auth,
authModeStrategy: authModeStrategy
authModeStrategy: authModeStrategy,
subscriptionFiltering: subscriptionFiltering
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ class AWSIncomingEventReconciliationQueueTests: XCTestCase {
storageAdapter: storageAdapter,
syncExpressions: [],
authModeStrategy: AWSDefaultAuthModeStrategy(),
modelReconciliationQueueFactory: modelReconciliationQueueFactory)
modelReconciliationQueueFactory: modelReconciliationQueueFactory,
subscriptionFiltering: false)
}

// This test case attempts to hit a race condition, and may be required to execute multiple times
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ final class IncomingAsyncSubscriptionEventPublisherTests: XCTestCase {
modelPredicate: nil,
auth: nil,
authModeStrategy: AWSDefaultAuthModeStrategy(),
subscriptionFiltering: false,
awsAuthService: nil)
let mapper = IncomingAsyncSubscriptionEventToAnyModelMapper()
asyncEvents.subscribe(subscriber: mapper)
Expand Down Expand Up @@ -73,6 +74,7 @@ final class IncomingAsyncSubscriptionEventPublisherTests: XCTestCase {
modelPredicate: nil,
auth: nil,
authModeStrategy: AWSDefaultAuthModeStrategy(),
subscriptionFiltering: false,
awsAuthService: nil)
let mapper = IncomingAsyncSubscriptionEventToAnyModelMapper()
asyncEvents.subscribe(subscriber: mapper)
Expand Down Expand Up @@ -107,10 +109,10 @@ final class IncomingAsyncSubscriptionEventPublisherTests: XCTestCase {
sink.cancel()
}

/// Given: IncomingAsyncSubscriptionEventPublisher initilized with modelPredicate
/// Given: IncomingAsyncSubscriptionEventPublisher initilized with modelPredicate and subscriptionFiltering enabled
/// When: IncomingAsyncSubscriptionEventPublisher subscribes to onCreate, onUpdate, onDelete events
/// Then: IncomingAsyncSubscriptionEventPublisher provides correct filters in subscriptions request
func testModelPredicateAsSubscribtionsFilter() async throws {
func testSubscriptionFilteringEnabledModelPredicateAsSubscribtionsFilter() async throws {

let id1 = UUID().uuidString
let id2 = UUID().uuidString
Expand Down Expand Up @@ -174,8 +176,70 @@ final class IncomingAsyncSubscriptionEventPublisherTests: XCTestCase {
]),
auth: nil,
authModeStrategy: AWSDefaultAuthModeStrategy(),
subscriptionFiltering: true,
awsAuthService: nil)

await fulfillment(of: [correctFilterOnCreate, correctFilterOnUpdate, correctFilterOnDelete], timeout: 1)
}

/// Given: IncomingAsyncSubscriptionEventPublisher initilized with modelPredicate and subscriptionFiltering disabled
/// When: IncomingAsyncSubscriptionEventPublisher subscribes to onCreate, onUpdate, onDelete events
/// Then: IncomingAsyncSubscriptionEventPublisher has no filters in subscriptions request
func testSubscriptionFilteringDisabledModelPredicateIgnoredInSubscribtions() async throws {

let id1 = UUID().uuidString
let id2 = UUID().uuidString

let noFilterOnCreate = expectation(description: "Correct no filter in onCreate request")
let noFilterOnUpdate = expectation(description: "Correct no filter in onUpdate request")
let noFilterOnDelete = expectation(description: "Correct no filter in onDelete request")

func validateVariables(_ variables: [String: Any]?) -> Bool {
guard variables == nil else {
XCTFail("The request contains variables with subscriptionFiltering disabled")
return false
}

return true
}

let responder = SubscribeRequestListenerResponder<MutationSync<AnyModel>> { request, _, _ in
if request.document.contains("onCreatePost") {
if validateVariables(request.variables) {
noFilterOnCreate.fulfill()
}

} else if request.document.contains("onUpdatePost") {
if validateVariables(request.variables) {
noFilterOnUpdate.fulfill()
}

} else if request.document.contains("onDeletePost") {
if validateVariables(request.variables) {
noFilterOnDelete.fulfill()
}

} else {
XCTFail("Unexpected request: \(request.document)")
}

return nil
}

apiPlugin.responders[.subscribeRequestListener] = responder

_ = await IncomingAsyncSubscriptionEventPublisher(
modelSchema: Post.schema,
api: apiPlugin,
modelPredicate: QueryPredicateGroup(type: .or, predicates: [
Post.keys.id.eq(id1),
Post.keys.id.eq(id2)
]),
auth: nil,
authModeStrategy: AWSDefaultAuthModeStrategy(),
subscriptionFiltering: false,
awsAuthService: nil)

await fulfillment(of: [noFilterOnCreate, noFilterOnUpdate, noFilterOnDelete], timeout: 1)
}
}
Loading

0 comments on commit 8fa7fed

Please sign in to comment.