Skip to content

Commit

Permalink
feat(property-provider): memoize() supports force refresh (#3413)
Browse files Browse the repository at this point in the history
* feat(property-provider): memoize() supports force refresh

* chore(property-provider): update unit test
  • Loading branch information
AllanZhengYP committed Mar 16, 2022
1 parent 4fd26e4 commit a79f962
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 24 deletions.
4 changes: 2 additions & 2 deletions packages/credential-provider-node/src/defaultProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { fromSSO, FromSSOInit } from "@aws-sdk/credential-provider-sso";
import { fromTokenFile, FromTokenFileInit } from "@aws-sdk/credential-provider-web-identity";
import { chain, CredentialsProviderError, memoize } from "@aws-sdk/property-provider";
import { ENV_PROFILE, loadSharedConfigFiles } from "@aws-sdk/shared-ini-file-loader";
import { CredentialProvider } from "@aws-sdk/types";
import { Credentials, MemoizedProvider } from "@aws-sdk/types";

import { remoteProvider } from "./remoteProvider";

Expand Down Expand Up @@ -46,7 +46,7 @@ import { remoteProvider } from "./remoteProvider";
*/
export const defaultProvider = (
init: FromIniInit & RemoteProviderInit & FromProcessInit & FromSSOInit & FromTokenFileInit = {}
): CredentialProvider => {
): MemoizedProvider<Credentials> => {
const options = {
profile: process.env[ENV_PROFILE],
...init,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import { EndpointCache } from "@aws-sdk/endpoint-cache";
import { Credentials, Provider } from "@aws-sdk/types";
import { Credentials, MemoizedProvider, Provider } from "@aws-sdk/types";

export interface EndpointDiscoveryInputConfig {}

export interface PreviouslyResolved {
isCustomEndpoint: boolean;
credentials: Provider<Credentials>;
credentials: MemoizedProvider<Credentials>;
endpointDiscoveryEnabledProvider: Provider<boolean | undefined>;
}

Expand Down
3 changes: 2 additions & 1 deletion packages/middleware-sdk-ec2/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@ import {
InitializeHandlerOptions,
InitializeHandlerOutput,
InitializeMiddleware,
MemoizedProvider,
MetadataBearer,
Pluggable,
Provider,
} from "@aws-sdk/types";
import { formatUrl } from "@aws-sdk/util-format-url";

interface PreviouslyResolved {
credentials: Provider<Credentials>;
credentials: MemoizedProvider<Credentials>;
endpoint: Provider<Endpoint>;
region: Provider<string>;
sha256: HashConstructor;
Expand Down
3 changes: 2 additions & 1 deletion packages/middleware-sdk-rds/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
InitializeHandlerOptions,
InitializeHandlerOutput,
InitializeMiddleware,
MemoizedProvider,
MetadataBearer,
Pluggable,
Provider,
Expand All @@ -28,7 +29,7 @@ const sourceIdToCommandKeyMap: { [key: string]: string } = {
const version = "2014-10-31";

interface PreviouslyResolved {
credentials: Provider<Credentials>;
credentials: MemoizedProvider<Credentials>;
endpoint: Provider<Endpoint>;
region: Provider<string>;
sha256: HashConstructor;
Expand Down
13 changes: 9 additions & 4 deletions packages/middleware-signing/src/configurations.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
Credentials,
HashConstructor,
Logger,
MemoizedProvider,
Provider,
RegionInfo,
RegionInfoProvider,
Expand Down Expand Up @@ -75,7 +76,7 @@ export interface SigV4AuthInputConfig {
}

interface PreviouslyResolved {
credentialDefaultProvider: (input: any) => Provider<Credentials>;
credentialDefaultProvider: (input: any) => MemoizedProvider<Credentials>;
region: string | Provider<string>;
regionInfoProvider: RegionInfoProvider;
signingName?: string;
Expand All @@ -86,7 +87,7 @@ interface PreviouslyResolved {
}

interface SigV4PreviouslyResolved {
credentialDefaultProvider: (input: any) => Provider<Credentials>;
credentialDefaultProvider: (input: any) => MemoizedProvider<Credentials>;
region: string | Provider<string>;
signingName: string;
sha256: HashConstructor;
Expand All @@ -96,8 +97,10 @@ interface SigV4PreviouslyResolved {
export interface AwsAuthResolvedConfig {
/**
* Resolved value for input config {@link AwsAuthInputConfig.credentials}
* This provider MAY memoize the loaded credentials for certain period.
* See {@link MemoizedProvider} for more information.
*/
credentials: Provider<Credentials>;
credentials: MemoizedProvider<Credentials>;
/**
* Resolved value for input config {@link AwsAuthInputConfig.signer}
*/
Expand Down Expand Up @@ -211,7 +214,9 @@ const normalizeProvider = <T>(input: T | Provider<T>): Provider<T> => {
return input as Provider<T>;
};

const normalizeCredentialProvider = (credentials: Credentials | Provider<Credentials>): Provider<Credentials> => {
const normalizeCredentialProvider = (
credentials: Credentials | Provider<Credentials>
): MemoizedProvider<Credentials> => {
if (typeof credentials === "function") {
return memoize(
credentials,
Expand Down
48 changes: 45 additions & 3 deletions packages/property-provider/src/memoize.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,20 @@ describe("memoize", () => {
expect(await memoized()).toBe("Retry");
expect(provider).toBeCalledTimes(2);
});

it("should retry provider if forceRefresh parameter is used", async () => {
provider
.mockReset()
.mockResolvedValueOnce("1st")
.mockResolvedValueOnce("2nd")
.mockRejectedValueOnce("Should not call 3rd time");
const memoized = memoize(provider);
expect(await memoized()).toBe("1st");
expect(await memoized()).toBe("1st");
expect(await memoized({ forceRefresh: true })).toBe("2nd");
expect(await memoized()).toBe("2nd");
expect(provider).toBeCalledTimes(2);
});
});

describe("refreshing memoization", () => {
Expand Down Expand Up @@ -115,7 +129,27 @@ describe("memoize", () => {
});
});

describe("should return the same promise for invocations 2-infinity if `requiresRefresh` returns `false`", () => {
describe("when called with forceRefresh set to `true`", () => {
it("should reinvoke the underlying provider even if isExpired returns false", async () => {
const memoized = memoize(provider, isExpired, requiresRefresh);
isExpired.mockReturnValue(false);
for (const _ in [...Array(repeatTimes).keys()]) {
expect(await memoized({ forceRefresh: true })).toEqual(mockReturn);
}
expect(provider).toHaveBeenCalledTimes(repeatTimes);
});

it("should reinvoke the underlying provider even if requiresRefresh returns false", async () => {
const memoized = memoize(provider, isExpired, requiresRefresh);
requiresRefresh.mockReturnValue(false);
for (const _ in [...Array(repeatTimes).keys()]) {
expect(await memoized({ forceRefresh: true })).toEqual(mockReturn);
}
expect(provider).toHaveBeenCalledTimes(repeatTimes);
});
});

describe("when `requiresRefresh` returns `false`", () => {
const requiresRefreshFalseTest = async () => {
const memoized = memoize(provider, isExpired, requiresRefresh);
const result = memoized();
Expand All @@ -130,14 +164,22 @@ describe("memoize", () => {
expect(isExpired).not.toHaveBeenCalled();
};

it("when isExpired returns true", () => {
it("should return the same promise for invocations 2-infinity if isExpired returns true", () => {
return requiresRefreshFalseTest();
});

it("when isExpired returns false", () => {
it("should return the same promise for invocations 2-infinity if isExpired returns false", () => {
isExpired.mockReturnValue(false);
return requiresRefreshFalseTest();
});

it("should re-evaluate `requiresRefresh` after force refresh", async () => {
const memoized = memoize(provider, isExpired, requiresRefresh);
for (const _ in [...Array(repeatTimes).keys()]) {
expect(await memoized({ forceRefresh: true })).toStrictEqual(mockReturn);
}
expect(requiresRefresh).toBeCalledTimes(repeatTimes);
});
});

describe("should not make extra request for concurrent calls", () => {
Expand Down
22 changes: 11 additions & 11 deletions packages/property-provider/src/memoize.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Provider } from "@aws-sdk/types";
import { MemoizedProvider, Provider } from "@aws-sdk/types";

interface MemoizeOverload {
/**
Expand All @@ -12,7 +12,7 @@ interface MemoizeOverload {
*
* @param provider The provider whose result should be cached indefinitely.
*/
<T>(provider: Provider<T>): Provider<T>;
<T>(provider: Provider<T>): MemoizedProvider<T>;

/**
* Decorates a provider function with refreshing memoization.
Expand All @@ -37,17 +37,18 @@ interface MemoizeOverload {
provider: Provider<T>,
isExpired: (resolved: T) => boolean,
requiresRefresh?: (resolved: T) => boolean
): Provider<T>;
): MemoizedProvider<T>;
}

export const memoize: MemoizeOverload = <T>(
provider: Provider<T>,
isExpired?: (resolved: T) => boolean,
requiresRefresh?: (resolved: T) => boolean
): Provider<T> => {
): MemoizedProvider<T> => {
let resolved: T;
let pending: Promise<T> | undefined;
let hasResult: boolean;
let isConstant = false;
// Wrapper over supplied provider with side effect to handle concurrent invocation.
const coalesceProvider: Provider<T> = async () => {
if (!pending) {
Expand All @@ -56,26 +57,25 @@ export const memoize: MemoizeOverload = <T>(
try {
resolved = await pending;
hasResult = true;
isConstant = false;
} finally {
pending = undefined;
}
return resolved;
};

if (isExpired === undefined) {
// This is a static memoization; no need to incorporate refreshing
return async () => {
if (!hasResult) {
// This is a static memoization; no need to incorporate refreshing unless using forceRefresh;
return async (options) => {
if (!hasResult || options?.forceRefresh) {
resolved = await coalesceProvider();
}
return resolved;
};
}

let isConstant = false;

return async () => {
if (!hasResult) {
return async (options) => {
if (!hasResult || options?.forceRefresh) {
resolved = await coalesceProvider();
}
if (isConstant) {
Expand Down
18 changes: 18 additions & 0 deletions packages/types/src/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,24 @@ export interface Provider<T> {
(): Promise<T>;
}

/**
* A function that, when invoked, returns a promise that will be fulfilled with
* a value of type T. It memoizes the result from the previous invocation
* instead of calling the underlying resources every time.
*
* You can force the provider to refresh the memoized value by invoke the
* function with optional parameter hash with `forceRefresh` boolean key and
* value `true`.
*
* @example A function that reads credentials from IMDS service that could
* return expired credentials. The SDK will keep using the expired credentials
* until an unretryable service error requiring a force refresh of the
* credentials.
*/
export interface MemoizedProvider<T> {
(options?: { forceRefresh?: boolean }): Promise<T>;
}

/**
* A function that, given a request body, determines the
* length of the body. This is used to determine the Content-Length
Expand Down

0 comments on commit a79f962

Please sign in to comment.