Skip to content

Commit

Permalink
chore(property-provider): refactor memoize to use arrow functions (#1281
Browse files Browse the repository at this point in the history
)
  • Loading branch information
trivikr committed Jun 22, 2020
1 parent 6f5ddfc commit 45edaed
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 105 deletions.
2 changes: 1 addition & 1 deletion packages/credential-provider-node/src/index.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ describe("defaultProvider", () => {

expect(await provider()).toEqual(creds);

expect(provider()).toBe(provider());
expect(provider()).toStrictEqual(provider());

expect(await provider()).toEqual(creds);
expect((fromEnv() as any).mock.calls.length).toBe(1);
Expand Down
153 changes: 100 additions & 53 deletions packages/property-provider/src/memoize.spec.ts
Original file line number Diff line number Diff line change
@@ -1,79 +1,126 @@
import { memoize } from "./memoize";
import { Provider } from "@aws-sdk/types";

describe("memoize", () => {
let provider: jest.Mock;
const mockReturn = "foo";
const repeatTimes = 10;

beforeEach(() => {
provider = jest.fn().mockResolvedValue(mockReturn);
});

afterEach(() => {
jest.clearAllMocks();
});

describe("static memoization", () => {
it("should cache the resolved provider", async () => {
const provider = jest.fn().mockResolvedValue("foo");
expect.assertions(repeatTimes * 2);

const memoized = memoize(provider);

expect(await memoized()).toEqual("foo");
expect(provider.mock.calls.length).toBe(1);
expect(await memoized()).toEqual("foo");
expect(provider.mock.calls.length).toBe(1);
for (const index in [...Array(repeatTimes).keys()]) {
expect(await memoized()).toStrictEqual(mockReturn);
expect(provider).toHaveBeenCalledTimes(1);
}
});

it("should always return the same promise", () => {
const provider = jest.fn().mockResolvedValue("foo");
expect.assertions(repeatTimes * 2);

const memoized = memoize(provider);
const result = memoized();

expect(memoized()).toBe(result);
for (const index in [...Array(repeatTimes).keys()]) {
expect(memoized()).toStrictEqual(result);
expect(provider).toHaveBeenCalledTimes(1);
}
});
});

describe("refreshing memoization", () => {
it("should not reinvoke the underlying provider while isExpired returns `false`", async () => {
const provider = jest.fn().mockResolvedValue("foo");
const isExpired = jest.fn().mockReturnValue(false);
const memoized = memoize(provider, isExpired);

const checkCount = 10;
for (let i = 0; i < checkCount; i++) {
expect(await memoized()).toBe("foo");
}
let isExpired: jest.Mock;
let requiresRefresh: jest.Mock;

expect(isExpired.mock.calls.length).toBe(checkCount);
expect(provider.mock.calls.length).toBe(1);
beforeEach(() => {
isExpired = jest.fn().mockReturnValue(true);
requiresRefresh = jest.fn().mockReturnValue(false);
});

it("should reinvoke the underlying provider when isExpired returns `true`", async () => {
const provider = jest.fn().mockResolvedValue("foo");
const isExpired = jest.fn().mockReturnValue(false);
const memoized = memoize(provider, isExpired);

const checkCount = 10;
for (let i = 0; i < checkCount; i++) {
expect(await memoized()).toBe("foo");
}

expect(isExpired.mock.calls.length).toBe(checkCount);
expect(provider.mock.calls.length).toBe(1);

isExpired.mockReturnValueOnce(true);
for (let i = 0; i < checkCount; i++) {
expect(await memoized()).toBe("foo");
}

expect(isExpired.mock.calls.length).toBe(checkCount * 2);
expect(provider.mock.calls.length).toBe(2);
describe("should not reinvoke the underlying provider while isExpired returns `false`", () => {
const isExpiredFalseTest = async (requiresRefresh?: any) => {
isExpired.mockReturnValue(false);
const memoized = memoize(provider, isExpired, requiresRefresh);

for (const index in [...Array(repeatTimes).keys()]) {
expect(await memoized()).toEqual(mockReturn);
}

expect(isExpired).toHaveBeenCalledTimes(repeatTimes);
if (requiresRefresh) {
expect(requiresRefresh).toHaveBeenCalledTimes(repeatTimes);
}
expect(provider).toHaveBeenCalledTimes(1);
};

it("when requiresRefresh is not passed", async () => {
return isExpiredFalseTest();
});

it("when requiresRefresh returns true", () => {
requiresRefresh.mockReturnValue(true);
return isExpiredFalseTest(requiresRefresh);
});
});

it("should return the same promise for invocations 2-infinity if `requiresRefresh` returns `false`", async () => {
const provider = jest.fn().mockResolvedValue("foo");
const isExpired = jest.fn().mockReturnValue(true);
const requiresRefresh = jest.fn().mockReturnValue(false);

const memoized = memoize(provider, isExpired, requiresRefresh);
expect(await memoized()).toBe("foo");
const set = new Set<Promise<string>>();

const checkCount = 10;
for (let i = 0; i < checkCount; i++) {
set.add(memoized());
}
describe("should reinvoke the underlying provider when isExpired returns `true`", () => {
const isExpiredTrueTest = async (requiresRefresh?: any) => {
const memoized = memoize(provider, isExpired, requiresRefresh);

for (const index in [...Array(repeatTimes).keys()]) {
expect(await memoized()).toEqual(mockReturn);
}

expect(isExpired).toHaveBeenCalledTimes(repeatTimes);
if (requiresRefresh) {
expect(requiresRefresh).toHaveBeenCalledTimes(repeatTimes);
}
expect(provider).toHaveBeenCalledTimes(repeatTimes + 1);
};

it("when requiresRefresh is not passed", () => {
return isExpiredTrueTest();
});

it("when requiresRefresh returns true", () => {
requiresRefresh.mockReturnValue(true);
return isExpiredTrueTest(requiresRefresh);
});
});

expect(set.size).toBe(1);
describe("should return the same promise for invocations 2-infinity if `requiresRefresh` returns `false`", () => {
const requiresRefreshFalseTest = async () => {
const memoized = memoize(provider, isExpired, requiresRefresh);
const result = memoized();
expect(await result).toBe(mockReturn);

for (const index in [...Array(repeatTimes).keys()]) {
expect(memoized()).toStrictEqual(result);
expect(provider).toHaveBeenCalledTimes(1);
}

expect(requiresRefresh).toHaveBeenCalledTimes(1);
expect(isExpired).not.toHaveBeenCalled();
};

it("when isExpired returns true", () => {
return requiresRefreshFalseTest();
});

it("when isExpired returns false", () => {
isExpired.mockReturnValue(false);
return requiresRefreshFalseTest();
});
});
});
});
101 changes: 50 additions & 51 deletions packages/property-provider/src/memoize.ts
Original file line number Diff line number Diff line change
@@ -1,48 +1,50 @@
import { Provider } from "@aws-sdk/types";

/**
*
* Decorates a provider function with either static memoization.
*
* To create a statically memoized provider, supply a provider as the only
* argument to this function. The provider will be invoked once, and all
* invocations of the provider returned by `memoize` will return the same
* promise object.
*
* @param provider The provider whose result should be cached indefinitely.
*/
export function memoize<T>(provider: Provider<T>): Provider<T>;
interface MemoizeOverload {
/**
*
* Decorates a provider function with either static memoization.
*
* To create a statically memoized provider, supply a provider as the only
* argument to this function. The provider will be invoked once, and all
* invocations of the provider returned by `memoize` will return the same
* promise object.
*
* @param provider The provider whose result should be cached indefinitely.
*/
<T>(provider: Provider<T>): Provider<T>;

/**
* Decorates a provider function with refreshing memoization.
*
* @param provider The provider whose result should be cached.
* @param isExpired A function that will evaluate the resolved value and
* determine if it is expired. For example, when
* memoizing AWS credential providers, this function
* should return `true` when the credential's
* expiration is in the past (or very near future) and
* `false` otherwise.
* @param requiresRefresh A function that will evaluate the resolved value and
* determine if it represents static value or one that
* will eventually need to be refreshed. For example,
* AWS credentials that have no defined expiration will
* never need to be refreshed, so this function would
* return `true` if the credentials resolved by the
* underlying provider had an expiration and `false`
* otherwise.
*/
export function memoize<T>(
provider: Provider<T>,
isExpired: (resolved: T) => boolean,
requiresRefresh?: (resolved: T) => boolean
): Provider<T>;
/**
* Decorates a provider function with refreshing memoization.
*
* @param provider The provider whose result should be cached.
* @param isExpired A function that will evaluate the resolved value and
* determine if it is expired. For example, when
* memoizing AWS credential providers, this function
* should return `true` when the credential's
* expiration is in the past (or very near future) and
* `false` otherwise.
* @param requiresRefresh A function that will evaluate the resolved value and
* determine if it represents static value or one that
* will eventually need to be refreshed. For example,
* AWS credentials that have no defined expiration will
* never need to be refreshed, so this function would
* return `true` if the credentials resolved by the
* underlying provider had an expiration and `false`
* otherwise.
*/
<T>(
provider: Provider<T>,
isExpired: (resolved: T) => boolean,
requiresRefresh?: (resolved: T) => boolean
): Provider<T>;
}

export function memoize<T>(
export const memoize: MemoizeOverload = <T>(
provider: Provider<T>,
isExpired?: (resolved: T) => boolean,
requiresRefresh?: (resolved: T) => boolean
): Provider<T> {
): Provider<T> => {
if (isExpired === undefined) {
// This is a static memoization; no need to incorporate refreshing
const result = provider();
Expand All @@ -52,22 +54,19 @@ export function memoize<T>(
let result = provider();
let isConstant: boolean = false;

return () => {
return async () => {
if (isConstant) {
return result;
}

return result.then(resolved => {
if (requiresRefresh && !requiresRefresh(resolved)) {
isConstant = true;
return resolved;
}

if (isExpired(resolved)) {
return (result = provider());
}

const resolved = await result;
if (requiresRefresh && !requiresRefresh(resolved)) {
isConstant = true;
return resolved;
});
}
if (isExpired(resolved)) {
return (result = provider());
}
return resolved;
};
}
};

0 comments on commit 45edaed

Please sign in to comment.