Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 63 additions & 13 deletions packages/approval-controller/src/ApprovalController.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,30 @@ describe('approval controller', () => {
}),
).toThrow(getOriginTypeCollisionError('bar.baz', 'myType'));
});

it('does not throw on origin and type collision if type excluded', () => {
approvalController = new ApprovalController({
messenger: getRestrictedMessenger(),
showApprovalRequest,
typesExcludedFromRateLimiting: ['myType'],
});

expect(() =>
approvalController.add({
id: 'foo',
origin: 'bar.baz',
type: 'myType',
}),
).not.toThrow();

expect(() =>
approvalController.add({
id: 'foo1',
origin: 'bar.baz',
type: 'myType',
}),
).not.toThrow();
});
});

// otherwise tested by 'add' above
Expand Down Expand Up @@ -406,6 +430,21 @@ describe('approval controller', () => {
approvalController.getApprovalCount({ type: 'type3' }),
).toStrictEqual(0);
});

it('gets the count when specifying origin and type with type excluded from rate limiting', () => {
approvalController = new ApprovalController({
messenger: getRestrictedMessenger(),
showApprovalRequest: sinon.spy(),
typesExcludedFromRateLimiting: [TYPE],
});

addWithCatch({ id: '1', origin: 'origin1', type: TYPE });
addWithCatch({ id: '2', origin: 'origin1', type: TYPE });

expect(
approvalController.getApprovalCount({ origin: 'origin1', type: TYPE }),
).toStrictEqual(2);
});
});

describe('getTotalApprovalCount', () => {
Expand Down Expand Up @@ -434,6 +473,30 @@ describe('approval controller', () => {
approvalController.clear(new EthereumRpcError(1, 'clear'));
expect(approvalController.getTotalApprovalCount()).toStrictEqual(0);
});

it('gets the total approval count with type excluded from rate limiting', () => {
const approvalController = new ApprovalController({
messenger: getRestrictedMessenger(),
showApprovalRequest: sinon.spy(),
typesExcludedFromRateLimiting: ['type0'],
});
expect(approvalController.getTotalApprovalCount()).toStrictEqual(0);

const addWithCatch = (args: any) =>
approvalController.add(args).catch(() => undefined);
Comment thread
matthewwalsh0 marked this conversation as resolved.

addWithCatch({ id: '1', origin: 'origin1', type: 'type0' });
expect(approvalController.getTotalApprovalCount()).toStrictEqual(1);

addWithCatch({ id: '2', origin: 'origin1', type: 'type0' });
expect(approvalController.getTotalApprovalCount()).toStrictEqual(2);

approvalController.reject('2', new Error('foo'));
expect(approvalController.getTotalApprovalCount()).toStrictEqual(1);

approvalController.clear(new EthereumRpcError(1, 'clear'));
expect(approvalController.getTotalApprovalCount()).toStrictEqual(0);
});
});

describe('has', () => {
Expand Down Expand Up @@ -841,19 +904,6 @@ describe('approval controller', () => {
});
});

// TODO: Stop using private methods in tests
describe('_isEmptyOrigin', () => {
it('handles non-existing origin', () => {
const approvalController = new ApprovalController({
messenger: getRestrictedMessenger(),
showApprovalRequest: sinon.spy(),
});
expect(() =>
(approvalController as any)._isEmptyOrigin('kaplar'),
).not.toThrow();
});
});

