diff --git a/packages/rate-limit-controller/src/RateLimitController.test.ts b/packages/rate-limit-controller/src/RateLimitController.test.ts index fc63790ad76..0e90163d34a 100644 --- a/packages/rate-limit-controller/src/RateLimitController.test.ts +++ b/packages/rate-limit-controller/src/RateLimitController.test.ts @@ -11,7 +11,14 @@ import { const name = 'RateLimitController'; const implementations = { - showNativeNotification: jest.fn(), + apiWithoutCustomLimit: { + method: jest.fn(), + }, + apiWithCustomLimit: { + method: jest.fn(), + rateLimitCount: 2, + rateLimitTimeout: 3000, + }, }; type RateLimitedApis = typeof implementations; @@ -56,7 +63,8 @@ describe('RateLimitController', () => { }); afterEach(() => { - implementations.showNativeNotification.mockClear(); + implementations.apiWithoutCustomLimit.method.mockClear(); + implementations.apiWithCustomLimit.method.mockClear(); jest.useRealTimers(); }); @@ -74,19 +82,19 @@ describe('RateLimitController', () => { await unrestricted.call( 'RateLimitController:call', origin, - 'showNativeNotification', + 'apiWithoutCustomLimit', origin, message, ), ).toBeUndefined(); - expect(implementations.showNativeNotification).toHaveBeenCalledWith( + expect(implementations.apiWithoutCustomLimit.method).toHaveBeenCalledWith( origin, message, ); }); - it('uses showNativeNotification to show a notification', async () => { + it('uses apiWithoutCustomLimit method', async () => { const messenger = getRestrictedMessenger(); const controller = new RateLimitController({ @@ -94,10 +102,10 @@ describe('RateLimitController', () => { messenger, }); expect( - await controller.call(origin, 'showNativeNotification', origin, message), + await controller.call(origin, 'apiWithoutCustomLimit', origin, message), ).toBeUndefined(); - expect(implementations.showNativeNotification).toHaveBeenCalledWith( + expect(implementations.apiWithoutCustomLimit.method).toHaveBeenCalledWith( origin, message, ); @@ -112,16 +120,41 @@ describe('RateLimitController', () => { }); expect( - await controller.call(origin, 'showNativeNotification', origin, message), + await controller.call(origin, 'apiWithoutCustomLimit', origin, message), ).toBeUndefined(); await expect( - controller.call(origin, 'showNativeNotification', origin, message), + controller.call(origin, 'apiWithoutCustomLimit', origin, message), ).rejects.toThrow( - `"showNativeNotification" is currently rate-limited. Please try again later`, + `"apiWithoutCustomLimit" is currently rate-limited. Please try again later`, ); - expect(implementations.showNativeNotification).toHaveBeenCalledTimes(1); - expect(implementations.showNativeNotification).toHaveBeenCalledWith( + + expect( + await controller.call(origin, 'apiWithCustomLimit', origin, message), + ).toBeUndefined(); + + expect( + await controller.call(origin, 'apiWithCustomLimit', origin, message), + ).toBeUndefined(); + + await expect( + controller.call(origin, 'apiWithCustomLimit', origin, message), + ).rejects.toThrow( + `"apiWithCustomLimit" is currently rate-limited. Please try again later`, + ); + + expect(implementations.apiWithoutCustomLimit.method).toHaveBeenCalledTimes( + 1, + ); + + expect(implementations.apiWithCustomLimit.method).toHaveBeenCalledTimes(2); + + expect(implementations.apiWithoutCustomLimit.method).toHaveBeenCalledWith( + origin, + message, + ); + + expect(implementations.apiWithCustomLimit.method).toHaveBeenCalledWith( origin, message, ); @@ -135,14 +168,37 @@ describe('RateLimitController', () => { rateLimitCount: 1, }); expect( - await controller.call(origin, 'showNativeNotification', origin, message), + await controller.call(origin, 'apiWithoutCustomLimit', origin, message), ).toBeUndefined(); jest.runAllTimers(); expect( - await controller.call(origin, 'showNativeNotification', origin, message), + await controller.call(origin, 'apiWithoutCustomLimit', origin, message), ).toBeUndefined(); - expect(implementations.showNativeNotification).toHaveBeenCalledTimes(2); - expect(implementations.showNativeNotification).toHaveBeenCalledWith( + + expect( + await controller.call(origin, 'apiWithCustomLimit', origin, message), + ).toBeUndefined(); + + expect( + await controller.call(origin, 'apiWithCustomLimit', origin, message), + ).toBeUndefined(); + + jest.runAllTimers(); + + expect( + await controller.call(origin, 'apiWithCustomLimit', origin, message), + ).toBeUndefined(); + + expect(implementations.apiWithoutCustomLimit.method).toHaveBeenCalledTimes( + 2, + ); + expect(implementations.apiWithoutCustomLimit.method).toHaveBeenCalledWith( + origin, + message, + ); + + expect(implementations.apiWithCustomLimit.method).toHaveBeenCalledTimes(3); + expect(implementations.apiWithCustomLimit.method).toHaveBeenCalledWith( origin, message, ); @@ -156,19 +212,19 @@ describe('RateLimitController', () => { rateLimitCount: 2, }); expect( - await controller.call(origin, 'showNativeNotification', origin, message), + await controller.call(origin, 'apiWithoutCustomLimit', origin, message), ).toBeUndefined(); jest.advanceTimersByTime(2500); expect( - await controller.call(origin, 'showNativeNotification', origin, message), + await controller.call(origin, 'apiWithoutCustomLimit', origin, message), ).toBeUndefined(); - expect(controller.state.requests.showNativeNotification[origin]).toBe(2); + expect(controller.state.requests.apiWithoutCustomLimit[origin]).toBe(2); jest.advanceTimersByTime(2500); - expect(controller.state.requests.showNativeNotification[origin]).toBe(0); + expect(controller.state.requests.apiWithoutCustomLimit[origin]).toBe(0); expect( - await controller.call(origin, 'showNativeNotification', origin, message), + await controller.call(origin, 'apiWithoutCustomLimit', origin, message), ).toBeUndefined(); jest.advanceTimersByTime(2500); - expect(controller.state.requests.showNativeNotification[origin]).toBe(1); + expect(controller.state.requests.apiWithoutCustomLimit[origin]).toBe(1); }); }); diff --git a/packages/rate-limit-controller/src/RateLimitController.ts b/packages/rate-limit-controller/src/RateLimitController.ts index 221ba8ec50d..0f5791324e0 100644 --- a/packages/rate-limit-controller/src/RateLimitController.ts +++ b/packages/rate-limit-controller/src/RateLimitController.ts @@ -5,12 +5,24 @@ import { RestrictedControllerMessenger, } from '@metamask/base-controller'; +/** + * @type RateLimitedApi + * @property method - The method that is rate-limited. + * @property rateLimitTimeout - The time window in which the rate limit is applied (in ms). + * @property rateLimitCount - The amount of calls an origin can make in the rate limit time window. + */ +export type RateLimitedApi = { + method: (...args: any[]) => any; + rateLimitTimeout?: number; + rateLimitCount?: number; +}; + /** * @type RateLimitState * @property requests - Object containing number of requests in a given interval for each origin and api type combination */ export type RateLimitState< - RateLimitedApis extends Record any>, + RateLimitedApis extends Record, > = { requests: Record>; }; @@ -18,32 +30,30 @@ export type RateLimitState< const name = 'RateLimitController'; export type RateLimitStateChange< - RateLimitedApis extends Record any>, + RateLimitedApis extends Record, > = { type: `${typeof name}:stateChange`; payload: [RateLimitState, Patch[]]; }; export type GetRateLimitState< - RateLimitedApis extends Record any>, + RateLimitedApis extends Record, > = { type: `${typeof name}:getState`; handler: () => RateLimitState; }; -export type CallApi< - RateLimitedApis extends Record any>, -> = { +export type CallApi> = { type: `${typeof name}:call`; handler: RateLimitController['call']; }; export type RateLimitControllerActions< - RateLimitedApis extends Record any>, + RateLimitedApis extends Record, > = GetRateLimitState | CallApi; export type RateLimitMessenger< - RateLimitedApis extends Record any>, + RateLimitedApis extends Record, > = RestrictedControllerMessenger< typeof name, RateLimitControllerActions, @@ -60,7 +70,7 @@ const metadata = { * Controller with logic for rate-limiting API endpoints per requesting origin. */ export class RateLimitController< - RateLimitedApis extends Record any>, + RateLimitedApis extends Record, > extends BaseController< typeof name, RateLimitState, @@ -116,7 +126,7 @@ export class RateLimitController< (( origin: string, type: keyof RateLimitedApis, - ...args: Parameters + ...args: Parameters ) => this.call(origin, type, ...args)) as any, ); } @@ -132,16 +142,16 @@ export class RateLimitController< async call( origin: string, type: ApiType, - ...args: Parameters - ): Promise> { + ...args: Parameters + ): Promise> { if (this.isRateLimited(type, origin)) { throw ethErrors.rpc.limitExceeded({ - message: `"${type}" is currently rate-limited. Please try again later.`, + message: `"${type.toString()}" is currently rate-limited. Please try again later.`, }); } this.recordRequest(type, origin); - const implementation = this.implementations[type]; + const implementation = this.implementations[type].method; if (!implementation) { throw new Error('Invalid api type'); @@ -158,7 +168,9 @@ export class RateLimitController< * @returns `true` if rate-limited, and `false` otherwise. */ private isRateLimited(api: keyof RateLimitedApis, origin: string) { - return this.state.requests[api][origin] >= this.rateLimitCount; + const rateLimitCount = + this.implementations[api].rateLimitCount ?? this.rateLimitCount; + return this.state.requests[api][origin] >= rateLimitCount; } /** @@ -168,15 +180,14 @@ export class RateLimitController< * @param origin - The origin trying to access the API. */ private recordRequest(api: keyof RateLimitedApis, origin: string) { + const rateLimitTimeout = + this.implementations[api].rateLimitTimeout ?? this.rateLimitTimeout; this.update((state) => { const previous = (state as any).requests[api][origin] ?? 0; (state as any).requests[api][origin] = previous + 1; if (previous === 0) { - setTimeout( - () => this.resetRequestCount(api, origin), - this.rateLimitTimeout, - ); + setTimeout(() => this.resetRequestCount(api, origin), rateLimitTimeout); } }); }