Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Wait for endpoint creation to identify user #13353

Merged
merged 3 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/core/src/providers/pinpoint/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ export {
PinpointServiceOptions,
UpdateEndpointException,
} from './types';
export { resolveEndpointId } from './utils';
export { getEndpointId, resolveEndpointId } from './utils';
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

import { updateEndpoint } from '@aws-amplify/core/internals/providers/pinpoint';
import {
getEndpointId,
updateEndpoint,
} from '@aws-amplify/core/internals/providers/pinpoint';
import { assertIsInitialized } from '../../../../../src/pushNotifications/errors/errorHelpers';
import { identifyUser } from '../../../../../src/pushNotifications/providers/pinpoint/apis/identifyUser.native';
import { IdentifyUserInput } from '../../../../../src/pushNotifications/providers/pinpoint/types';
Expand All @@ -11,6 +14,7 @@ import {
} from '../../../../../src/pushNotifications/utils';
import {
getChannelType,
getInflightDeviceRegistration,
resolveConfig,
} from '../../../../../src/pushNotifications/providers/pinpoint/utils';
import {
Expand All @@ -32,11 +36,14 @@ describe('identifyUser (native)', () => {
// assert mocks
const mockAssertIsInitialized = assertIsInitialized as jest.Mock;
const mockGetChannelType = getChannelType as jest.Mock;
const mockUpdateEndpoint = updateEndpoint as jest.Mock;
const mockGetEndpointId = getEndpointId as jest.Mock;
const mockGetInflightDeviceRegistration =
getInflightDeviceRegistration as jest.Mock;
const mockGetPushNotificationUserAgentString =
getPushNotificationUserAgentString as jest.Mock;
const mockResolveConfig = resolveConfig as jest.Mock;
const mockResolveCredentials = resolveCredentials as jest.Mock;
const mockUpdateEndpoint = updateEndpoint as jest.Mock;

beforeAll(() => {
mockGetChannelType.mockReturnValue(channelType);
Expand All @@ -47,7 +54,9 @@ describe('identifyUser (native)', () => {

afterEach(() => {
mockAssertIsInitialized.mockReset();
mockGetEndpointId.mockReset();
mockUpdateEndpoint.mockReset();
mockGetInflightDeviceRegistration.mockClear();
});

it('must be initialized', async () => {
Expand Down Expand Up @@ -111,4 +120,24 @@ describe('identifyUser (native)', () => {
};
await expect(identifyUser(input)).rejects.toBeDefined();
});

it('awaits device registration promise when endpoint is not present', async () => {
const input: IdentifyUserInput = {
userId: 'user-id',
userProfile: {},
};
mockGetEndpointId.mockResolvedValue(undefined);
await identifyUser(input);
expect(mockGetInflightDeviceRegistration).toHaveBeenCalled();
});

it('does not await device registration promise when endpoint is present', async () => {
const input: IdentifyUserInput = {
userId: 'user-id',
userProfile: {},
};
mockGetEndpointId.mockResolvedValue('endpoint-id');
await identifyUser(input);
expect(mockGetInflightDeviceRegistration).not.toHaveBeenCalled();
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ import {
resolveCredentials,
setToken,
} from '../../../../../src/pushNotifications/utils';
import { resolveConfig } from '../../../../../src/pushNotifications//providers/pinpoint/utils';
import {
rejectInflightDeviceRegistration,
resolveConfig,
resolveInflightDeviceRegistration,
} from '../../../../../src/pushNotifications//providers/pinpoint/utils';
import {
completionHandlerId,
credentials,
Expand Down Expand Up @@ -56,8 +60,12 @@ describe('initializePushNotifications (native)', () => {
const mockGetToken = getToken as jest.Mock;
const mockInitialize = initialize as jest.Mock;
const mockIsInitialized = isInitialized as jest.Mock;
const mockRejectInflightDeviceRegistration =
rejectInflightDeviceRegistration as jest.Mock;
const mockResolveCredentials = resolveCredentials as jest.Mock;
const mockResolveConfig = resolveConfig as jest.Mock;
const mockResolveInflightDeviceRegistration =
resolveInflightDeviceRegistration as jest.Mock;
const mockSetToken = setToken as jest.Mock;
const mockNotifyEventListeners = notifyEventListeners as jest.Mock;
const mockNotifyEventListenersAndAwaitHandlers =
Expand Down Expand Up @@ -114,6 +122,8 @@ describe('initializePushNotifications (native)', () => {
mockEventListenerRemover.remove.mockClear();
mockNotifyEventListeners.mockClear();
mockNotifyEventListenersAndAwaitHandlers.mockClear();
mockRejectInflightDeviceRegistration.mockClear();
mockResolveInflightDeviceRegistration.mockClear();
});

it('only enables once', () => {
Expand Down Expand Up @@ -236,29 +246,29 @@ describe('initializePushNotifications (native)', () => {

describe('token received', () => {
it('registers and calls token received listener', done => {
expect.assertions(6);
mockGetToken.mockReturnValue(undefined);
mockAddTokenEventListener.mockImplementation(
async (heardEvent, handler) => {
if (heardEvent === NativeEvent.TOKEN_RECEIVED) {
await handler(pushToken);
expect(mockAddTokenEventListener).toHaveBeenCalledWith(
NativeEvent.TOKEN_RECEIVED,
expect.any(Function),
);
expect(mockSetToken).toHaveBeenCalledWith(pushToken);
expect(mockNotifyEventListeners).toHaveBeenCalledWith(
'tokenReceived',
pushToken,
);
expect(mockUpdateEndpoint).toHaveBeenCalled();
expect(mockResolveInflightDeviceRegistration).toHaveBeenCalled();
expect(mockRejectInflightDeviceRegistration).not.toHaveBeenCalled();
done();
}
},
);
mockUpdateEndpoint.mockImplementation(() => {
expect(mockUpdateEndpoint).toHaveBeenCalled();
done();
});
initializePushNotifications();

expect(mockAddTokenEventListener).toHaveBeenCalledWith(
NativeEvent.TOKEN_RECEIVED,
expect.any(Function),
);
expect(mockSetToken).toHaveBeenCalledWith(pushToken);
expect(mockNotifyEventListeners).toHaveBeenCalledWith(
'tokenReceived',
pushToken,
);
});

it('should not be invoke token received listener with the same token twice', () => {
Expand Down Expand Up @@ -292,13 +302,18 @@ describe('initializePushNotifications (native)', () => {
});

it('throws if device registration fails', done => {
expect.assertions(3);
mockUpdateEndpoint.mockImplementation(() => {
throw new Error();
});
mockAddTokenEventListener.mockImplementation(
async (heardEvent, handler) => {
if (heardEvent === NativeEvent.TOKEN_RECEIVED) {
await expect(handler(pushToken)).rejects.toThrow();
expect(
mockResolveInflightDeviceRegistration,
).not.toHaveBeenCalled();
expect(mockRejectInflightDeviceRegistration).toHaveBeenCalled();
done();
}
},
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

import {
getInflightDeviceRegistration,
rejectInflightDeviceRegistration,
resolveInflightDeviceRegistration,
} from '../../../../../src/pushNotifications/providers/pinpoint/utils/inflightDeviceRegistration';
import { InflightDeviceRegistration } from '../../../../../src/pushNotifications/providers/pinpoint/types';

describe('inflightDeviceRegistration', () => {
describe('resolveInflightDeviceRegistration', () => {
let getInflightDeviceRegistration: () => InflightDeviceRegistration;
let resolveInflightDeviceRegistration: () => void;
jest.isolateModules(() => {
({
getInflightDeviceRegistration,
resolveInflightDeviceRegistration,
} = require('../../../../../src/pushNotifications/providers/pinpoint/utils/inflightDeviceRegistration'));
});

it('creates a pending promise on module load', () => {
expect(getInflightDeviceRegistration()).toBeDefined();
});

it('should resolve the promise', async () => {
const blockedFunction = jest.fn();
const promise = getInflightDeviceRegistration()?.then(() => {
blockedFunction();
});

expect(blockedFunction).not.toHaveBeenCalled();
resolveInflightDeviceRegistration();
await promise;
expect(blockedFunction).toHaveBeenCalled();
});

it('should have released the promise from memory', () => {
expect(getInflightDeviceRegistration()).toBeUndefined();
});
});

describe('rejectInflightDeviceRegistration', () => {
let getInflightDeviceRegistration: () => InflightDeviceRegistration;
let rejectInflightDeviceRegistration: (underlyingError: unknown) => void;
jest.isolateModules(() => {
({
getInflightDeviceRegistration,
rejectInflightDeviceRegistration,
} = require('../../../../../src/pushNotifications/providers/pinpoint/utils/inflightDeviceRegistration'));
});

it('creates a pending promise on module load', () => {
expect(getInflightDeviceRegistration()).toBeDefined();
});

it('should reject the promise', async () => {
const underlyingError = new Error('underlying-error');
const blockedFunction = jest.fn();
const promise = getInflightDeviceRegistration()?.then(() => {
blockedFunction();
});

expect(blockedFunction).not.toHaveBeenCalled();
rejectInflightDeviceRegistration(underlyingError);
await expect(promise).rejects.toMatchObject({
name: 'DeviceRegistrationFailed',
underlyingError,
});
expect(blockedFunction).not.toHaveBeenCalled();
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,21 @@
// SPDX-License-Identifier: Apache-2.0

import { PushNotificationAction } from '@aws-amplify/core/internals/utils';
import { updateEndpoint } from '@aws-amplify/core/internals/providers/pinpoint';
import {
getEndpointId,
updateEndpoint,
} from '@aws-amplify/core/internals/providers/pinpoint';

import { assertIsInitialized } from '../../../errors/errorHelpers';
import {
getPushNotificationUserAgentString,
resolveCredentials,
} from '../../../utils';
import { getChannelType, resolveConfig } from '../utils';
import {
getChannelType,
getInflightDeviceRegistration,
resolveConfig,
} from '../utils';
import { IdentifyUser } from '../types';

export const identifyUser: IdentifyUser = async ({
Expand All @@ -21,6 +28,10 @@ export const identifyUser: IdentifyUser = async ({
const { credentials, identityId } = await resolveCredentials();
const { appId, region } = resolveConfig();
const { address, optOut, userAttributes } = options ?? {};
if (!(await getEndpointId(appId, 'PushNotification'))) {
// if there is no cached endpoint id, wait for successful endpoint creation before continuing
await getInflightDeviceRegistration();
}
await updateEndpoint({
address,
channelType: getChannelType(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ import {
import {
createMessageEventRecorder,
getChannelType,
rejectInflightDeviceRegistration,
resolveConfig,
resolveInflightDeviceRegistration,
} from '../utils';

const {
Expand Down Expand Up @@ -203,16 +205,24 @@ const addAnalyticsListeners = (): void => {
const registerDevice = async (address: string): Promise<void> => {
const { credentials, identityId } = await resolveCredentials();
const { appId, region } = resolveConfig();
await updateEndpoint({
address,
appId,
category: 'PushNotification',
credentials,
region,
channelType: getChannelType(),
identityId,
userAgentValue: getPushNotificationUserAgentString(
PushNotificationAction.InitializePushNotifications,
),
});
try {
await updateEndpoint({
address,
appId,
category: 'PushNotification',
credentials,
region,
channelType: getChannelType(),
identityId,
userAgentValue: getPushNotificationUserAgentString(
PushNotificationAction.InitializePushNotifications,
),
});
// always resolve inflight device registration promise here even though the promise is only awaited on by
// `identifyUser` when no endpoint is found in the cache
resolveInflightDeviceRegistration();
} catch (underlyingError) {
rejectInflightDeviceRegistration(underlyingError);
throw underlyingError;
}
};
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,8 @@ export {
OnTokenReceivedOutput,
} from './outputs';
export { IdentifyUserOptions } from './options';
export { ChannelType } from './pushNotifications';
export {
ChannelType,
InflightDeviceRegistration,
InflightDeviceRegistrationResolver,
} from './pushNotifications';
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,13 @@

import { updateEndpoint } from '@aws-amplify/core/internals/providers/pinpoint';

import { PushNotificationError } from '../../../errors';

export type ChannelType = Parameters<typeof updateEndpoint>[0]['channelType'];

export type InflightDeviceRegistration = Promise<void> | undefined;

export interface InflightDeviceRegistrationResolver {
resolve?(): void;
reject?(error: PushNotificationError): void;
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,9 @@
export { createMessageEventRecorder } from './createMessageEventRecorder';
export { getAnalyticsEvent } from './getAnalyticsEvent';
export { getChannelType } from './getChannelType';
export {
getInflightDeviceRegistration,
rejectInflightDeviceRegistration,
resolveInflightDeviceRegistration,
} from './inflightDeviceRegistration';
export { resolveConfig } from './resolveConfig';