describe('actions', () => {
it('addApprovalRequest: shouldShowRequest = true', async () => {
const messenger = new ControllerMessenger<
Expand Down
52 changes: 30 additions & 22 deletions packages/approval-controller/src/ApprovalController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ type ApprovalControllerOptions = {
messenger: ApprovalControllerMessenger;
showApprovalRequest: ShowApprovalRequest;
state?: Partial<ApprovalControllerState>;
typesExcludedFromRateLimiting?: string[];
};

/**
Expand All @@ -174,10 +175,12 @@ export class ApprovalController extends BaseControllerV2<
> {
private _approvals: Map<string, ApprovalCallbacks>;

private _origins: Map<string, Set<string>>;
private _origins: Map<string, Map<string, number>>;

private _showApprovalRequest: () => void;

private _typesExcludedFromRateLimiting: string[];

/**
* Construct an Approval controller.
*
Expand All @@ -186,11 +189,13 @@ export class ApprovalController extends BaseControllerV2<
* the request can be displayed to the user.
* @param options.messenger - The restricted controller messenger for the Approval controller.
* @param options.state - The initial controller state.
* @param options.typesExcludedFromRateLimiting - Array of aproval types which allow multiple pending approval requests from the same origin.
*/
constructor({
messenger,
showApprovalRequest,
state = {},
typesExcludedFromRateLimiting = [],
}: ApprovalControllerOptions) {
super({
name: controllerName,
Expand All @@ -202,6 +207,7 @@ export class ApprovalController extends BaseControllerV2<
this._approvals = new Map();
this._origins = new Map();
this._showApprovalRequest = showApprovalRequest;
this._typesExcludedFromRateLimiting = typesExcludedFromRateLimiting;
this.registerMessageHandlers();
}

Expand Down Expand Up @@ -333,11 +339,13 @@ export class ApprovalController extends BaseControllerV2<
const { origin, type: _type } = opts;

if (origin && _type) {
return Number(Boolean(this._origins.get(origin)?.has(_type)));
return this._origins.get(origin)?.get(_type) || 0;
}

if (origin) {
return this._origins.get(origin)?.size || 0;
return Array.from(
(this._origins.get(origin) || new Map()).values(),
).reduce((total, value) => total + value, 0);
}

// Only "type" was specified
Expand Down Expand Up @@ -395,7 +403,7 @@ export class ApprovalController extends BaseControllerV2<

// Check origin and type pair if type also specified
if (_type) {
return Boolean(this._origins.get(origin)?.has(_type));
return Boolean(this._origins.get(origin)?.get(_type));
}
return this._origins.has(origin);
}
Expand Down Expand Up @@ -487,7 +495,10 @@ export class ApprovalController extends BaseControllerV2<
): Promise<unknown> {
this._validateAddParams(id, origin, type, requestData, requestState);

if (this._origins.get(origin)?.has(type)) {
if (
!this._typesExcludedFromRateLimiting.includes(type) &&
this.has({ origin, type })
) {
throw ethErrors.rpc.resourceUnavailable(
getAlreadyPendingMessage(origin, type),
);
Expand Down Expand Up @@ -551,12 +562,15 @@ export class ApprovalController extends BaseControllerV2<
* @param type - The type associated with the approval request.
*/
private _addPendingApprovalOrigin(origin: string, type: string): void {
const originSet = this._origins.get(origin) || new Set();
originSet.add(type);
let originMap = this._origins.get(origin);

if (!this._origins.has(origin)) {
this._origins.set(origin, originSet);
if (!originMap) {
originMap = new Map();
this._origins.set(origin, originMap);
}

const currentValue = originMap.get(type) || 0;
originMap.set(type, currentValue + 1);
}

/**
Expand Down Expand Up @@ -610,9 +624,14 @@ export class ApprovalController extends BaseControllerV2<
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const { origin, type } = this.state.pendingApprovals[id]!;

(this._origins.get(origin) as Set<string>).delete(type);
if (this._isEmptyOrigin(origin)) {
const originMap = this._origins.get(origin) as Map<string, number>;
const originTotalCount = this.getApprovalCount({ origin });
const originTypeCount = originMap.get(type) as number;

if (originTotalCount === 1) {
this._origins.delete(origin);
} else {
originMap.set(type, originTypeCount - 1);
}

this.update((draftState) => {
Expand Down Expand Up @@ -640,16 +659,5 @@ export class ApprovalController extends BaseControllerV2<
this._delete(id);
return callbacks;
}

/**
* Checks whether there are any approvals associated with the given
* origin.
*
* @param origin - The origin to check.
* @returns True if the origin has no approvals, false otherwise.
*/
private _isEmptyOrigin(origin: string): boolean {
return !this._origins.get(origin)?.size;
}
}
export default ApprovalController;