diff --git a/package.json b/package.json index 8e44eb37c33..7970a9f5d26 100644 --- a/package.json +++ b/package.json @@ -102,6 +102,7 @@ "eslint-plugin-prettier": "^5.1.3", "eslint-plugin-promise": "^6.1.1", "eslint-plugin-unused-imports": "^3.0.0", + "expect": "^29.7.0", "glob": "^10.3.10", "husky": "^9.0.11", "jest": "^29.7.0", diff --git a/packages/storage/__tests__/providers/s3/apis/copy.test.ts b/packages/storage/__tests__/providers/s3/apis/copy.test.ts index 260e38d4863..55547ae8e7c 100644 --- a/packages/storage/__tests__/providers/s3/apis/copy.test.ts +++ b/packages/storage/__tests__/providers/s3/apis/copy.test.ts @@ -14,6 +14,7 @@ import { CopyWithPathInput, CopyWithPathOutput, } from '../../../../src/providers/s3/types'; +import './testUtils'; jest.mock('../../../../src/providers/s3/utils/client'); jest.mock('@aws-amplify/core', () => ({ @@ -186,11 +187,14 @@ describe('copy API', () => { }); expect(key).toEqual(destinationKey); expect(copyObject).toHaveBeenCalledTimes(1); - expect(copyObject).toHaveBeenCalledWith(copyObjectClientConfig, { - ...copyObjectClientBaseParams, - CopySource: expectedSourceKey, - Key: expectedDestinationKey, - }); + await expect(copyObject).toBeLastCalledWithConfigAndInput( + copyObjectClientConfig, + { + ...copyObjectClientBaseParams, + CopySource: expectedSourceKey, + Key: expectedDestinationKey, + }, + ); }); }, ); @@ -239,11 +243,14 @@ describe('copy API', () => { }); expect(path).toEqual(expectedDestinationPath); expect(copyObject).toHaveBeenCalledTimes(1); - expect(copyObject).toHaveBeenCalledWith(copyObjectClientConfig, { - ...copyObjectClientBaseParams, - CopySource: `${bucket}/${expectedSourcePath}`, - Key: expectedDestinationPath, - }); + await expect(copyObject).toBeLastCalledWithConfigAndInput( + copyObjectClientConfig, + { + ...copyObjectClientBaseParams, + CopySource: `${bucket}/${expectedSourcePath}`, + Key: expectedDestinationPath, + }, + ); }, ); }); @@ -269,11 +276,14 @@ describe('copy API', () => { }); } catch (error: any) { expect(copyObject).toHaveBeenCalledTimes(1); - expect(copyObject).toHaveBeenCalledWith(copyObjectClientConfig, { - ...copyObjectClientBaseParams, - CopySource: `${bucket}/public/${missingSourceKey}`, - Key: `public/${destinationKey}`, - }); + await expect(copyObject).toBeLastCalledWithConfigAndInput( + copyObjectClientConfig, + { + ...copyObjectClientBaseParams, + CopySource: `${bucket}/public/${missingSourceKey}`, + Key: `public/${destinationKey}`, + }, + ); expect(error.$metadata.httpStatusCode).toBe(404); } }); diff --git a/packages/storage/__tests__/providers/s3/apis/downloadData.test.ts b/packages/storage/__tests__/providers/s3/apis/downloadData.test.ts index 721720679c0..57d402b1f24 100644 --- a/packages/storage/__tests__/providers/s3/apis/downloadData.test.ts +++ b/packages/storage/__tests__/providers/s3/apis/downloadData.test.ts @@ -23,6 +23,7 @@ import { ItemWithKey, ItemWithPath, } from '../../../../src/providers/s3/types/outputs'; +import './testUtils'; jest.mock('../../../../src/providers/s3/utils/client'); jest.mock('../../../../src/providers/s3/utils'); @@ -142,7 +143,7 @@ describe('downloadData with key', () => { body: 'body', }); expect(getObject).toHaveBeenCalledTimes(1); - expect(getObject).toHaveBeenCalledWith( + await expect(getObject).toBeLastCalledWithConfigAndInput( { credentials, region, @@ -288,7 +289,7 @@ describe('downloadData with path', () => { body: 'body', }); expect(getObject).toHaveBeenCalledTimes(1); - expect(getObject).toHaveBeenCalledWith( + await expect(getObject).toBeLastCalledWithConfigAndInput( { credentials, region, diff --git a/packages/storage/__tests__/providers/s3/apis/getProperties.test.ts b/packages/storage/__tests__/providers/s3/apis/getProperties.test.ts index 3b2ca3cae58..bb5a5b957a7 100644 --- a/packages/storage/__tests__/providers/s3/apis/getProperties.test.ts +++ b/packages/storage/__tests__/providers/s3/apis/getProperties.test.ts @@ -12,6 +12,7 @@ import { GetPropertiesWithPathInput, GetPropertiesWithPathOutput, } from '../../../../src/providers/s3/types'; +import './testUtils'; jest.mock('../../../../src/providers/s3/utils/client'); jest.mock('@aws-amplify/core', () => ({ @@ -145,7 +146,10 @@ describe('getProperties with key', () => { ...expectedResult, }); expect(headObject).toHaveBeenCalledTimes(1); - expect(headObject).toHaveBeenCalledWith(config, headObjectOptions); + await expect(headObject).toBeLastCalledWithConfigAndInput( + config, + headObjectOptions, + ); }, ); }); @@ -166,7 +170,7 @@ describe('getProperties with key', () => { await getPropertiesWrapper({ key: inputKey }); } catch (error: any) { expect(headObject).toHaveBeenCalledTimes(1); - expect(headObject).toHaveBeenCalledWith( + await expect(headObject).toBeLastCalledWithConfigAndInput( { credentials, region: 'region', @@ -265,7 +269,10 @@ describe('Happy cases: With path', () => { ...expectedResult, }); expect(headObject).toHaveBeenCalledTimes(1); - expect(headObject).toHaveBeenCalledWith(config, headObjectOptions); + await expect(headObject).toBeLastCalledWithConfigAndInput( + config, + headObjectOptions, + ); }, ); }); @@ -286,7 +293,7 @@ describe('Happy cases: With path', () => { await getPropertiesWrapper({ path: inputPath }); } catch (error: any) { expect(headObject).toHaveBeenCalledTimes(1); - expect(headObject).toHaveBeenCalledWith( + await expect(headObject).toBeLastCalledWithConfigAndInput( { credentials, region: 'region', diff --git a/packages/storage/__tests__/providers/s3/apis/getUrl.test.ts b/packages/storage/__tests__/providers/s3/apis/getUrl.test.ts index 428f8f2034c..994f4a0b648 100644 --- a/packages/storage/__tests__/providers/s3/apis/getUrl.test.ts +++ b/packages/storage/__tests__/providers/s3/apis/getUrl.test.ts @@ -15,6 +15,7 @@ import { GetUrlWithPathInput, GetUrlWithPathOutput, } from '../../../../src/providers/s3/types'; +import './testUtils'; jest.mock('../../../../src/providers/s3/utils/client'); jest.mock('@aws-amplify/core', () => ({ @@ -31,8 +32,8 @@ jest.mock('@aws-amplify/core', () => ({ const bucket = 'bucket'; const region = 'region'; -const mockFetchAuthSession = Amplify.Auth.fetchAuthSession as jest.Mock; -const mockGetConfig = Amplify.getConfig as jest.Mock; +const mockFetchAuthSession = jest.mocked(Amplify.Auth.fetchAuthSession); +const mockGetConfig = jest.mocked(Amplify.getConfig); const credentials: AWSCredentials = { accessKeyId: 'accessKeyId', sessionToken: 'sessionToken', @@ -68,7 +69,7 @@ describe('getUrl test with key', () => { }; const key = 'key'; beforeEach(() => { - (headObject as jest.MockedFunction).mockResolvedValue({ + jest.mocked(headObject).mockResolvedValue({ ContentLength: 100, ContentType: 'text/plain', ETag: 'etag', @@ -76,11 +77,7 @@ describe('getUrl test with key', () => { Metadata: { meta: 'value' }, $metadata: {} as any, }); - ( - getPresignedGetObjectUrl as jest.MockedFunction< - typeof getPresignedGetObjectUrl - > - ).mockResolvedValue(mockURL); + jest.mocked(getPresignedGetObjectUrl).mockResolvedValue(mockURL); }); afterEach(() => { jest.clearAllMocks(); @@ -131,7 +128,10 @@ describe('getUrl test with key', () => { }; expect(getPresignedGetObjectUrl).toHaveBeenCalledTimes(1); expect(headObject).toHaveBeenCalledTimes(1); - expect(headObject).toHaveBeenCalledWith(config, headObjectOptions); + await expect(headObject).toBeLastCalledWithConfigAndInput( + config, + headObjectOptions, + ); expect({ url, expiresAt }).toEqual(expectedResult); }, ); @@ -187,7 +187,7 @@ describe('getUrl test with path', () => { userAgentValue: expect.any(String), }; beforeEach(() => { - (headObject as jest.MockedFunction).mockResolvedValue({ + jest.mocked(headObject).mockResolvedValue({ ContentLength: 100, ContentType: 'text/plain', ETag: 'etag', @@ -195,11 +195,7 @@ describe('getUrl test with path', () => { Metadata: { meta: 'value' }, $metadata: {} as any, }); - ( - getPresignedGetObjectUrl as jest.MockedFunction< - typeof getPresignedGetObjectUrl - > - ).mockResolvedValue(mockURL); + jest.mocked(getPresignedGetObjectUrl).mockResolvedValue(mockURL); }); afterEach(() => { jest.clearAllMocks(); @@ -229,7 +225,10 @@ describe('getUrl test with path', () => { }); expect(getPresignedGetObjectUrl).toHaveBeenCalledTimes(1); expect(headObject).toHaveBeenCalledTimes(1); - expect(headObject).toHaveBeenCalledWith(config, headObjectOptions); + await expect(headObject).toBeLastCalledWithConfigAndInput( + config, + headObjectOptions, + ); expect({ url, expiresAt }).toEqual({ url: mockURL, expiresAt: expect.any(Date), diff --git a/packages/storage/__tests__/providers/s3/apis/list.test.ts b/packages/storage/__tests__/providers/s3/apis/list.test.ts index 76f4d3a7881..9629129d7a2 100644 --- a/packages/storage/__tests__/providers/s3/apis/list.test.ts +++ b/packages/storage/__tests__/providers/s3/apis/list.test.ts @@ -16,6 +16,7 @@ import { ListPaginateWithPathInput, ListPaginateWithPathOutput, } from '../../../../src/providers/s3/types'; +import './testUtils'; jest.mock('../../../../src/providers/s3/utils/client'); jest.mock('@aws-amplify/core', () => ({ @@ -172,11 +173,14 @@ describe('list API', () => { }); expect(response.nextToken).toEqual(nextToken); expect(listObjectsV2).toHaveBeenCalledTimes(1); - expect(listObjectsV2).toHaveBeenCalledWith(listObjectClientConfig, { - Bucket: bucket, - MaxKeys: 1000, - Prefix: expectedKey, - }); + await expect(listObjectsV2).toBeLastCalledWithConfigAndInput( + listObjectClientConfig, + { + Bucket: bucket, + MaxKeys: 1000, + Prefix: expectedKey, + }, + ); }); }); @@ -210,12 +214,15 @@ describe('list API', () => { }); expect(response.nextToken).toEqual(nextToken); expect(listObjectsV2).toHaveBeenCalledTimes(1); - expect(listObjectsV2).toHaveBeenCalledWith(listObjectClientConfig, { - Bucket: bucket, - Prefix: expectedKey, - ContinuationToken: nextToken, - MaxKeys: customPageSize, - }); + await expect(listObjectsV2).toBeLastCalledWithConfigAndInput( + listObjectClientConfig, + { + Bucket: bucket, + Prefix: expectedKey, + ContinuationToken: nextToken, + MaxKeys: customPageSize, + }, + ); }); }); @@ -236,11 +243,15 @@ describe('list API', () => { expect(response.items).toEqual([]); expect(response.nextToken).toEqual(undefined); - expect(listObjectsV2).toHaveBeenCalledWith(listObjectClientConfig, { - Bucket: bucket, - MaxKeys: 1000, - Prefix: expectedKey, - }); + expect(listObjectsV2).toHaveBeenCalledTimes(1); + await expect(listObjectsV2).toBeLastCalledWithConfigAndInput( + listObjectClientConfig, + { + Bucket: bucket, + MaxKeys: 1000, + Prefix: expectedKey, + }, + ); }); }); @@ -268,8 +279,8 @@ describe('list API', () => { // listing three times for three pages expect(listObjectsV2).toHaveBeenCalledTimes(3); - // first input recieves undefined as the Continuation Token - expect(listObjectsV2).toHaveBeenNthCalledWith( + // first input receives undefined as the Continuation Token + await expect(listObjectsV2).toHaveBeenNthCalledWithConfigAndInput( 1, listObjectClientConfig, { @@ -279,8 +290,8 @@ describe('list API', () => { ContinuationToken: undefined, }, ); - // last input recieves TEST_TOKEN as the Continuation Token - expect(listObjectsV2).toHaveBeenNthCalledWith( + // last input receives TEST_TOKEN as the Continuation Token + await expect(listObjectsV2).toHaveBeenNthCalledWithConfigAndInput( 3, listObjectClientConfig, { @@ -346,11 +357,14 @@ describe('list API', () => { }); expect(response.nextToken).toEqual(nextToken); expect(listObjectsV2).toHaveBeenCalledTimes(1); - expect(listObjectsV2).toHaveBeenCalledWith(listObjectClientConfig, { - Bucket: bucket, - MaxKeys: 1000, - Prefix: resolvePath(inputPath), - }); + await expect(listObjectsV2).toBeLastCalledWithConfigAndInput( + listObjectClientConfig, + { + Bucket: bucket, + MaxKeys: 1000, + Prefix: resolvePath(inputPath), + }, + ); }, ); @@ -385,12 +399,15 @@ describe('list API', () => { }); expect(response.nextToken).toEqual(nextToken); expect(listObjectsV2).toHaveBeenCalledTimes(1); - expect(listObjectsV2).toHaveBeenCalledWith(listObjectClientConfig, { - Bucket: bucket, - Prefix: resolvePath(inputPath), - ContinuationToken: nextToken, - MaxKeys: customPageSize, - }); + await expect(listObjectsV2).toBeLastCalledWithConfigAndInput( + listObjectClientConfig, + { + Bucket: bucket, + Prefix: resolvePath(inputPath), + ContinuationToken: nextToken, + MaxKeys: customPageSize, + }, + ); }, ); @@ -406,11 +423,15 @@ describe('list API', () => { expect(response.items).toEqual([]); expect(response.nextToken).toEqual(undefined); - expect(listObjectsV2).toHaveBeenCalledWith(listObjectClientConfig, { - Bucket: bucket, - MaxKeys: 1000, - Prefix: resolvePath(path), - }); + expect(listObjectsV2).toHaveBeenCalledTimes(1); + await expect(listObjectsV2).toBeLastCalledWithConfigAndInput( + listObjectClientConfig, + { + Bucket: bucket, + MaxKeys: 1000, + Prefix: resolvePath(path), + }, + ); }, ); @@ -437,8 +458,8 @@ describe('list API', () => { // listing three times for three pages expect(listObjectsV2).toHaveBeenCalledTimes(3); - // first input recieves undefined as the Continuation Token - expect(listObjectsV2).toHaveBeenNthCalledWith( + // first input receives undefined as the Continuation Token + await expect(listObjectsV2).toHaveBeenNthCalledWithConfigAndInput( 1, listObjectClientConfig, { @@ -448,8 +469,8 @@ describe('list API', () => { ContinuationToken: undefined, }, ); - // last input recieves TEST_TOKEN as the Continuation Token - expect(listObjectsV2).toHaveBeenNthCalledWith( + // last input receives TEST_TOKEN as the Continuation Token + await expect(listObjectsV2).toHaveBeenNthCalledWithConfigAndInput( 3, listObjectClientConfig, { @@ -479,11 +500,14 @@ describe('list API', () => { } catch (error: any) { expect.assertions(3); expect(listObjectsV2).toHaveBeenCalledTimes(1); - expect(listObjectsV2).toHaveBeenCalledWith(listObjectClientConfig, { - Bucket: bucket, - MaxKeys: 1000, - Prefix: 'public/', - }); + await expect(listObjectsV2).toBeLastCalledWithConfigAndInput( + listObjectClientConfig, + { + Bucket: bucket, + MaxKeys: 1000, + Prefix: 'public/', + }, + ); expect(error.$metadata.httpStatusCode).toBe(404); } }); diff --git a/packages/storage/__tests__/providers/s3/apis/remove.test.ts b/packages/storage/__tests__/providers/s3/apis/remove.test.ts index 61745b54455..ca1107f0912 100644 --- a/packages/storage/__tests__/providers/s3/apis/remove.test.ts +++ b/packages/storage/__tests__/providers/s3/apis/remove.test.ts @@ -13,6 +13,7 @@ import { RemoveWithPathInput, RemoveWithPathOutput, } from '../../../../src/providers/s3/types'; +import './testUtils'; jest.mock('../../../../src/providers/s3/utils/client'); jest.mock('@aws-amplify/core', () => ({ @@ -105,10 +106,13 @@ describe('remove API', () => { }); expect(key).toEqual(inputKey); expect(deleteObject).toHaveBeenCalledTimes(1); - expect(deleteObject).toHaveBeenCalledWith(deleteObjectClientConfig, { - Bucket: bucket, - Key: expectedKey, - }); + await expect(deleteObject).toBeLastCalledWithConfigAndInput( + deleteObjectClientConfig, + { + Bucket: bucket, + Key: expectedKey, + }, + ); }); }); }); @@ -144,10 +148,13 @@ describe('remove API', () => { const { path } = await removeWrapper({ path: inputPath }); expect(path).toEqual(resolvedPath); expect(deleteObject).toHaveBeenCalledTimes(1); - expect(deleteObject).toHaveBeenCalledWith(deleteObjectClientConfig, { - Bucket: bucket, - Key: resolvedPath, - }); + await expect(deleteObject).toBeLastCalledWithConfigAndInput( + deleteObjectClientConfig, + { + Bucket: bucket, + Key: resolvedPath, + }, + ); }); }); }); @@ -170,10 +177,13 @@ describe('remove API', () => { await remove({ key }); } catch (error: any) { expect(deleteObject).toHaveBeenCalledTimes(1); - expect(deleteObject).toHaveBeenCalledWith(deleteObjectClientConfig, { - Bucket: bucket, - Key: `public/${key}`, - }); + await expect(deleteObject).toBeLastCalledWithConfigAndInput( + deleteObjectClientConfig, + { + Bucket: bucket, + Key: `public/${key}`, + }, + ); expect(error.$metadata.httpStatusCode).toBe(404); } }); diff --git a/packages/storage/__tests__/providers/s3/apis/testUtils.ts b/packages/storage/__tests__/providers/s3/apis/testUtils.ts new file mode 100644 index 00000000000..75f5dd823b2 --- /dev/null +++ b/packages/storage/__tests__/providers/s3/apis/testUtils.ts @@ -0,0 +1,131 @@ +import { AWSCredentials } from '@aws-amplify/core/internals/utils'; +import { expect } from '@jest/globals'; +import { type MatcherFunction } from 'expect'; + +const toBeLastCalledWithConfigAndInput: MatcherFunction< + [config: { credentials: unknown }, input: unknown] +> = async function toBeLastCalledWithConfigAndInput( + actualHandler, + expectedConfig, + expectedInput, +) { + if (!jest.isMockFunction(actualHandler)) { + return { + message: () => + `expected custom client handler to be a mock function, got ${actualHandler}`, + pass: false, + }; + } + const actualConfig = actualHandler.mock.lastCall?.[0]; + const actualInput = actualHandler.mock.lastCall?.[1]; + const actualConfigWithResolvedCredentials = + typeof actualConfig?.credentials === 'function' + ? { + ...actualConfig, + credentials: await actualConfig.credentials(), + } + : actualConfig; + if ( + this.equals(actualConfigWithResolvedCredentials, expectedConfig) && + this.equals(actualInput, expectedInput) + ) { + return { + message: () => '', + pass: true, + }; + } + + return { + message: () => + `expected ${JSON.stringify(actualConfig)} to equal ${JSON.stringify(expectedConfig)} and ${JSON.stringify(actualInput)} to equal ${JSON.stringify(expectedInput)}`, + pass: false, + }; +}; + +const toHaveBeenNthCalledWithConfigAndInput: MatcherFunction< + [nthCall: number, config: unknown, input: unknown] +> = async function toHaveBeenNthCalledWithConfigAndInput( + actualHandler, + nthCall, + expectedConfig, + expectedInput, +) { + if (!jest.isMockFunction(actualHandler)) { + return { + message: () => + `expected custom client handler to be a mock function, got ${actualHandler}`, + pass: false, + }; + } + const actualConfig = actualHandler.mock.calls[nthCall - 1]?.[0]; + const actualInput = actualHandler.mock.calls[nthCall - 1]?.[1]; + const actualConfigWithResolvedCredentials = + typeof actualConfig?.credentials === 'function' + ? { + ...actualConfig, + credentials: await actualConfig.credentials(), + } + : actualConfig; + if ( + this.equals(actualConfigWithResolvedCredentials, expectedConfig) && + this.equals(actualInput, expectedInput) + ) { + return { + message: () => '', + pass: true, + }; + } + + return { + message: () => + `expected ${JSON.stringify(actualConfig)} to equal ${JSON.stringify(expectedConfig)} and ${JSON.stringify(actualInput)} to equal ${JSON.stringify(expectedInput)}`, + pass: false, + }; +}; + +expect.extend({ + toBeLastCalledWithConfigAndInput, + toHaveBeenNthCalledWithConfigAndInput, +}); + +interface ConfigType { + credentials: AWSCredentials | (() => Promise); +} + +declare global { + namespace jest { + interface AsymmetricMatchers { + toBeLastCalledWithConfigAndInput( + expectedConfig: C, + expectedInput: any, + ): void; + toHaveBeenNthCalledWithConfigAndInput( + nthCall: number, + expectedConfig: C, + expectedInput: any, + ): void; + } + interface Matchers { + /** + * Asynchronously asserts mocked custom client handler to be last called with expected config and input. + * If the actual client config has a credential that is a provider function, it will be resolved to static + * credential object and matched against the supplied config credentials. + */ + toBeLastCalledWithConfigAndInput( + expectedConfig: C, + expectedInput: any, + ): Promise; + + /** + * Asynchronously asserts mocked custom client handler to be Nth called with expected config and input. + * If the actual client config has a credential that is a provider function, it will be resolved to static + * credential object and matched against the supplied config credentials. + */ + toHaveBeenNthCalledWithConfigAndInput( + nthCall: number, + expectedConfig: C, + expectedInput: any, + ): Promise; + } + } +} diff --git a/packages/storage/__tests__/providers/s3/apis/uploadData/multipartHandlers.test.ts b/packages/storage/__tests__/providers/s3/apis/uploadData/multipartHandlers.test.ts index d8300dde305..c40e5c83de6 100644 --- a/packages/storage/__tests__/providers/s3/apis/uploadData/multipartHandlers.test.ts +++ b/packages/storage/__tests__/providers/s3/apis/uploadData/multipartHandlers.test.ts @@ -21,6 +21,7 @@ import { UPLOADS_STORAGE_KEY } from '../../../../../src/providers/s3/utils/const import { byteLength } from '../../../../../src/providers/s3/apis/uploadData/byteLength'; import { CanceledError } from '../../../../../src/errors/CanceledError'; import { StorageOptions } from '../../../../../src/types'; +import '../testUtils'; jest.mock('@aws-amplify/core'); jest.mock('../../../../../src/providers/s3/utils/client'); @@ -40,12 +41,12 @@ const defaultCacheKey = '8388608_application/octet-stream_bucket_public_key'; const testPath = 'testPath/object'; const testPathCacheKey = `8388608_${defaultContentType}_${bucket}_custom_${testPath}`; -const mockCreateMultipartUpload = createMultipartUpload as jest.Mock; -const mockUploadPart = uploadPart as jest.Mock; -const mockCompleteMultipartUpload = completeMultipartUpload as jest.Mock; -const mockAbortMultipartUpload = abortMultipartUpload as jest.Mock; -const mockListParts = listParts as jest.Mock; -const mockHeadObject = headObject as jest.Mock; +const mockCreateMultipartUpload = jest.mocked(createMultipartUpload); +const mockUploadPart = jest.mocked(uploadPart); +const mockCompleteMultipartUpload = jest.mocked(completeMultipartUpload); +const mockAbortMultipartUpload = jest.mocked(abortMultipartUpload); +const mockListParts = jest.mocked(listParts); +const mockHeadObject = jest.mocked(headObject); const disableAssertionFlag = true; @@ -55,20 +56,23 @@ const mockMultipartUploadSuccess = (disableAssertion?: boolean) => { let totalSize = 0; mockCreateMultipartUpload.mockResolvedValueOnce({ UploadId: 'uploadId', + $metadata: {}, }); + // @ts-expect-error Special mock to make uploadPart return input part number mockUploadPart.mockImplementation(async (s3Config, input) => { if (!disableAssertion) { expect(input.UploadId).toEqual('uploadId'); } // mock 2 invocation of onProgress callback to simulate progress - s3Config?.onUploadProgress({ - transferredBytes: input.Body.byteLength / 2, - totalBytes: input.Body.byteLength, + const body = input.Body as ArrayBuffer; + s3Config?.onUploadProgress?.({ + transferredBytes: body.byteLength / 2, + totalBytes: body.byteLength, }); - s3Config?.onUploadProgress({ - transferredBytes: input.Body.byteLength, - totalBytes: input.Body.byteLength, + s3Config?.onUploadProgress?.({ + transferredBytes: body.byteLength, + totalBytes: body.byteLength, }); totalSize += byteLength(input.Body)!; @@ -80,9 +84,11 @@ const mockMultipartUploadSuccess = (disableAssertion?: boolean) => { }); mockCompleteMultipartUpload.mockResolvedValueOnce({ ETag: 'etag', + $metadata: {}, }); mockHeadObject.mockResolvedValueOnce({ ContentLength: totalSize, + $metadata: {}, }); }; @@ -91,8 +97,10 @@ const mockMultipartUploadCancellation = ( ) => { mockCreateMultipartUpload.mockImplementation(async () => ({ UploadId: 'uploadId', + $metadata: {}, })); + // @ts-expect-error Only need partial mock mockUploadPart.mockImplementation(async ({ abortSignal }, { PartNumber }) => { beforeUploadPartResponseCallback?.(); if (abortSignal?.aborted) { @@ -105,10 +113,13 @@ const mockMultipartUploadCancellation = ( }; }); - mockAbortMultipartUpload.mockResolvedValueOnce({}); + mockAbortMultipartUpload.mockResolvedValueOnce({ + $metadata: {}, + }); // Mock resumed upload and completed upload successfully mockCompleteMultipartUpload.mockResolvedValueOnce({ ETag: 'etag', + $metadata: {}, }); }; @@ -194,7 +205,9 @@ describe('getMultipartUploadHandlers with key', () => { options: options as StorageOptions, }); const result = await multipartUploadJob(); - expect(mockCreateMultipartUpload).toHaveBeenCalledWith( + await expect( + mockCreateMultipartUpload, + ).toBeLastCalledWithConfigAndInput( expect.objectContaining({ credentials, region, @@ -258,7 +271,9 @@ describe('getMultipartUploadHandlers with key', () => { expect(mockCreateMultipartUpload).toHaveBeenCalledTimes(1); expect(mockUploadPart).toHaveBeenCalledTimes(10_000); expect(mockCompleteMultipartUpload).toHaveBeenCalledTimes(1); - expect(mockUploadPart.mock.calls[0][1].Body.byteLength).toEqual(10 * MB); // The part size should be adjusted from default 5MB to 10MB. + expect( + (mockUploadPart.mock.calls[0][1].Body as ArrayBuffer).byteLength, + ).toEqual(10 * MB); // The part size should be adjusted from default 5MB to 10MB. }); it('should throw error when remote and local file sizes do not match upon completed upload', async () => { @@ -267,6 +282,7 @@ describe('getMultipartUploadHandlers with key', () => { mockHeadObject.mockReset(); mockHeadObject.mockResolvedValue({ ContentLength: 1, + $metadata: {}, }); const { multipartUploadJob } = getMultipartUploadHandlers( @@ -318,6 +334,7 @@ describe('getMultipartUploadHandlers with key', () => { mockUploadPart.mockReset(); mockUploadPart.mockResolvedValueOnce({ ETag: `etag-1`, + // @ts-expect-error Special mock to make uploadPart return input part number. PartNumber: 1, }); mockUploadPart.mockRejectedValueOnce(new Error('error')); @@ -370,7 +387,7 @@ describe('getMultipartUploadHandlers with key', () => { }), ); mockMultipartUploadSuccess(); - mockListParts.mockResolvedValueOnce({ Parts: [] }); + mockListParts.mockResolvedValueOnce({ Parts: [], $metadata: {} }); const size = 8 * MB; const { multipartUploadJob } = getMultipartUploadHandlers( { @@ -388,7 +405,7 @@ describe('getMultipartUploadHandlers with key', () => { it('should cache the upload with file including file lastModified property', async () => { mockMultipartUploadSuccess(); - mockListParts.mockResolvedValueOnce({ Parts: [] }); + mockListParts.mockResolvedValueOnce({ Parts: [], $metadata: {} }); const size = 8 * MB; const { multipartUploadJob } = getMultipartUploadHandlers( { @@ -423,7 +440,7 @@ describe('getMultipartUploadHandlers with key', () => { }), ); mockMultipartUploadSuccess(); - mockListParts.mockResolvedValueOnce({ Parts: [] }); + mockListParts.mockResolvedValueOnce({ Parts: [], $metadata: {} }); const size = 8 * MB; const { multipartUploadJob } = getMultipartUploadHandlers( { @@ -441,7 +458,7 @@ describe('getMultipartUploadHandlers with key', () => { it('should cache upload task if new upload task is created', async () => { mockMultipartUploadSuccess(); - mockListParts.mockResolvedValueOnce({ Parts: [] }); + mockListParts.mockResolvedValueOnce({ Parts: [], $metadata: {} }); const size = 8 * MB; const { multipartUploadJob } = getMultipartUploadHandlers( { @@ -468,7 +485,7 @@ describe('getMultipartUploadHandlers with key', () => { it('should remove from cache if upload task is completed', async () => { mockMultipartUploadSuccess(); - mockListParts.mockResolvedValueOnce({ Parts: [] }); + mockListParts.mockResolvedValueOnce({ Parts: [], $metadata: {} }); const size = 8 * MB; const { multipartUploadJob } = getMultipartUploadHandlers( { @@ -490,7 +507,7 @@ describe('getMultipartUploadHandlers with key', () => { it('should remove from cache if upload task is canceled', async () => { expect.assertions(2); mockMultipartUploadSuccess(disableAssertionFlag); - mockListParts.mockResolvedValueOnce({ Parts: [] }); + mockListParts.mockResolvedValueOnce({ Parts: [], $metadata: {} }); const size = 8 * MB; const { multipartUploadJob } = getMultipartUploadHandlers( { @@ -612,6 +629,7 @@ describe('getMultipartUploadHandlers with key', () => { ); mockListParts.mockResolvedValue({ Parts: [{ PartNumber: 1 }], + $metadata: {}, }); const onProgress = jest.fn(); @@ -701,7 +719,9 @@ describe('getMultipartUploadHandlers with path', () => { data: twoPartsPayload, }); const result = await multipartUploadJob(); - expect(mockCreateMultipartUpload).toHaveBeenCalledWith( + await expect( + mockCreateMultipartUpload, + ).toBeLastCalledWithConfigAndInput( expect.objectContaining({ credentials, region, @@ -765,7 +785,9 @@ describe('getMultipartUploadHandlers with path', () => { expect(mockCreateMultipartUpload).toHaveBeenCalledTimes(1); expect(mockUploadPart).toHaveBeenCalledTimes(10_000); expect(mockCompleteMultipartUpload).toHaveBeenCalledTimes(1); - expect(mockUploadPart.mock.calls[0][1].Body.byteLength).toEqual(10 * MB); // The part size should be adjusted from default 5MB to 10MB. + expect( + (mockUploadPart.mock.calls[0][1].Body as ArrayBuffer).byteLength, + ).toEqual(10 * MB); // The part size should be adjusted from default 5MB to 10MB. }); it('should throw error when remote and local file sizes do not match upon completed upload', async () => { @@ -774,6 +796,7 @@ describe('getMultipartUploadHandlers with path', () => { mockHeadObject.mockReset(); mockHeadObject.mockResolvedValue({ ContentLength: 1, + $metadata: {}, }); const { multipartUploadJob } = getMultipartUploadHandlers( @@ -825,6 +848,7 @@ describe('getMultipartUploadHandlers with path', () => { mockUploadPart.mockReset(); mockUploadPart.mockResolvedValueOnce({ ETag: `etag-1`, + // @ts-expect-error Special mock to make uploadPart return input part number. PartNumber: 1, }); mockUploadPart.mockRejectedValueOnce(new Error('error')); @@ -877,7 +901,7 @@ describe('getMultipartUploadHandlers with path', () => { }), ); mockMultipartUploadSuccess(); - mockListParts.mockResolvedValueOnce({ Parts: [] }); + mockListParts.mockResolvedValueOnce({ Parts: [], $metadata: {} }); const size = 8 * MB; const { multipartUploadJob } = getMultipartUploadHandlers( { @@ -895,7 +919,7 @@ describe('getMultipartUploadHandlers with path', () => { it('should cache the upload with file including file lastModified property', async () => { mockMultipartUploadSuccess(); - mockListParts.mockResolvedValueOnce({ Parts: [] }); + mockListParts.mockResolvedValueOnce({ Parts: [], $metadata: {} }); const size = 8 * MB; const { multipartUploadJob } = getMultipartUploadHandlers( { @@ -933,7 +957,7 @@ describe('getMultipartUploadHandlers with path', () => { }), ); mockMultipartUploadSuccess(); - mockListParts.mockResolvedValueOnce({ Parts: [] }); + mockListParts.mockResolvedValueOnce({ Parts: [], $metadata: {} }); const size = 8 * MB; const { multipartUploadJob } = getMultipartUploadHandlers( { @@ -951,7 +975,7 @@ describe('getMultipartUploadHandlers with path', () => { it('should cache upload task if new upload task is created', async () => { mockMultipartUploadSuccess(); - mockListParts.mockResolvedValueOnce({ Parts: [] }); + mockListParts.mockResolvedValueOnce({ Parts: [], $metadata: {} }); const size = 8 * MB; const { multipartUploadJob } = getMultipartUploadHandlers( { @@ -976,7 +1000,7 @@ describe('getMultipartUploadHandlers with path', () => { it('should remove from cache if upload task is completed', async () => { mockMultipartUploadSuccess(); - mockListParts.mockResolvedValueOnce({ Parts: [] }); + mockListParts.mockResolvedValueOnce({ Parts: [], $metadata: {} }); const size = 8 * MB; const { multipartUploadJob } = getMultipartUploadHandlers( { @@ -998,7 +1022,7 @@ describe('getMultipartUploadHandlers with path', () => { it('should remove from cache if upload task is canceled', async () => { expect.assertions(2); mockMultipartUploadSuccess(disableAssertionFlag); - mockListParts.mockResolvedValueOnce({ Parts: [] }); + mockListParts.mockResolvedValueOnce({ Parts: [], $metadata: {} }); const size = 8 * MB; const { multipartUploadJob } = getMultipartUploadHandlers( { @@ -1120,6 +1144,7 @@ describe('getMultipartUploadHandlers with path', () => { ); mockListParts.mockResolvedValue({ Parts: [{ PartNumber: 1 }], + $metadata: {}, }); const onProgress = jest.fn(); diff --git a/packages/storage/__tests__/providers/s3/apis/uploadData/putObjectJob.test.ts b/packages/storage/__tests__/providers/s3/apis/uploadData/putObjectJob.test.ts index 3d76fe3776e..335e804c0ea 100644 --- a/packages/storage/__tests__/providers/s3/apis/uploadData/putObjectJob.test.ts +++ b/packages/storage/__tests__/providers/s3/apis/uploadData/putObjectJob.test.ts @@ -7,6 +7,7 @@ import { Amplify } from '@aws-amplify/core'; import { putObject } from '../../../../../src/providers/s3/utils/client'; import { calculateContentMd5 } from '../../../../../src/providers/s3/utils'; import { putObjectJob } from '../../../../../src/providers/s3/apis/uploadData/putObjectJob'; +import '../testUtils'; jest.mock('../../../../../src/providers/s3/utils/client'); jest.mock('../../../../../src/providers/s3/utils', () => { @@ -35,14 +36,14 @@ const credentials: AWSCredentials = { secretAccessKey: 'secretAccessKey', }; const identityId = 'identityId'; -const mockFetchAuthSession = Amplify.Auth.fetchAuthSession as jest.Mock; -const mockPutObject = putObject as jest.Mock; +const mockFetchAuthSession = jest.mocked(Amplify.Auth.fetchAuthSession); +const mockPutObject = jest.mocked(putObject); mockFetchAuthSession.mockResolvedValue({ credentials, identityId, }); -(Amplify.getConfig as jest.Mock).mockReturnValue({ +jest.mocked(Amplify.getConfig).mockReturnValue({ Storage: { S3: { bucket: 'bucket', @@ -53,10 +54,15 @@ mockFetchAuthSession.mockResolvedValue({ mockPutObject.mockResolvedValue({ ETag: 'eTag', VersionId: 'versionId', + $metadata: {}, }); /* TODO Remove suite when `key` parameter is removed */ describe('putObjectJob with key', () => { + beforeEach(() => { + mockPutObject.mockClear(); + }); + it('should supply the correct parameters to putObject API handler', async () => { const abortController = new AbortController(); const inputKey = 'key'; @@ -92,7 +98,8 @@ describe('putObjectJob with key', () => { metadata: { key: 'value' }, size: undefined, }); - expect(mockPutObject).toHaveBeenCalledWith( + expect(mockPutObject).toHaveBeenCalledTimes(1); + await expect(mockPutObject).toBeLastCalledWithConfigAndInput( { credentials, region: 'region', @@ -135,6 +142,10 @@ describe('putObjectJob with key', () => { }); describe('putObjectJob with path', () => { + beforeEach(() => { + mockPutObject.mockClear(); + }); + test.each([ { path: testPath, @@ -180,7 +191,8 @@ describe('putObjectJob with path', () => { metadata: { key: 'value' }, size: undefined, }); - expect(mockPutObject).toHaveBeenCalledWith( + expect(mockPutObject).toHaveBeenCalledTimes(1); + await expect(mockPutObject).toBeLastCalledWithConfigAndInput( { credentials, region: 'region', diff --git a/packages/storage/__tests__/providers/s3/apis/utils/resolveS3ConfigAndInput.test.ts b/packages/storage/__tests__/providers/s3/apis/utils/resolveS3ConfigAndInput.test.ts index 5db56bed1ed..e26cb63b6c7 100644 --- a/packages/storage/__tests__/providers/s3/apis/utils/resolveS3ConfigAndInput.test.ts +++ b/packages/storage/__tests__/providers/s3/apis/utils/resolveS3ConfigAndInput.test.ts @@ -53,20 +53,27 @@ describe('resolveS3ConfigAndInput', () => { }, }); - it('should call fetchAuthSession with forceRefresh false for credentials and identityId', async () => { + it('should call fetchAuthSession for credentials and identityId', async () => { + expect.assertions(1); await resolveS3ConfigAndInput(Amplify, {}); - expect(mockFetchAuthSession).toHaveBeenCalledWith({ - forceRefresh: false, - }); + expect(mockFetchAuthSession).toHaveBeenCalled(); }); it('should throw if credentials are not available', async () => { - mockFetchAuthSession.mockResolvedValueOnce({ + expect.assertions(1); + mockFetchAuthSession.mockResolvedValue({ identityId: targetIdentityId, }); - await expect(resolveS3ConfigAndInput(Amplify, {})).rejects.toMatchObject( - validationErrorMap[StorageValidationErrorCode.NoCredentials], - ); + const { + s3Config: { credentials: credentialsProvider }, + } = await resolveS3ConfigAndInput(Amplify, {}); + if (typeof credentialsProvider === 'function') { + await expect(credentialsProvider()).rejects.toMatchObject( + validationErrorMap[StorageValidationErrorCode.NoCredentials], + ); + } else { + fail('Expect credentials to be a function'); + } }); it('should throw if identityId is not available', async () => { diff --git a/packages/storage/__tests__/providers/s3/utils/client/S3/cases/shared.ts b/packages/storage/__tests__/providers/s3/utils/client/S3/cases/shared.ts index 6a1d8431825..76414b86d8f 100644 --- a/packages/storage/__tests__/providers/s3/utils/client/S3/cases/shared.ts +++ b/packages/storage/__tests__/providers/s3/utils/client/S3/cases/shared.ts @@ -17,10 +17,18 @@ export const expectedMetadata = { httpStatusCode: 200, }; +const region = 'us-east-1'; +const staticCredentials = { + accessKeyId: 'key', + secretAccessKey: 'secret', +}; + export const defaultConfig = { - region: 'us-east-1', - credentials: { - accessKeyId: 'key', - secretAccessKey: 'secret', - }, + region, + credentials: async () => staticCredentials, +}; + +export const defaultConfigWithStaticCredentials = { + region, + credentials: staticCredentials, }; diff --git a/packages/storage/__tests__/providers/s3/utils/client/S3/getPresignedGetObjectUrl.test.ts b/packages/storage/__tests__/providers/s3/utils/client/S3/getPresignedGetObjectUrl.test.ts index 7fbfcdeb3a1..93bd3963606 100644 --- a/packages/storage/__tests__/providers/s3/utils/client/S3/getPresignedGetObjectUrl.test.ts +++ b/packages/storage/__tests__/providers/s3/utils/client/S3/getPresignedGetObjectUrl.test.ts @@ -5,7 +5,7 @@ import { presignUrl } from '@aws-amplify/core/internals/aws-client-utils'; import { getPresignedGetObjectUrl } from '../../../../../../src/providers/s3/utils/client'; -import { defaultConfig } from './cases/shared'; +import { defaultConfigWithStaticCredentials } from './cases/shared'; jest.mock('@aws-amplify/core/internals/aws-client-utils', () => { const original = jest.requireActual( @@ -25,8 +25,8 @@ describe('serializeGetObjectRequest', () => { it('should return get object API request', async () => { const actual = await getPresignedGetObjectUrl( { - ...defaultConfig, - signingRegion: defaultConfig.region, + ...defaultConfigWithStaticCredentials, + signingRegion: defaultConfigWithStaticCredentials.region, signingService: 's3', expiration: 900, userAgentValue: 'UA', @@ -38,7 +38,7 @@ describe('serializeGetObjectRequest', () => { ); const actualUrl = actual; expect(actualUrl.hostname).toEqual( - `bucket.s3.${defaultConfig.region}.amazonaws.com`, + `bucket.s3.${defaultConfigWithStaticCredentials.region}.amazonaws.com`, ); expect(actualUrl.pathname).toEqual('/key'); expect(actualUrl.searchParams.get('X-Amz-Expires')).toEqual('900'); @@ -51,8 +51,8 @@ describe('serializeGetObjectRequest', () => { it('should call presignUrl with uriEscapePath param set to false', async () => { await getPresignedGetObjectUrl( { - ...defaultConfig, - signingRegion: defaultConfig.region, + ...defaultConfigWithStaticCredentials, + signingRegion: defaultConfigWithStaticCredentials.region, signingService: 's3', expiration: 900, userAgentValue: 'UA', diff --git a/packages/storage/src/providers/s3/apis/internal/getUrl.ts b/packages/storage/src/providers/s3/apis/internal/getUrl.ts index a2de5d3f770..4f866ef80b3 100644 --- a/packages/storage/src/providers/s3/apis/internal/getUrl.ts +++ b/packages/storage/src/providers/s3/apis/internal/getUrl.ts @@ -46,7 +46,11 @@ export const getUrl = async ( let urlExpirationInSec = getUrlOptions?.expiresIn ?? DEFAULT_PRESIGN_EXPIRATION; - const awsCredExpiration = s3Config.credentials?.expiration; + const resolvedCredential = + typeof s3Config.credentials === 'function' + ? await s3Config.credentials() + : s3Config.credentials; + const awsCredExpiration = resolvedCredential.expiration; if (awsCredExpiration) { const awsCredExpirationInSec = Math.floor( (awsCredExpiration.getTime() - Date.now()) / 1000, @@ -64,6 +68,7 @@ export const getUrl = async ( url: await getPresignedGetObjectUrl( { ...s3Config, + credentials: resolvedCredential, expiration: urlExpirationInSec, }, { diff --git a/packages/storage/src/providers/s3/types/options.ts b/packages/storage/src/providers/s3/types/options.ts index 4d0af341f52..b2b7dfd0ddc 100644 --- a/packages/storage/src/providers/s3/types/options.ts +++ b/packages/storage/src/providers/s3/types/options.ts @@ -2,7 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 import { StorageAccessLevel } from '@aws-amplify/core'; -import { AWSCredentials } from '@aws-amplify/core/internals/utils'; +import { SigningOptions } from '@aws-amplify/core/internals/aws-client-utils'; import { TransferProgressEvent } from '../../../types'; import { @@ -176,9 +176,8 @@ export type CopyDestinationOptionsWithKey = WriteOptions & { * * @internal */ -export interface ResolvedS3Config { - region: string; - credentials: AWSCredentials; +export interface ResolvedS3Config + extends Pick { customEndpoint?: string; forcePathStyle?: boolean; useAccelerateEndpoint?: boolean; diff --git a/packages/storage/src/providers/s3/utils/resolveS3ConfigAndInput.ts b/packages/storage/src/providers/s3/utils/resolveS3ConfigAndInput.ts index 701c046d52f..ae7a185c93c 100644 --- a/packages/storage/src/providers/s3/utils/resolveS3ConfigAndInput.ts +++ b/packages/storage/src/providers/s3/utils/resolveS3ConfigAndInput.ts @@ -39,16 +39,29 @@ export const resolveS3ConfigAndInput = async ( amplify: AmplifyClassV6, apiOptions?: S3ApiOptions, ): Promise => { - // identityId is always cached in memory if forceRefresh is not set. So we can safely make calls here. - const { credentials, identityId } = await amplify.Auth.fetchAuthSession({ - forceRefresh: false, - }); - assertValidationError( - !!credentials, - StorageValidationErrorCode.NoCredentials, - ); + /** + * IdentityId is always cached in memory so we can safely make calls here. It + * should be stable even for unauthenticated users, regardless of credentials. + */ + const { identityId } = await amplify.Auth.fetchAuthSession(); assertValidationError(!!identityId, StorageValidationErrorCode.NoIdentityId); + /** + * A credentials provider function instead of a static credentials object is + * used because the long-running tasks like multipart upload may span over the + * credentials expiry. Auth.fetchAuthSession() automatically refreshes the + * credentials if they are expired. + */ + const credentialsProvider = async () => { + const { credentials } = await amplify.Auth.fetchAuthSession(); + assertValidationError( + !!credentials, + StorageValidationErrorCode.NoCredentials, + ); + + return credentials; + }; + const { bucket, region, dangerouslyConnectToHttpEndpointForTesting } = amplify.getConfig()?.Storage?.S3 ?? {}; assertValidationError(!!bucket, StorageValidationErrorCode.NoBucket); @@ -72,7 +85,7 @@ export const resolveS3ConfigAndInput = async ( return { s3Config: { - credentials, + credentials: credentialsProvider, region, useAccelerateEndpoint: apiOptions?.useAccelerateEndpoint, ...(dangerouslyConnectToHttpEndpointForTesting