Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 78 additions & 22 deletions packages/rate-limit-controller/src/RateLimitController.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@ import {
const name = 'RateLimitController';

const implementations = {
showNativeNotification: jest.fn(),
apiWithoutCustomLimit: {
method: jest.fn(),
},
apiWithCustomLimit: {
method: jest.fn(),
rateLimitCount: 2,
rateLimitTimeout: 3000,
},
};

type RateLimitedApis = typeof implementations;
Expand Down Expand Up @@ -56,7 +63,8 @@ describe('RateLimitController', () => {
});

afterEach(() => {
implementations.showNativeNotification.mockClear();
implementations.apiWithoutCustomLimit.method.mockClear();
implementations.apiWithCustomLimit.method.mockClear();
jest.useRealTimers();
});

Expand All @@ -74,30 +82,30 @@ describe('RateLimitController', () => {
await unrestricted.call(
'RateLimitController:call',
origin,
'showNativeNotification',
'apiWithoutCustomLimit',
origin,
message,
),
).toBeUndefined();

expect(implementations.showNativeNotification).toHaveBeenCalledWith(
expect(implementations.apiWithoutCustomLimit.method).toHaveBeenCalledWith(
origin,
message,
);
});

it('uses showNativeNotification to show a notification', async () => {
it('uses apiWithoutCustomLimit method', async () => {
const messenger = getRestrictedMessenger();

const controller = new RateLimitController({
implementations,
messenger,
});
expect(
await controller.call(origin, 'showNativeNotification', origin, message),
await controller.call(origin, 'apiWithoutCustomLimit', origin, message),
).toBeUndefined();

expect(implementations.showNativeNotification).toHaveBeenCalledWith(
expect(implementations.apiWithoutCustomLimit.method).toHaveBeenCalledWith(
origin,
message,
);
Expand All @@ -112,16 +120,41 @@ describe('RateLimitController', () => {
});

expect(
await controller.call(origin, 'showNativeNotification', origin, message),
await controller.call(origin, 'apiWithoutCustomLimit', origin, message),
).toBeUndefined();

await expect(
controller.call(origin, 'showNativeNotification', origin, message),
controller.call(origin, 'apiWithoutCustomLimit', origin, message),
).rejects.toThrow(
`"showNativeNotification" is currently rate-limited. Please try again later`,
`"apiWithoutCustomLimit" is currently rate-limited. Please try again later`,
);
expect(implementations.showNativeNotification).toHaveBeenCalledTimes(1);
expect(implementations.showNativeNotification).toHaveBeenCalledWith(

expect(
await controller.call(origin, 'apiWithCustomLimit', origin, message),
).toBeUndefined();

expect(
await controller.call(origin, 'apiWithCustomLimit', origin, message),
).toBeUndefined();

await expect(
controller.call(origin, 'apiWithCustomLimit', origin, message),
).rejects.toThrow(
`"apiWithCustomLimit" is currently rate-limited. Please try again later`,
);

expect(implementations.apiWithoutCustomLimit.method).toHaveBeenCalledTimes(
1,
);

expect(implementations.apiWithCustomLimit.method).toHaveBeenCalledTimes(2);

expect(implementations.apiWithoutCustomLimit.method).toHaveBeenCalledWith(
origin,
message,
);

expect(implementations.apiWithCustomLimit.method).toHaveBeenCalledWith(
origin,
message,
);
Expand All @@ -135,14 +168,37 @@ describe('RateLimitController', () => {
rateLimitCount: 1,
});
expect(
await controller.call(origin, 'showNativeNotification', origin, message),
await controller.call(origin, 'apiWithoutCustomLimit', origin, message),
).toBeUndefined();
jest.runAllTimers();
expect(
await controller.call(origin, 'showNativeNotification', origin, message),
await controller.call(origin, 'apiWithoutCustomLimit', origin, message),
).toBeUndefined();
expect(implementations.showNativeNotification).toHaveBeenCalledTimes(2);
expect(implementations.showNativeNotification).toHaveBeenCalledWith(

expect(
await controller.call(origin, 'apiWithCustomLimit', origin, message),
).toBeUndefined();

expect(
await controller.call(origin, 'apiWithCustomLimit', origin, message),
).toBeUndefined();

jest.runAllTimers();

expect(
await controller.call(origin, 'apiWithCustomLimit', origin, message),
).toBeUndefined();

expect(implementations.apiWithoutCustomLimit.method).toHaveBeenCalledTimes(
2,
);
expect(implementations.apiWithoutCustomLimit.method).toHaveBeenCalledWith(
origin,
message,
);

expect(implementations.apiWithCustomLimit.method).toHaveBeenCalledTimes(3);
expect(implementations.apiWithCustomLimit.method).toHaveBeenCalledWith(
origin,
message,
);
Expand All @@ -156,19 +212,19 @@ describe('RateLimitController', () => {
rateLimitCount: 2,
});
expect(
await controller.call(origin, 'showNativeNotification', origin, message),
await controller.call(origin, 'apiWithoutCustomLimit', origin, message),
).toBeUndefined();
jest.advanceTimersByTime(2500);
expect(
await controller.call(origin, 'showNativeNotification', origin, message),
await controller.call(origin, 'apiWithoutCustomLimit', origin, message),
).toBeUndefined();
expect(controller.state.requests.showNativeNotification[origin]).toBe(2);
expect(controller.state.requests.apiWithoutCustomLimit[origin]).toBe(2);
jest.advanceTimersByTime(2500);
expect(controller.state.requests.showNativeNotification[origin]).toBe(0);
expect(controller.state.requests.apiWithoutCustomLimit[origin]).toBe(0);
expect(
await controller.call(origin, 'showNativeNotification', origin, message),
await controller.call(origin, 'apiWithoutCustomLimit', origin, message),
).toBeUndefined();
jest.advanceTimersByTime(2500);
expect(controller.state.requests.showNativeNotification[origin]).toBe(1);
expect(controller.state.requests.apiWithoutCustomLimit[origin]).toBe(1);
});
});
49 changes: 30 additions & 19 deletions packages/rate-limit-controller/src/RateLimitController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,45 +5,55 @@ import {
RestrictedControllerMessenger,
} from '@metamask/base-controller';

/**
* @type RateLimitedApi
* @property method - The method that is rate-limited.
* @property rateLimitTimeout - The time window in which the rate limit is applied (in ms).
* @property rateLimitCount - The amount of calls an origin can make in the rate limit time window.
*/
export type RateLimitedApi = {
method: (...args: any[]) => any;
rateLimitTimeout?: number;
rateLimitCount?: number;
};

/**
* @type RateLimitState
* @property requests - Object containing number of requests in a given interval for each origin and api type combination
*/
export type RateLimitState<
RateLimitedApis extends Record<string, (...args: any[]) => any>,
RateLimitedApis extends Record<string, RateLimitedApi>,
> = {
requests: Record<keyof RateLimitedApis, Record<string, number>>;
};

const name = 'RateLimitController';

export type RateLimitStateChange<
RateLimitedApis extends Record<string, (...args: any[]) => any>,
RateLimitedApis extends Record<string, RateLimitedApi>,
> = {
type: `${typeof name}:stateChange`;
payload: [RateLimitState<RateLimitedApis>, Patch[]];
};

export type GetRateLimitState<
RateLimitedApis extends Record<string, (...args: any[]) => any>,
RateLimitedApis extends Record<string, RateLimitedApi>,
> = {
type: `${typeof name}:getState`;
handler: () => RateLimitState<RateLimitedApis>;
};

export type CallApi<
RateLimitedApis extends Record<string, (...args: any[]) => any>,
> = {
export type CallApi<RateLimitedApis extends Record<string, RateLimitedApi>> = {
type: `${typeof name}:call`;
handler: RateLimitController<RateLimitedApis>['call'];
};

export type RateLimitControllerActions<
RateLimitedApis extends Record<string, (...args: any[]) => any>,
RateLimitedApis extends Record<string, RateLimitedApi>,
> = GetRateLimitState<RateLimitedApis> | CallApi<RateLimitedApis>;

export type RateLimitMessenger<
RateLimitedApis extends Record<string, (...args: any[]) => any>,
RateLimitedApis extends Record<string, RateLimitedApi>,
> = RestrictedControllerMessenger<
typeof name,
RateLimitControllerActions<RateLimitedApis>,
Expand All @@ -60,7 +70,7 @@ const metadata = {
* Controller with logic for rate-limiting API endpoints per requesting origin.
*/
export class RateLimitController<
RateLimitedApis extends Record<string, (...args: any[]) => any>,
RateLimitedApis extends Record<string, RateLimitedApi>,
> extends BaseController<
typeof name,
RateLimitState<RateLimitedApis>,
Expand Down Expand Up @@ -116,7 +126,7 @@ export class RateLimitController<
((
origin: string,
type: keyof RateLimitedApis,
...args: Parameters<RateLimitedApis[keyof RateLimitedApis]>
...args: Parameters<RateLimitedApis[keyof RateLimitedApis]['method']>
) => this.call(origin, type, ...args)) as any,
);
}
Expand All @@ -132,16 +142,16 @@ export class RateLimitController<
async call<ApiType extends keyof RateLimitedApis>(
origin: string,
type: ApiType,
...args: Parameters<RateLimitedApis[ApiType]>
): Promise<ReturnType<RateLimitedApis[ApiType]>> {
...args: Parameters<RateLimitedApis[ApiType]['method']>
): Promise<ReturnType<RateLimitedApis[ApiType]['method']>> {
if (this.isRateLimited(type, origin)) {
throw ethErrors.rpc.limitExceeded({
message: `"${type}" is currently rate-limited. Please try again later.`,
message: `"${type.toString()}" is currently rate-limited. Please try again later.`,
});
}
this.recordRequest(type, origin);

const implementation = this.implementations[type];
const implementation = this.implementations[type].method;

if (!implementation) {
throw new Error('Invalid api type');
Expand All @@ -158,7 +168,9 @@ export class RateLimitController<
* @returns `true` if rate-limited, and `false` otherwise.
*/
private isRateLimited(api: keyof RateLimitedApis, origin: string) {
return this.state.requests[api][origin] >= this.rateLimitCount;
const rateLimitCount =
this.implementations[api].rateLimitCount ?? this.rateLimitCount;
return this.state.requests[api][origin] >= rateLimitCount;
}

/**
Expand All @@ -168,15 +180,14 @@ export class RateLimitController<
* @param origin - The origin trying to access the API.
*/
private recordRequest(api: keyof RateLimitedApis, origin: string) {
const rateLimitTimeout =
this.implementations[api].rateLimitTimeout ?? this.rateLimitTimeout;
this.update((state) => {
const previous = (state as any).requests[api][origin] ?? 0;
(state as any).requests[api][origin] = previous + 1;

if (previous === 0) {
setTimeout(
() => this.resetRequestCount(api, origin),
this.rateLimitTimeout,
);
setTimeout(() => this.resetRequestCount(api, origin), rateLimitTimeout);
}
});
}
Expand Down