diff --git a/packages/assets-controllers/src/NftDetectionController.test.ts b/packages/assets-controllers/src/NftDetectionController.test.ts index e745804afc..bc7fc12068 100644 --- a/packages/assets-controllers/src/NftDetectionController.test.ts +++ b/packages/assets-controllers/src/NftDetectionController.test.ts @@ -16,6 +16,10 @@ type ApprovalActions = AddApprovalRequest; const controllerName = 'NftController' as const; +const flushPromises = () => { + return new Promise(jest.requireActual('timers').setImmediate); +}; + describe('NftDetectionController', () => { let nftDetection: NftDetectionController; let preferences: PreferencesController; @@ -294,19 +298,26 @@ describe('NftDetectionController', () => { testNftDetection.startPollingByNetworkClientId('mainnet', { address: '0x1', }); + await Promise.all([jest.advanceTimersByTime(0), flushPromises]); + expect(spy.mock.calls).toHaveLength(1); await Promise.all([ - jest.advanceTimersByTime(DEFAULT_INTERVAL), - Promise.resolve(), + jest.advanceTimersByTime(DEFAULT_INTERVAL / 2), + flushPromises(), ]); expect(spy.mock.calls).toHaveLength(1); await Promise.all([ - jest.advanceTimersByTime(DEFAULT_INTERVAL), - Promise.resolve(), + jest.advanceTimersByTime(DEFAULT_INTERVAL / 2), + flushPromises(), ]); expect(spy.mock.calls).toHaveLength(2); + await Promise.all([ + jest.advanceTimersByTime(DEFAULT_INTERVAL), + flushPromises(), + ]); expect(spy.mock.calls).toMatchObject([ ['mainnet', '0x1'], ['mainnet', '0x1'], + ['mainnet', '0x1'], ]); nftDetection.stopAllPolling(); jest.runOnlyPendingTimers(); diff --git a/packages/assets-controllers/src/TokenDetectionController.test.ts b/packages/assets-controllers/src/TokenDetectionController.test.ts index efdaf23b4a..c41165b50e 100644 --- a/packages/assets-controllers/src/TokenDetectionController.test.ts +++ b/packages/assets-controllers/src/TokenDetectionController.test.ts @@ -123,6 +123,10 @@ const setupTokenListController = ( return { tokenList, tokenListMessenger }; }; +const flushPromises = () => { + return new Promise(jest.requireActual('timers').setImmediate); +}; + describe('TokenDetectionController', () => { let tokenDetection: TokenDetectionController; let preferences: PreferencesController; @@ -611,7 +615,7 @@ describe('TokenDetectionController', () => { }); await Promise.all([ jest.advanceTimersByTime(DEFAULT_INTERVAL), - Promise.resolve(), + flushPromises(), ]); expect(spy.mock.calls).toMatchObject([ [{ networkClientId: 'mainnet', accountAddress: '0x1' }], diff --git a/packages/assets-controllers/src/TokenListController.test.ts b/packages/assets-controllers/src/TokenListController.test.ts index eafdddb2cc..e123f928fd 100644 --- a/packages/assets-controllers/src/TokenListController.test.ts +++ b/packages/assets-controllers/src/TokenListController.test.ts @@ -1368,20 +1368,28 @@ describe('TokenListController', () => { ); controller.startPollingByNetworkClientId('goerli'); - jest.advanceTimersByTime(pollingIntervalTime / 2); + jest.advanceTimersByTime(0); await flushPromises(); - expect(fetchTokenListByChainIdSpy).toHaveBeenCalledTimes(0); - jest.advanceTimersByTime(pollingIntervalTime / 2); - await flushPromises(); - expect(fetchTokenListByChainIdSpy).toHaveBeenCalledTimes(1); + await Promise.all([ + jest.advanceTimersByTime(pollingIntervalTime / 2), + flushPromises(), + ]); + expect(fetchTokenListByChainIdSpy).toHaveBeenCalledTimes(1); + await Promise.all([ + jest.advanceTimersByTime(pollingIntervalTime / 2), + jest.runOnlyPendingTimers(), + flushPromises(), + ]); + + expect(fetchTokenListByChainIdSpy).toHaveBeenCalledTimes(2); await Promise.all([ jest.advanceTimersByTime(pollingIntervalTime), flushPromises(), ]); await Promise.all([jest.runOnlyPendingTimers(), flushPromises()]); - expect(fetchTokenListByChainIdSpy).toHaveBeenCalledTimes(2); + expect(fetchTokenListByChainIdSpy).toHaveBeenCalledTimes(3); }); it('should update tokenList state and tokensChainsCache', async () => { @@ -1441,7 +1449,7 @@ describe('TokenListController', () => { expect(controller.state).toStrictEqual(startingState); // start polling for sepolia - await controller.startPollingByNetworkClientId('sepolia'); + const pollingToken = controller.startPollingByNetworkClientId('sepolia'); // wait a polling interval jest.advanceTimersByTime(pollingIntervalTime); await flushPromises(); @@ -1457,10 +1465,10 @@ describe('TokenListController', () => { data: sampleSepoliaTokensChainCache, }, }); + controller.stopPollingByPollingToken(pollingToken); + // start polling for binance - await controller.startPollingByNetworkClientId( - 'binance-network-client-id', - ); + controller.startPollingByNetworkClientId('binance-network-client-id'); jest.advanceTimersByTime(pollingIntervalTime); await flushPromises(); diff --git a/packages/gas-fee-controller/src/GasFeeController.test.ts b/packages/gas-fee-controller/src/GasFeeController.test.ts index 0f3bb97ea2..6cd2b3433a 100644 --- a/packages/gas-fee-controller/src/GasFeeController.test.ts +++ b/packages/gas-fee-controller/src/GasFeeController.test.ts @@ -983,7 +983,7 @@ describe('GasFeeController', () => { }); describe('polling (by networkClientId)', () => { - it('should call determineGasFeeCalculations (via _executePoll) with a URL that contains the chainId corresponding to the networkClientId after the interval passed via the constructor', async () => { + it('should call determineGasFeeCalculations (via _executePoll) with a URL that contains the chainId corresponding to the networkClientId immedaitely and after each interval passed via the constructor', async () => { const pollingInterval = 10000; await setupGasFeeController({ getIsEIP1559Compatible: jest.fn().mockResolvedValue(false), @@ -1013,10 +1013,20 @@ describe('GasFeeController', () => { }); gasFeeController.startPollingByNetworkClientId('goerli'); + await clock.tickAsync(0); + expect(mockedDetermineGasFeeCalculations).toHaveBeenNthCalledWith( + 1, + expect.objectContaining({ + fetchGasEstimatesUrl: `https://some-eip-1559-endpoint/${convertHexToDecimal( + ChainId.goerli, + )}`, + }), + ); await clock.tickAsync(pollingInterval / 2); - expect(mockedDetermineGasFeeCalculations).not.toHaveBeenCalled(); + expect(mockedDetermineGasFeeCalculations).toHaveBeenCalledTimes(1); await clock.tickAsync(pollingInterval / 2); - expect(mockedDetermineGasFeeCalculations).toHaveBeenCalledWith( + expect(mockedDetermineGasFeeCalculations).toHaveBeenNthCalledWith( + 2, expect.objectContaining({ fetchGasEstimatesUrl: `https://some-eip-1559-endpoint/${convertHexToDecimal( ChainId.goerli, diff --git a/packages/polling-controller/src/PollingController.test.ts b/packages/polling-controller/src/PollingController.test.ts index 3a376c55b3..4f18520bd5 100644 --- a/packages/polling-controller/src/PollingController.test.ts +++ b/packages/polling-controller/src/PollingController.test.ts @@ -13,7 +13,7 @@ const createExecutePollMock = () => { describe('PollingController', () => { describe('start', () => { - it('should start polling if not polling', () => { + it('should start polling if not polling', async () => { jest.useFakeTimers(); class MyGasFeeController extends PollingController { @@ -28,13 +28,16 @@ describe('PollingController', () => { state: { foo: 'bar' }, }); controller.startPollingByNetworkClientId('mainnet'); + jest.advanceTimersByTime(0); + await Promise.resolve(); + expect(controller._executePoll).toHaveBeenCalledTimes(1); jest.advanceTimersByTime(TICK_TIME); controller.stopAllPolling(); - expect(controller._executePoll).toHaveBeenCalledTimes(1); + expect(controller._executePoll).toHaveBeenCalledTimes(2); }); }); describe('stop', () => { - it('should stop polling when called with a valid polling that was the only active pollingToken for a given networkClient', () => { + it('should stop polling when called with a valid polling that was the only active pollingToken for a given networkClient', async () => { jest.useFakeTimers(); class MyGasFeeController extends PollingController { _executePoll = createExecutePollMock(); @@ -48,10 +51,13 @@ describe('PollingController', () => { state: { foo: 'bar' }, }); const pollingToken = controller.startPollingByNetworkClientId('mainnet'); + jest.advanceTimersByTime(0); + await Promise.resolve(); + expect(controller._executePoll).toHaveBeenCalledTimes(1); jest.advanceTimersByTime(TICK_TIME); controller.stopPollingByPollingToken(pollingToken); jest.advanceTimersByTime(TICK_TIME); - expect(controller._executePoll).toHaveBeenCalledTimes(1); + expect(controller._executePoll).toHaveBeenCalledTimes(2); controller.stopAllPolling(); }); it('should not stop polling if called with one of multiple active polling tokens for a given networkClient', async () => { @@ -68,13 +74,16 @@ describe('PollingController', () => { state: { foo: 'bar' }, }); const pollingToken1 = controller.startPollingByNetworkClientId('mainnet'); + jest.advanceTimersByTime(0); + await Promise.resolve(); controller.startPollingByNetworkClientId('mainnet'); + expect(controller._executePoll).toHaveBeenCalledTimes(1); jest.advanceTimersByTime(TICK_TIME); await Promise.resolve(); controller.stopPollingByPollingToken(pollingToken1); jest.advanceTimersByTime(TICK_TIME); await Promise.resolve(); - expect(controller._executePoll).toHaveBeenCalledTimes(2); + expect(controller._executePoll).toHaveBeenCalledTimes(3); controller.stopAllPolling(); }); it('should error if no pollingToken is passed', () => { @@ -116,8 +125,8 @@ describe('PollingController', () => { controller.stopAllPolling(); }); }); - describe('poll', () => { - it('should call _executePoll if polling', async () => { + describe('startPollingByNetworkClientId', () => { + it('should call _executePoll immediately and on interval if polling', async () => { jest.useFakeTimers(); class MyGasFeeController extends PollingController { @@ -132,13 +141,16 @@ describe('PollingController', () => { state: { foo: 'bar' }, }); controller.startPollingByNetworkClientId('mainnet'); + jest.advanceTimersByTime(0); + await Promise.resolve(); + expect(controller._executePoll).toHaveBeenCalledTimes(1); jest.advanceTimersByTime(TICK_TIME); await Promise.resolve(); jest.advanceTimersByTime(TICK_TIME); await Promise.resolve(); - expect(controller._executePoll).toHaveBeenCalledTimes(2); + expect(controller._executePoll).toHaveBeenCalledTimes(3); }); - it('should continue calling _executePoll when start is called again with the same networkClientId', async () => { + it('should call _executePoll immediately once and continue calling _executePoll on interval when start is called again with the same networkClientId', async () => { jest.useFakeTimers(); class MyGasFeeController extends PollingController { @@ -153,12 +165,17 @@ describe('PollingController', () => { state: { foo: 'bar' }, }); controller.startPollingByNetworkClientId('mainnet'); + jest.advanceTimersByTime(0); + await Promise.resolve(); controller.startPollingByNetworkClientId('mainnet'); + jest.advanceTimersByTime(0); + await Promise.resolve(); + expect(controller._executePoll).toHaveBeenCalledTimes(1); jest.advanceTimersByTime(TICK_TIME); await Promise.resolve(); jest.advanceTimersByTime(TICK_TIME); await Promise.resolve(); - expect(controller._executePoll).toHaveBeenCalledTimes(2); + expect(controller._executePoll).toHaveBeenCalledTimes(3); controller.stopAllPolling(); }); it('should publish "pollingComplete" when stop is called', async () => { @@ -196,18 +213,15 @@ describe('PollingController', () => { name: 'PollingController', state: { foo: 'bar' }, }); - controller.setIntervalLength(TICK_TIME * 3); + controller.setIntervalLength(TICK_TIME); controller.startPollingByNetworkClientId('mainnet'); - jest.advanceTimersByTime(TICK_TIME); + jest.advanceTimersByTime(0); await Promise.resolve(); - expect(controller._executePoll).not.toHaveBeenCalled(); - jest.advanceTimersByTime(TICK_TIME); - await Promise.resolve(); - expect(controller._executePoll).not.toHaveBeenCalled(); - jest.advanceTimersByTime(TICK_TIME); + expect(controller._executePoll).toHaveBeenCalledTimes(1); + jest.advanceTimersByTime(TICK_TIME / 2); await Promise.resolve(); expect(controller._executePoll).toHaveBeenCalledTimes(1); - jest.advanceTimersByTime(TICK_TIME * 3); + jest.advanceTimersByTime(TICK_TIME / 2); await Promise.resolve(); expect(controller._executePoll).toHaveBeenCalledTimes(2); }); @@ -229,15 +243,33 @@ describe('PollingController', () => { address: '0x1', }); controller.startPollingByNetworkClientId('mainnet', { address: '0x2' }); + jest.advanceTimersByTime(0); + await Promise.resolve(); controller.startPollingByNetworkClientId('sepolia', { address: '0x2' }); + jest.advanceTimersByTime(0); + await Promise.resolve(); + expect(controller._executePoll.mock.calls).toMatchObject([ + ['mainnet', { address: '0x1' }], + ['mainnet', { address: '0x2' }], + ['sepolia', { address: '0x2' }], + ]); jest.advanceTimersByTime(TICK_TIME); await Promise.resolve(); - expect(controller._executePoll).toHaveBeenCalledTimes(3); + expect(controller._executePoll.mock.calls).toMatchObject([ + ['mainnet', { address: '0x1' }], + ['mainnet', { address: '0x2' }], + ['sepolia', { address: '0x2' }], + ['mainnet', { address: '0x1' }], + ['mainnet', { address: '0x2' }], + ['sepolia', { address: '0x2' }], + ]); controller.stopPollingByPollingToken(pollToken1); jest.advanceTimersByTime(TICK_TIME); await Promise.resolve(); - expect(controller._executePoll).toHaveBeenCalledTimes(5); expect(controller._executePoll.mock.calls).toMatchObject([ + ['mainnet', { address: '0x1' }], + ['mainnet', { address: '0x2' }], + ['sepolia', { address: '0x2' }], ['mainnet', { address: '0x1' }], ['mainnet', { address: '0x2' }], ['sepolia', { address: '0x2' }], @@ -261,12 +293,22 @@ describe('PollingController', () => { state: { foo: 'bar' }, }); controller.startPollingByNetworkClientId('mainnet'); + jest.advanceTimersByTime(0); + await Promise.resolve(); controller.startPollingByNetworkClientId('rinkeby'); + jest.advanceTimersByTime(0); + await Promise.resolve(); + expect(controller._executePoll.mock.calls).toMatchObject([ + ['mainnet', {}], + ['rinkeby', {}], + ]); jest.advanceTimersByTime(TICK_TIME); await Promise.resolve(); expect(controller._executePoll.mock.calls).toMatchObject([ ['mainnet', {}], ['rinkeby', {}], + ['mainnet', {}], + ['rinkeby', {}], ]); jest.advanceTimersByTime(TICK_TIME); await Promise.resolve(); @@ -275,6 +317,8 @@ describe('PollingController', () => { ['rinkeby', {}], ['mainnet', {}], ['rinkeby', {}], + ['mainnet', {}], + ['rinkeby', {}], ]); controller.stopAllPolling(); }); @@ -295,20 +339,34 @@ describe('PollingController', () => { }); controller.setIntervalLength(TICK_TIME * 2); controller.startPollingByNetworkClientId('mainnet'); + jest.advanceTimersByTime(0); + await Promise.resolve(); + expect(controller._executePoll.mock.calls).toMatchObject([ + ['mainnet', {}], + ]); jest.advanceTimersByTime(TICK_TIME); await Promise.resolve(); controller.startPollingByNetworkClientId('sepolia'); - expect(controller._executePoll.mock.calls).toMatchObject([]); + jest.advanceTimersByTime(0); + await Promise.resolve(); + expect(controller._executePoll.mock.calls).toMatchObject([ + ['mainnet', {}], + ['sepolia', {}], + ]); jest.advanceTimersByTime(TICK_TIME); await Promise.resolve(); expect(controller._executePoll.mock.calls).toMatchObject([ ['mainnet', {}], + ['sepolia', {}], + ['mainnet', {}], ]); jest.advanceTimersByTime(TICK_TIME); await Promise.resolve(); expect(controller._executePoll.mock.calls).toMatchObject([ ['mainnet', {}], ['sepolia', {}], + ['mainnet', {}], + ['sepolia', {}], ]); jest.advanceTimersByTime(TICK_TIME); await Promise.resolve(); @@ -316,6 +374,8 @@ describe('PollingController', () => { ['mainnet', {}], ['sepolia', {}], ['mainnet', {}], + ['sepolia', {}], + ['mainnet', {}], ]); jest.advanceTimersByTime(TICK_TIME); await Promise.resolve(); @@ -324,6 +384,8 @@ describe('PollingController', () => { ['sepolia', {}], ['mainnet', {}], ['sepolia', {}], + ['mainnet', {}], + ['sepolia', {}], ]); }); }); diff --git a/packages/polling-controller/src/PollingController.ts b/packages/polling-controller/src/PollingController.ts index a1335fb888..6cdad16519 100644 --- a/packages/polling-controller/src/PollingController.ts +++ b/packages/polling-controller/src/PollingController.ts @@ -139,21 +139,25 @@ function PollingControllerMixin(Base: TBase) { #poll(networkClientId: NetworkClientId, options: Json) { const key = getKey(networkClientId, options); - if (this.#intervalIds[key]) { - clearTimeout(this.#intervalIds[key]); + const interval = this.#intervalIds[key]; + if (interval) { + clearTimeout(interval); delete this.#intervalIds[key]; } // setTimeout is not `await`ing this async function, which is expected // We're just using async here for improved stack traces // eslint-disable-next-line @typescript-eslint/no-misused-promises - this.#intervalIds[key] = setTimeout(async () => { - try { - await this._executePoll(networkClientId, options); - } catch (error) { - console.error(error); - } - this.#poll(networkClientId, options); - }, this.#intervalLength); + this.#intervalIds[key] = setTimeout( + async () => { + try { + await this._executePoll(networkClientId, options); + } catch (error) { + console.error(error); + } + this.#poll(networkClientId, options); + }, + interval ? this.#intervalLength : 0, + ); } /**