diff --git a/packages/approval-controller/src/ApprovalController.test.ts b/packages/approval-controller/src/ApprovalController.test.ts index 3d03070c60..68a33be934 100644 --- a/packages/approval-controller/src/ApprovalController.test.ts +++ b/packages/approval-controller/src/ApprovalController.test.ts @@ -1,15 +1,21 @@ import { errorCodes, EthereumRpcError } from 'eth-rpc-errors'; -import * as sinon from 'sinon'; import { ControllerMessenger } from '@metamask/base-controller'; import { ApprovalController, ApprovalControllerActions, ApprovalControllerEvents, ApprovalControllerMessenger, + StartFlowOptions, } from './ApprovalController'; -import { ApprovalRequestNoResultSupportError } from './errors'; +import { + ApprovalRequestNoResultSupportError, + EndInvalidFlowError, + NoApprovalFlowsError, +} from './errors'; + +const PENDING_APPROVALS_STORE_KEY = 'pendingApprovals'; +const APPROVAL_FLOWS_STORE_KEY = 'approvalFlows'; -const STORE_KEY = 'pendingApprovals'; const TYPE = 'TYPE'; const ID_MOCK = 'TestId'; const ORIGIN_MOCK = 'TestOrigin'; @@ -40,26 +46,18 @@ function getRestrictedMessenger() { } describe('approval controller', () => { - beforeEach(() => { - sinon.useFakeTimers(1); - }); + let approvalController: ApprovalController; + let showApprovalRequest: jest.Mock; - afterEach(() => { - sinon.restore(); + beforeEach(() => { + showApprovalRequest = jest.fn(); + approvalController = new ApprovalController({ + messenger: getRestrictedMessenger(), + showApprovalRequest, + }); }); describe('add', () => { - let approvalController: ApprovalController; - let showApprovalRequest: sinon.SinonSpy; - - beforeEach(() => { - showApprovalRequest = sinon.spy(); - approvalController = new ApprovalController({ - messenger: getRestrictedMessenger(), - showApprovalRequest, - }); - }); - it('validates input', () => { expect(() => approvalController.add({ id: null, origin: 'bar.baz' } as any), @@ -124,13 +122,15 @@ describe('approval controller', () => { approvalController.has({ origin: 'bar.baz', type: TYPE }), ).toStrictEqual(true); - expect(approvalController.state[STORE_KEY]).toStrictEqual({ + expect( + approvalController.state[PENDING_APPROVALS_STORE_KEY], + ).toStrictEqual({ foo: { id: 'foo', origin: 'bar.baz', requestData: null, requestState: null, - time: 1, + time: expect.any(Number), type: TYPE, expectsResult: true, }, @@ -146,7 +146,9 @@ describe('approval controller', () => { }), ).not.toThrow(); - const id = Object.keys(approvalController.state[STORE_KEY])[0]; + const id = Object.keys( + approvalController.state[PENDING_APPROVALS_STORE_KEY], + )[0]; expect(id && typeof id === 'string').toStrictEqual(true); }); @@ -163,9 +165,9 @@ describe('approval controller', () => { expect(approvalController.has({ id: 'foo' })).toStrictEqual(true); expect(approvalController.has({ origin: 'bar.baz' })).toStrictEqual(true); expect(approvalController.has({ type: 'myType' })).toStrictEqual(true); - expect(approvalController.state[STORE_KEY].foo.requestData).toStrictEqual( - { foo: 'bar' }, - ); + expect( + approvalController.state[PENDING_APPROVALS_STORE_KEY].foo.requestData, + ).toStrictEqual({ foo: 'bar' }); }); it('adds correctly specified entry with request state', () => { @@ -182,7 +184,7 @@ describe('approval controller', () => { expect(approvalController.has({ origin: 'bar.baz' })).toStrictEqual(true); expect(approvalController.has({ type: 'myType' })).toStrictEqual(true); expect( - approvalController.state[STORE_KEY].foo.requestState, + approvalController.state[PENDING_APPROVALS_STORE_KEY].foo.requestState, ).toStrictEqual({ foo: 'bar' }); }); @@ -265,12 +267,6 @@ describe('approval controller', () => { // otherwise tested by 'add' above describe('addAndShowApprovalRequest', () => { it('addAndShowApprovalRequest', () => { - const showApprovalSpy = sinon.spy(); - const approvalController = new ApprovalController({ - messenger: getRestrictedMessenger(), - showApprovalRequest: showApprovalSpy, - }); - const result = approvalController.addAndShowApprovalRequest({ id: 'foo', origin: 'bar.baz', @@ -278,17 +274,12 @@ describe('approval controller', () => { requestData: { foo: 'bar' }, }); expect(result instanceof Promise).toStrictEqual(true); - expect(showApprovalSpy.calledOnce).toStrictEqual(true); + expect(showApprovalRequest).toHaveBeenCalledTimes(1); }); }); describe('get', () => { it('gets entry', () => { - const approvalController = new ApprovalController({ - messenger: getRestrictedMessenger(), - showApprovalRequest: sinon.spy(), - }); - approvalController.add({ id: 'foo', origin: 'bar.baz', @@ -302,17 +293,12 @@ describe('approval controller', () => { requestData: null, requestState: null, type: 'myType', - time: 1, + time: expect.any(Number), expectsResult: true, }); }); it('returns undefined for non-existing entry', () => { - const approvalController = new ApprovalController({ - messenger: getRestrictedMessenger(), - showApprovalRequest: sinon.spy(), - }); - approvalController.add({ id: 'foo', origin: 'bar.baz', type: 'type' }); expect(approvalController.get('fizz')).toBeUndefined(); @@ -324,15 +310,9 @@ describe('approval controller', () => { }); describe('getApprovalCount', () => { - let approvalController: ApprovalController; let addWithCatch: (args: any) => void; beforeEach(() => { - approvalController = new ApprovalController({ - messenger: getRestrictedMessenger(), - showApprovalRequest: sinon.spy(), - }); - addWithCatch = (args: any) => approvalController.add(args).catch(() => undefined); }); @@ -454,7 +434,7 @@ describe('approval controller', () => { it('gets the count when specifying origin and type with type excluded from rate limiting', () => { approvalController = new ApprovalController({ messenger: getRestrictedMessenger(), - showApprovalRequest: sinon.spy(), + showApprovalRequest, typesExcludedFromRateLimiting: [TYPE], }); @@ -469,10 +449,6 @@ describe('approval controller', () => { describe('getTotalApprovalCount', () => { it('gets the total approval count', () => { - const approvalController = new ApprovalController({ - messenger: getRestrictedMessenger(), - showApprovalRequest: sinon.spy(), - }); expect(approvalController.getTotalApprovalCount()).toStrictEqual(0); const addWithCatch = (args: any) => @@ -495,9 +471,9 @@ describe('approval controller', () => { }); it('gets the total approval count with type excluded from rate limiting', () => { - const approvalController = new ApprovalController({ + approvalController = new ApprovalController({ messenger: getRestrictedMessenger(), - showApprovalRequest: sinon.spy(), + showApprovalRequest, typesExcludedFromRateLimiting: ['type0'], }); expect(approvalController.getTotalApprovalCount()).toStrictEqual(0); @@ -520,15 +496,6 @@ describe('approval controller', () => { }); describe('has', () => { - let approvalController: ApprovalController; - - beforeEach(() => { - approvalController = new ApprovalController({ - messenger: getRestrictedMessenger(), - showApprovalRequest: sinon.spy(), - }); - }); - it('validates input', () => { expect(() => approvalController.has()).toThrow( getInvalidHasParamsError(), @@ -619,17 +586,12 @@ describe('approval controller', () => { }); describe('resolve', () => { - let approvalController: ApprovalController; let numDeletions: number; - let deleteSpy: sinon.SinonSpy; + let deleteSpy: jest.SpyInstance; beforeEach(() => { - approvalController = new ApprovalController({ - messenger: getRestrictedMessenger(), - showApprovalRequest: sinon.spy(), - }); // TODO: Stop using private methods in tests - deleteSpy = sinon.spy(approvalController as any, '_delete'); + deleteSpy = jest.spyOn(approvalController as any, '_delete'); numDeletions = 0; }); @@ -645,7 +607,7 @@ describe('approval controller', () => { const result = await approvalPromise; expect(result).toStrictEqual('success'); - expect(deleteSpy.callCount).toStrictEqual(numDeletions); + expect(deleteSpy).toHaveBeenCalledTimes(numDeletions); }); it('resolves multiple approval promises out of order', async () => { @@ -671,29 +633,24 @@ describe('approval controller', () => { result = await approvalPromise1; expect(result).toStrictEqual('success1'); - expect(deleteSpy.callCount).toStrictEqual(numDeletions); + expect(deleteSpy).toHaveBeenCalledTimes(numDeletions); }); it('throws on unknown id', () => { expect(() => approvalController.accept('foo')).toThrow( getIdNotFoundError('foo'), ); - expect(deleteSpy.callCount).toStrictEqual(numDeletions); + expect(deleteSpy).toHaveBeenCalledTimes(numDeletions); }); }); describe('reject', () => { - let approvalController: ApprovalController; let numDeletions: number; - let deleteSpy: sinon.SinonSpy; + let deleteSpy: jest.SpyInstance; beforeEach(() => { - approvalController = new ApprovalController({ - messenger: getRestrictedMessenger(), - showApprovalRequest: sinon.spy(), - }); // TODO: Stop using private methods in tests - deleteSpy = sinon.spy(approvalController as any, '_delete'); + deleteSpy = jest.spyOn(approvalController as any, '_delete'); numDeletions = 0; }); @@ -706,7 +663,7 @@ describe('approval controller', () => { }); approvalController.reject('foo', new Error('failure')); await expect(approvalPromise).rejects.toThrow('failure'); - expect(deleteSpy.callCount).toStrictEqual(numDeletions); + expect(deleteSpy).toHaveBeenCalledTimes(numDeletions); }); it('rejects multiple approval promises out of order', async () => { @@ -727,24 +684,19 @@ describe('approval controller', () => { approvalController.reject('foo1', new Error('failure1')); await expect(rejectionPromise2).rejects.toThrow('failure2'); await expect(rejectionPromise1).rejects.toThrow('failure1'); - expect(deleteSpy.callCount).toStrictEqual(numDeletions); + expect(deleteSpy).toHaveBeenCalledTimes(numDeletions); }); it('throws on unknown id', () => { expect(() => approvalController.reject('foo', new Error('bar'))).toThrow( getIdNotFoundError('foo'), ); - expect(deleteSpy.callCount).toStrictEqual(numDeletions); + expect(deleteSpy).toHaveBeenCalledTimes(numDeletions); }); }); describe('accept', () => { it('resolves accept promise when success callback is called', async () => { - const approvalController = new ApprovalController({ - messenger: getRestrictedMessenger(), - showApprovalRequest: sinon.spy(), - }); - const approvalPromise = approvalController.add({ id: ID_MOCK, origin: ORIGIN_MOCK, @@ -766,11 +718,6 @@ describe('approval controller', () => { }); it('rejects accept promise when error callback is called', async () => { - const approvalController = new ApprovalController({ - messenger: getRestrictedMessenger(), - showApprovalRequest: sinon.spy(), - }); - const approvalPromise = approvalController.add({ id: ID_MOCK, origin: ORIGIN_MOCK, @@ -792,11 +739,6 @@ describe('approval controller', () => { }); it('resolves request promise with empty result callbacks if accept does not wait for result', async () => { - const approvalController = new ApprovalController({ - messenger: getRestrictedMessenger(), - showApprovalRequest: sinon.spy(), - }); - const approvalPromise = approvalController.add({ id: ID_MOCK, origin: ORIGIN_MOCK, @@ -813,11 +755,6 @@ describe('approval controller', () => { }); it('throws if accept wants to wait but request does not expect result', async () => { - const approvalController = new ApprovalController({ - messenger: getRestrictedMessenger(), - showApprovalRequest: sinon.spy(), - }); - approvalController.add({ id: ID_MOCK, origin: ORIGIN_MOCK, @@ -834,11 +771,6 @@ describe('approval controller', () => { describe('accept and reject', () => { it('accepts and rejects multiple approval promises out of order', async () => { - const approvalController = new ApprovalController({ - messenger: getRestrictedMessenger(), - showApprovalRequest: sinon.spy(), - }); - const promise1 = approvalController.add({ id: 'foo1', origin: 'bar.baz', @@ -888,15 +820,6 @@ describe('approval controller', () => { }); describe('clear', () => { - let approvalController: ApprovalController; - - beforeEach(() => { - approvalController = new ApprovalController({ - messenger: getRestrictedMessenger(), - showApprovalRequest: sinon.spy(), - }); - }); - it('does nothing if state is already empty', () => { expect(() => approvalController.clear(new EthereumRpcError(1, 'clear')), @@ -904,7 +827,7 @@ describe('approval controller', () => { }); it('deletes existing entries', async () => { - const rejectSpy = sinon.spy(approvalController, 'reject'); + const rejectSpy = jest.spyOn(approvalController, 'reject'); approvalController .add({ id: 'foo2', origin: 'bar.baz', type: 'myType' }) @@ -916,8 +839,10 @@ describe('approval controller', () => { approvalController.clear(new EthereumRpcError(1, 'clear')); - expect(approvalController.state[STORE_KEY]).toStrictEqual({}); - expect(rejectSpy.callCount).toStrictEqual(2); + expect( + approvalController.state[PENDING_APPROVALS_STORE_KEY], + ).toStrictEqual({}); + expect(rejectSpy).toHaveBeenCalledTimes(2); }); it('rejects existing entries with a caller-specified error', async () => { @@ -932,18 +857,19 @@ describe('approval controller', () => { new EthereumRpcError(1000, 'foo'), ); }); - }); - describe('updateRequestState', () => { - let approvalController: ApprovalController; + it('does not clear approval flows', async () => { + approvalController.startFlow(); - beforeEach(() => { - approvalController = new ApprovalController({ - messenger: getRestrictedMessenger(), - showApprovalRequest: sinon.spy(), - }); + approvalController.clear(new EthereumRpcError(1, 'clear')); + + expect(approvalController.state[APPROVAL_FLOWS_STORE_KEY]).toHaveLength( + 1, + ); }); + }); + describe('updateRequestState', () => { it('updates the request state of a given approval request', () => { approvalController .add({ @@ -978,15 +904,6 @@ describe('approval controller', () => { // they are heavily dependent upon it. // TODO: Stop using private methods in tests describe('_delete', () => { - let approvalController: ApprovalController; - - beforeEach(() => { - approvalController = new ApprovalController({ - messenger: getRestrictedMessenger(), - showApprovalRequest: sinon.spy(), - }); - }); - it('deletes entry', () => { approvalController.add({ id: 'foo', origin: 'bar.baz', type: 'type' }); @@ -996,7 +913,7 @@ describe('approval controller', () => { !approvalController.has({ id: 'foo' }) && !approvalController.has({ type: 'type' }) && !approvalController.has({ origin: 'bar.baz' }) && - !approvalController.state[STORE_KEY].foo, + !approvalController.state[PENDING_APPROVALS_STORE_KEY].foo, ).toStrictEqual(true); }); @@ -1024,13 +941,12 @@ describe('approval controller', () => { ApprovalControllerActions, ApprovalControllerEvents >(); - const showApprovalSpy = sinon.spy(); - const approvalController = new ApprovalController({ + approvalController = new ApprovalController({ messenger: messenger.getRestricted({ name: controllerName, }) as ApprovalControllerMessenger, - showApprovalRequest: showApprovalSpy, + showApprovalRequest, }); messenger.call( @@ -1038,7 +954,7 @@ describe('approval controller', () => { { id: 'foo', origin: 'bar.baz', type: TYPE }, true, ); - expect(showApprovalSpy.calledOnce).toStrictEqual(true); + expect(showApprovalRequest).toHaveBeenCalledTimes(1); expect(approvalController.has({ id: 'foo' })).toStrictEqual(true); }); @@ -1047,13 +963,12 @@ describe('approval controller', () => { ApprovalControllerActions, ApprovalControllerEvents >(); - const showApprovalSpy = sinon.spy(); - const approvalController = new ApprovalController({ + approvalController = new ApprovalController({ messenger: messenger.getRestricted({ name: controllerName, }) as ApprovalControllerMessenger, - showApprovalRequest: showApprovalSpy, + showApprovalRequest, }); messenger.call( @@ -1061,7 +976,7 @@ describe('approval controller', () => { { id: 'foo', origin: 'bar.baz', type: TYPE }, false, ); - expect(showApprovalSpy.notCalled).toStrictEqual(true); + expect(showApprovalRequest).toHaveBeenCalledTimes(0); expect(approvalController.has({ id: 'foo' })).toStrictEqual(true); }); @@ -1071,11 +986,11 @@ describe('approval controller', () => { ApprovalControllerEvents >(); - const approvalController = new ApprovalController({ + approvalController = new ApprovalController({ messenger: messenger.getRestricted({ name: controllerName, }) as ApprovalControllerMessenger, - showApprovalRequest: sinon.spy(), + showApprovalRequest, }); approvalController.add({ @@ -1095,6 +1010,57 @@ describe('approval controller', () => { }); }); }); + + describe('startFlow', () => { + it.each([ + ['no options passed', undefined], + ['partial options passed', {}], + ['options passed', { id: 'id' }], + ])( + 'adds flow to state and calls showApprovalRequest with %s', + (_, approvalFlowOptions?: StartFlowOptions) => { + const result = approvalController.startFlow(approvalFlowOptions); + + const expectedFlow = { + id: approvalFlowOptions?.id ?? expect.any(String), + }; + expect(result).toStrictEqual(expectedFlow); + expect(showApprovalRequest).toHaveBeenCalledTimes(1); + expect(approvalController.state[APPROVAL_FLOWS_STORE_KEY]).toHaveLength( + 1, + ); + expect( + approvalController.state[APPROVAL_FLOWS_STORE_KEY][0], + ).toStrictEqual(expectedFlow); + }, + ); + }); + + describe('endFlow', () => { + it('fails to end flow if no flow exists', () => { + expect(() => approvalController.endFlow({ id: 'id' })).toThrow( + NoApprovalFlowsError, + ); + }); + + it('fails to end flow if id does not correspond the current flow', () => { + approvalController.startFlow({ id: 'id' }); + + expect(() => approvalController.endFlow({ id: 'wrong-id' })).toThrow( + EndInvalidFlowError, + ); + }); + + it('ends flow if id corresponds with the current flow', () => { + approvalController.startFlow({ id: 'id' }); + + approvalController.endFlow({ id: 'id' }); + + expect(approvalController.state[APPROVAL_FLOWS_STORE_KEY]).toHaveLength( + 0, + ); + }); + }); }); // helpers diff --git a/packages/approval-controller/src/ApprovalController.ts b/packages/approval-controller/src/ApprovalController.ts index 715771176d..f2ac8c6a78 100644 --- a/packages/approval-controller/src/ApprovalController.ts +++ b/packages/approval-controller/src/ApprovalController.ts @@ -5,10 +5,12 @@ import { BaseControllerV2, RestrictedControllerMessenger, } from '@metamask/base-controller'; -import { Json } from '@metamask/utils'; +import { Json, OptionalField } from '@metamask/utils'; import { ApprovalRequestNotFoundError, ApprovalRequestNoResultSupportError, + EndInvalidFlowError, + NoApprovalFlowsError, } from './errors'; const controllerName = 'ApprovalController'; @@ -66,14 +68,22 @@ export type ApprovalRequest = { type ShowApprovalRequest = () => void | Promise; +type ApprovalFlow = { + id: string; +}; + +export type ApprovalFlowState = ApprovalFlow; + export type ApprovalControllerState = { pendingApprovals: Record>>; pendingApprovalCount: number; + approvalFlows: ApprovalFlowState[]; }; const stateMetadata = { pendingApprovals: { persist: false, anonymous: true }, pendingApprovalCount: { persist: false, anonymous: false }, + approvalFlows: { persist: false, anonymous: false }, }; const getAlreadyPendingMessage = (origin: string, type: string) => @@ -83,6 +93,7 @@ const getDefaultState = (): ApprovalControllerState => { return { pendingApprovals: {}, pendingApprovalCount: 0, + approvalFlows: [], }; }; @@ -183,6 +194,22 @@ export type AddResult = { resultCallbacks?: AcceptResultCallbacks; }; +export type StartFlowOptions = OptionalField; + +export type ApprovalFlowStartResult = ApprovalFlow; + +export type EndFlowOptions = Pick; + +export type StartFlow = { + type: `${typeof controllerName}:startFlow`; + handler: ApprovalController['startFlow']; +}; + +export type EndFlow = { + type: `${typeof controllerName}:endFlow`; + handler: ApprovalController['endFlow']; +}; + export type ApprovalControllerActions = | GetApprovalsState | ClearApprovalRequests @@ -190,7 +217,9 @@ export type ApprovalControllerActions = | HasApprovalRequest | AcceptRequest | RejectRequest - | UpdateRequestState; + | UpdateRequestState + | StartFlow + | EndFlow; export type ApprovalStateChange = { type: `${typeof controllerName}:stateChange`; @@ -305,6 +334,16 @@ export class ApprovalController extends BaseControllerV2< `${controllerName}:updateRequestState` as const, this.updateRequestState.bind(this), ); + + this.messagingSystem.registerActionHandler( + `${controllerName}:startFlow` as const, + this.startFlow.bind(this), + ); + + this.messagingSystem.registerActionHandler( + `${controllerName}:endFlow` as const, + this.endFlow.bind(this), + ); } /** @@ -590,7 +629,10 @@ export class ApprovalController extends BaseControllerV2< this.reject(id, rejectionError); } this._origins.clear(); - this.update(() => getDefaultState()); + this.update((draftState) => { + draftState.pendingApprovals = {}; + draftState.pendingApprovalCount = 0; + }); } /** @@ -612,6 +654,51 @@ export class ApprovalController extends BaseControllerV2< }); } + /** + * Starts a new approval flow. + * + * @param opts - Options bag. + * @param opts.id - The id of the approval flow. + * @returns The object containing the approval flow id. + */ + startFlow(opts: StartFlowOptions = {}): ApprovalFlowStartResult { + const id = opts.id ?? nanoid(); + const finalOptions = { id }; + + this.update((draftState) => { + draftState.approvalFlows.push(finalOptions); + }); + + this._showApprovalRequest(); + + return { id }; + } + + /** + * Ends the current approval flow. + * + * @param opts - Options bag. + * @param opts.id - The id of the approval flow that will be finished. + */ + endFlow({ id }: EndFlowOptions) { + if (!this.state.approvalFlows.length) { + throw new NoApprovalFlowsError(); + } + + const currentFlow = this.state.approvalFlows.slice(-1)[0]; + + if (id !== currentFlow.id) { + throw new EndInvalidFlowError( + id, + this.state.approvalFlows.map((flow) => flow.id), + ); + } + + this.update((draftState) => { + draftState.approvalFlows.pop(); + }); + } + /** * Implementation of add operation. * diff --git a/packages/approval-controller/src/errors.ts b/packages/approval-controller/src/errors.ts index a7a213e3df..c7bf22da35 100644 --- a/packages/approval-controller/src/errors.ts +++ b/packages/approval-controller/src/errors.ts @@ -11,3 +11,19 @@ export class ApprovalRequestNoResultSupportError extends Error { ); } } + +export class NoApprovalFlowsError extends Error { + constructor() { + super(`No approval flows found.`); + } +} + +export class EndInvalidFlowError extends Error { + constructor(id: string, flowIds: string[]) { + super( + `Attempted to end flow with id '${id}' which does not match current flow with id '${ + flowIds.slice(-1)[0] + }'. All Flows: ${flowIds.join(', ')}`, + ); + } +}