diff --git a/packages/assets-controller/CHANGELOG.md b/packages/assets-controller/CHANGELOG.md index 06d1e87c326..6838457ed3f 100644 --- a/packages/assets-controller/CHANGELOG.md +++ b/packages/assets-controller/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Fixed + +- Refactored `BalanceFetcher` and `RpcDataSource` to ensure the correct `assetId` is used for EVM native assets that are not ETH ([#8284](https://github.com/MetaMask/core/pull/8284)) + ## [3.1.0] ### Changed diff --git a/packages/assets-controller/src/data-sources/RpcDataSource.test.ts b/packages/assets-controller/src/data-sources/RpcDataSource.test.ts index c6ce6eaf9ce..1355065a592 100644 --- a/packages/assets-controller/src/data-sources/RpcDataSource.test.ts +++ b/packages/assets-controller/src/data-sources/RpcDataSource.test.ts @@ -476,7 +476,7 @@ describe('RpcDataSource', () => { const nativeAssetId = 'eip155:1/slip44:60' as Caip19AssetId; await withController(async ({ controller }) => { jest - .spyOn(BalanceFetcher.prototype, 'fetchBalancesForTokens') + .spyOn(BalanceFetcher.prototype, 'fetchBalancesForAssets') .mockResolvedValue({ chainId: MOCK_CHAIN_ID_HEX, accountId: MOCK_ACCOUNT_ID, @@ -565,7 +565,7 @@ describe('RpcDataSource', () => { it('initializes assetsBalance[accountId] in catch when first fetch for account throws', async () => { await withController(async ({ controller }) => { jest - .spyOn(BalanceFetcher.prototype, 'fetchBalancesForTokens') + .spyOn(BalanceFetcher.prototype, 'fetchBalancesForAssets') .mockRejectedValue(new Error('RPC unavailable')); const request = createDataRequest(); const response = await controller.fetch(request); @@ -759,12 +759,12 @@ describe('RpcDataSource', () => { ); }); - it('passes custom ERC20 token addresses to BalanceFetcher', async () => { + it('passes custom ERC20 asset entries (plus native) to BalanceFetcher', async () => { const customAssetId = 'eip155:1/erc20:0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48' as Caip19AssetId; const fetchSpy = jest - .spyOn(BalanceFetcher.prototype, 'fetchBalancesForTokens') + .spyOn(BalanceFetcher.prototype, 'fetchBalancesForAssets') .mockResolvedValue(createBalanceFetchResult()); await withController(async ({ controller }) => { @@ -777,9 +777,16 @@ describe('RpcDataSource', () => { MOCK_CHAIN_ID_HEX, MOCK_ACCOUNT_ID, MOCK_ADDRESS, - ['0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48'], - { includeNative: true }, - [], + [ + { + assetId: `${MOCK_CHAIN_ID_CAIP}/slip44:60`, + address: '0x0000000000000000000000000000000000000000', + }, + expect.objectContaining({ + assetId: customAssetId, + address: '0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48', + }), + ], ); }); @@ -793,7 +800,7 @@ describe('RpcDataSource', () => { 'eip155:137/erc20:0x2791Bca1f2de4661ED88A30C99A7a9449Aa84174' as Caip19AssetId; const fetchSpy = jest - .spyOn(BalanceFetcher.prototype, 'fetchBalancesForTokens') + .spyOn(BalanceFetcher.prototype, 'fetchBalancesForAssets') .mockResolvedValue(createBalanceFetchResult()); await withController(async ({ controller }) => { @@ -806,9 +813,16 @@ describe('RpcDataSource', () => { MOCK_CHAIN_ID_HEX, MOCK_ACCOUNT_ID, MOCK_ADDRESS, - ['0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48'], - { includeNative: true }, - [], + [ + { + assetId: `${MOCK_CHAIN_ID_CAIP}/slip44:60`, + address: '0x0000000000000000000000000000000000000000', + }, + expect.objectContaining({ + assetId: matchingAsset, + address: '0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48', + }), + ], ); }); diff --git a/packages/assets-controller/src/data-sources/RpcDataSource.ts b/packages/assets-controller/src/data-sources/RpcDataSource.ts index da913d14ac9..f5f00cde2c1 100644 --- a/packages/assets-controller/src/data-sources/RpcDataSource.ts +++ b/packages/assets-controller/src/data-sources/RpcDataSource.ts @@ -41,11 +41,11 @@ import type { } from './evm-rpc-services'; import type { Address, + AssetFetchEntry, Provider as RpcProvider, TokenListState, BalanceFetchResult, TokenDetectionResult, - TokenFetchInfo, } from './evm-rpc-services/types'; import type { AssetsControllerGetStateAction, @@ -62,6 +62,7 @@ import type { Middleware, } from '../types'; import { normalizeAssetId } from '../utils'; +import { ZERO_ADDRESS } from '../utils/constants'; const CONTROLLER_NAME = 'RpcDataSource'; const DEFAULT_BALANCE_INTERVAL = 30_000; // 30 seconds @@ -900,9 +901,15 @@ export class RpcDataSource extends AbstractDataSource< for (const chainId of chainsForAccount) { const hexChainId = caipChainIdToHex(chainId); - // Extract ERC20 token addresses from customAssets for this chain - const customTokenAddresses: Address[] = []; + // Build a single AssetFetchEntry[] for native + custom ERC-20s + const nativeAssetId = this.#buildNativeAssetId(chainId); + const assetsToFetch: AssetFetchEntry[] = [ + { assetId: nativeAssetId, address: ZERO_ADDRESS }, + ]; + if (request.customAssets) { + const existingMetadata = this.#getExistingAssetsMetadata(); + for (const assetId of request.customAssets) { try { const parsed = parseCaipAssetType(assetId); @@ -911,7 +918,18 @@ export class RpcDataSource extends AbstractDataSource< assetChainId === chainId && parsed.assetNamespace === 'erc20' ) { - customTokenAddresses.push(parsed.assetReference as Address); + const tokenAddress = + parsed.assetReference.toLowerCase() as Address; + const normalizedId = normalizeAssetId(assetId); + const decimals = + existingMetadata[normalizedId]?.decimals ?? + this.#getTokenMetadataFromTokenList(normalizedId)?.decimals; + + assetsToFetch.push({ + assetId, + address: tokenAddress, + decimals, + }); } } catch { // Skip unparseable asset IDs @@ -920,19 +938,11 @@ export class RpcDataSource extends AbstractDataSource< } try { - const tokenInfos = this.#tokenFetchInfosForCustomErc20s( - chainId, - customTokenAddresses, - ); - - // Use BalanceFetcher for batched balance fetching - const result = await this.#balanceFetcher.fetchBalancesForTokens( + const result = await this.#balanceFetcher.fetchBalancesForAssets( hexChainId, accountId, address as Address, - customTokenAddresses, - { includeNative: true }, - tokenInfos, + assetsToFetch, ); if (!assetsBalance[accountId]) { @@ -992,7 +1002,6 @@ export class RpcDataSource extends AbstractDataSource< if (!assetsBalance[accountId]) { assetsBalance[accountId] = {}; } - const nativeAssetId = this.#buildNativeAssetId(chainId); assetsBalance[accountId][nativeAssetId] = { amount: '0' }; // Even on error, include native token metadata @@ -1354,39 +1363,6 @@ export class RpcDataSource extends AbstractDataSource< return nativeAssetIdentifiers[chainId] ?? `${chainId}/slip44:60`; } - /** - * Build token infos for custom ERC-20s when decimals are already known from - * state or token list so BalanceFetcher can format balances; unknown decimals - * are left out and resolved in `fetch` / `#handleBalanceUpdate`. - * - * @param caipChainId - CAIP-2 chain id (e.g. `eip155:1`). - * @param tokenAddresses - ERC-20 contract addresses on that chain. - * @returns Token fetch infos that include only entries with known decimals. - */ - #tokenFetchInfosForCustomErc20s( - caipChainId: ChainId, - tokenAddresses: Address[], - ): TokenFetchInfo[] { - const existingMetadata = this.#getExistingAssetsMetadata(); - const infos: TokenFetchInfo[] = []; - - for (const tokenAddress of tokenAddresses) { - const { reference } = parseCaipChainId(caipChainId); - const rawAssetId = - `eip155:${reference}/erc20:${tokenAddress.toLowerCase()}` as Caip19AssetId; - const assetId = normalizeAssetId(rawAssetId); - const decimals = - existingMetadata[assetId]?.decimals ?? - this.#getTokenMetadataFromTokenList(assetId)?.decimals; - - if (decimals !== undefined) { - infos.push({ address: tokenAddress, decimals }); - } - } - - return infos; - } - /** * Get existing assets metadata from AssetsController state. * Used to include metadata for ERC20 tokens when returning balance updates. diff --git a/packages/assets-controller/src/data-sources/evm-rpc-services/clients/MulticallClient.ts b/packages/assets-controller/src/data-sources/evm-rpc-services/clients/MulticallClient.ts index 7fa65e28dd3..d3e767dae22 100644 --- a/packages/assets-controller/src/data-sources/evm-rpc-services/clients/MulticallClient.ts +++ b/packages/assets-controller/src/data-sources/evm-rpc-services/clients/MulticallClient.ts @@ -1,6 +1,7 @@ import { Interface } from '@ethersproject/abi'; import type { Hex } from '@metamask/utils'; +import { ZERO_ADDRESS } from '../../../utils/constants'; import type { Address, BalanceOfRequest, @@ -79,12 +80,6 @@ const erc20Interface = new Interface(ERC20_ABI); // CONSTANTS // ============================================================================= -/** - * Zero address constant for native token. - */ -const ZERO_ADDRESS: Address = - '0x0000000000000000000000000000000000000000' as Address; - /** * Multicall3 contract addresses by chain ID. * Source: https://github.com/mds1/multicall/blob/main/deployments.json diff --git a/packages/assets-controller/src/data-sources/evm-rpc-services/index.ts b/packages/assets-controller/src/data-sources/evm-rpc-services/index.ts index e47693f7299..5b30f3b7d36 100644 --- a/packages/assets-controller/src/data-sources/evm-rpc-services/index.ts +++ b/packages/assets-controller/src/data-sources/evm-rpc-services/index.ts @@ -1,5 +1,6 @@ export type { Address, + AssetFetchEntry, AssetsBalanceState, ChainId, GetProviderFunction, diff --git a/packages/assets-controller/src/data-sources/evm-rpc-services/services/BalanceFetcher.test.ts b/packages/assets-controller/src/data-sources/evm-rpc-services/services/BalanceFetcher.test.ts index d953496f0e3..780be4a4ff0 100644 --- a/packages/assets-controller/src/data-sources/evm-rpc-services/services/BalanceFetcher.test.ts +++ b/packages/assets-controller/src/data-sources/evm-rpc-services/services/BalanceFetcher.test.ts @@ -1,3 +1,5 @@ +import type { CaipAssetType } from '@metamask/utils'; + import { BalanceFetcher } from './BalanceFetcher'; import type { BalanceFetcherConfig, @@ -7,10 +9,10 @@ import type { import type { MulticallClient } from '../clients'; import type { Address, + AssetFetchEntry, AssetsBalanceState, BalanceOfResponse, ChainId, - TokenFetchInfo, } from '../types'; // ============================================================================= @@ -30,13 +32,30 @@ const ZERO_ADDRESS: Address = const MAINNET_CHAIN_ID: ChainId = '0x1' as ChainId; const POLYGON_CHAIN_ID: ChainId = '0x89' as ChainId; -/** Decimals for TEST_TOKEN_1 (USDC) / TEST_TOKEN_2 (USDT) in fetch tests */ -const TEST_TOKEN_1_WITH_DECIMALS: TokenFetchInfo = { - address: TEST_TOKEN_1, +const NATIVE_ETH_ASSET_ID = 'eip155:1/slip44:60' as CaipAssetType; +const TOKEN_1_ASSET_ID = + `eip155:1/erc20:${TEST_TOKEN_1.toLowerCase()}` as CaipAssetType; +const TOKEN_2_ASSET_ID = + `eip155:1/erc20:${TEST_TOKEN_2.toLowerCase()}` as CaipAssetType; + +const NATIVE_ETH_ENTRY: AssetFetchEntry = { + assetId: NATIVE_ETH_ASSET_ID, + address: ZERO_ADDRESS, +}; +const TOKEN_1_ENTRY: AssetFetchEntry = { + assetId: TOKEN_1_ASSET_ID, + address: TEST_TOKEN_1.toLowerCase() as Address, +}; +const TOKEN_1_ENTRY_WITH_DECIMALS: AssetFetchEntry = { + ...TOKEN_1_ENTRY, decimals: 6, }; -const TEST_TOKEN_2_WITH_DECIMALS: TokenFetchInfo = { - address: TEST_TOKEN_2, +const TOKEN_2_ENTRY: AssetFetchEntry = { + assetId: TOKEN_2_ASSET_ID, + address: TEST_TOKEN_2.toLowerCase() as Address, +}; +const TOKEN_2_ENTRY_WITH_DECIMALS: AssetFetchEntry = { + ...TOKEN_2_ENTRY, decimals: 6, }; @@ -151,7 +170,6 @@ describe('BalanceFetcher', () => { config: { defaultBatchSize: 100, defaultTimeoutMs: 60000, - includeNativeByDefault: false, pollingInterval: 45000, }, }, @@ -183,35 +201,42 @@ describe('BalanceFetcher', () => { describe('setOnBalanceUpdate', () => { it('sets the balance update callback', async () => { - await withController(async ({ controller, mockMulticallClient }) => { - const mockCallback = jest.fn(); - controller.setOnBalanceUpdate(mockCallback); - - mockMulticallClient.batchBalanceOf.mockResolvedValue([ - createMockBalanceResponse( - ZERO_ADDRESS, - TEST_ACCOUNT, - true, - '1000000000000000000', - ), - ]); + const mockState = createMockAssetsBalanceState(TEST_ACCOUNT_ID, { + [NATIVE_ETH_ASSET_ID]: { amount: '0' }, + }); - const input: BalancePollingInput = { - chainId: MAINNET_CHAIN_ID, - accountId: TEST_ACCOUNT_ID, - accountAddress: TEST_ACCOUNT, - }; + await withController( + { assetsBalanceState: mockState }, + async ({ controller, mockMulticallClient }) => { + const mockCallback = jest.fn(); + controller.setOnBalanceUpdate(mockCallback); - await controller._executePoll(input); + mockMulticallClient.batchBalanceOf.mockResolvedValue([ + createMockBalanceResponse( + ZERO_ADDRESS, + TEST_ACCOUNT, + true, + '1000000000000000000', + ), + ]); - expect(mockCallback).toHaveBeenCalledWith( - expect.objectContaining({ + const input: BalancePollingInput = { chainId: MAINNET_CHAIN_ID, accountId: TEST_ACCOUNT_ID, - balances: expect.any(Array), - }), - ); - }); + accountAddress: TEST_ACCOUNT, + }; + + await controller._executePoll(input); + + expect(mockCallback).toHaveBeenCalledWith( + expect.objectContaining({ + chainId: MAINNET_CHAIN_ID, + accountId: TEST_ACCOUNT_ID, + balances: expect.any(Array), + }), + ); + }, + ); }); it('does not call callback when balances are empty', async () => { @@ -285,55 +310,8 @@ describe('BalanceFetcher', () => { }); }); - describe('getTokensToFetch', () => { - it('returns empty array when no balances exist', async () => { - await withController(async ({ controller }) => { - const tokens = controller.getTokensToFetch( - MAINNET_CHAIN_ID, - TEST_ACCOUNT_ID, - ); - expect(tokens).toStrictEqual([]); - }); - }); - - it('returns tokens from assetsBalance state', async () => { - const mockState = createMockAssetsBalanceState(TEST_ACCOUNT_ID, { - [`eip155:1/erc20:${TEST_TOKEN_1}`]: { amount: '100' }, - [`eip155:1/erc20:${TEST_TOKEN_2}`]: { amount: '200' }, - }); - - await withController( - { assetsBalanceState: mockState }, - async ({ controller }) => { - const tokens = controller.getTokensToFetch( - MAINNET_CHAIN_ID, - TEST_ACCOUNT_ID, - ); - expect(tokens).toHaveLength(2); - }, - ); - }); - - it('returns empty array when chain has no tokens', async () => { - const mockState = createMockAssetsBalanceState(TEST_ACCOUNT_ID, { - [`eip155:1/erc20:${TEST_TOKEN_1}`]: { amount: '100' }, - }); - - await withController( - { assetsBalanceState: mockState }, - async ({ controller }) => { - const tokens = controller.getTokensToFetch( - POLYGON_CHAIN_ID, - TEST_ACCOUNT_ID, - ); - expect(tokens).toStrictEqual([]); - }, - ); - }); - }); - - describe('fetchBalancesForTokens', () => { - it('fetches balances for specified token addresses', async () => { + describe('fetchBalancesForAssets', () => { + it('fetches balances for specified asset entries', async () => { await withController(async ({ controller, mockMulticallClient }) => { mockMulticallClient.batchBalanceOf.mockResolvedValue([ createMockBalanceResponse( @@ -350,13 +328,11 @@ describe('BalanceFetcher', () => { ), ]); - const result = await controller.fetchBalancesForTokens( + const result = await controller.fetchBalancesForAssets( MAINNET_CHAIN_ID, TEST_ACCOUNT_ID, TEST_ACCOUNT, - [TEST_TOKEN_1], - undefined, - [TEST_TOKEN_1_WITH_DECIMALS], + [NATIVE_ETH_ENTRY, TOKEN_1_ENTRY_WITH_DECIMALS], ); expect(result.balances).toHaveLength(2); @@ -364,7 +340,7 @@ describe('BalanceFetcher', () => { }); }); - it('creates correct CAIP-19 asset ID for native token', async () => { + it('preserves the native asset ID provided by the caller', async () => { await withController(async ({ controller, mockMulticallClient }) => { mockMulticallClient.batchBalanceOf.mockResolvedValue([ createMockBalanceResponse( @@ -375,45 +351,43 @@ describe('BalanceFetcher', () => { ), ]); - const result = await controller.fetchBalancesForTokens( + const result = await controller.fetchBalancesForAssets( MAINNET_CHAIN_ID, TEST_ACCOUNT_ID, TEST_ACCOUNT, - [], - { includeNative: true }, + [NATIVE_ETH_ENTRY], ); - expect(result.balances[0].assetId).toBe('eip155:1/slip44:60'); + expect(result.balances[0].assetId).toBe(NATIVE_ETH_ASSET_ID); }); }); - it('creates correct CAIP-19 asset ID for ERC-20 token', async () => { + it('preserves a non-ETH native asset ID (e.g. Avalanche)', async () => { + const avaxNativeAssetId = 'eip155:43114/slip44:9005' as CaipAssetType; + await withController(async ({ controller, mockMulticallClient }) => { mockMulticallClient.batchBalanceOf.mockResolvedValue([ createMockBalanceResponse( - TEST_TOKEN_1, + ZERO_ADDRESS, TEST_ACCOUNT, true, - '1000000000', + '5000000000000000000', ), ]); - const result = await controller.fetchBalancesForTokens( + const result = await controller.fetchBalancesForAssets( MAINNET_CHAIN_ID, TEST_ACCOUNT_ID, TEST_ACCOUNT, - [TEST_TOKEN_1], - { includeNative: false }, - [TEST_TOKEN_1_WITH_DECIMALS], + [{ assetId: avaxNativeAssetId, address: ZERO_ADDRESS }], ); - expect(result.balances[0].assetId).toBe( - 'eip155:1/erc20:0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48', - ); + expect(result.balances[0].assetId).toBe(avaxNativeAssetId); + expect(result.balances[0].formattedBalance).toBe('5'); }); }); - it('includes ERC-20 raw balance when tokenInfos are omitted (decimals resolved downstream)', async () => { + it('preserves the ERC-20 asset ID provided by the caller', async () => { await withController(async ({ controller, mockMulticallClient }) => { mockMulticallClient.batchBalanceOf.mockResolvedValue([ createMockBalanceResponse( @@ -424,23 +398,18 @@ describe('BalanceFetcher', () => { ), ]); - const result = await controller.fetchBalancesForTokens( + const result = await controller.fetchBalancesForAssets( MAINNET_CHAIN_ID, TEST_ACCOUNT_ID, TEST_ACCOUNT, - [TEST_TOKEN_1], - { includeNative: false }, + [TOKEN_1_ENTRY_WITH_DECIMALS], ); - expect(result.balances).toHaveLength(1); - expect(result.balances[0].decimals).toBeUndefined(); - expect(result.balances[0].balance).toBe('1000000000'); - expect(result.balances[0].formattedBalance).toBe('1000000000'); - expect(result.failedAddresses).toHaveLength(0); + expect(result.balances[0].assetId).toBe(TOKEN_1_ASSET_ID); }); }); - it('includes ERC-20 raw balance when token info has no decimals', async () => { + it('includes ERC-20 raw balance when decimals omitted (resolved downstream)', async () => { await withController(async ({ controller, mockMulticallClient }) => { mockMulticallClient.batchBalanceOf.mockResolvedValue([ createMockBalanceResponse( @@ -451,44 +420,32 @@ describe('BalanceFetcher', () => { ), ]); - const tokenInfoMissingDecimals = { - address: TEST_TOKEN_1, - } as TokenFetchInfo; - - const result = await controller.fetchBalancesForTokens( + const result = await controller.fetchBalancesForAssets( MAINNET_CHAIN_ID, TEST_ACCOUNT_ID, TEST_ACCOUNT, - [TEST_TOKEN_1], - { includeNative: false }, - [tokenInfoMissingDecimals], + [TOKEN_1_ENTRY], ); expect(result.balances).toHaveLength(1); expect(result.balances[0].decimals).toBeUndefined(); + expect(result.balances[0].balance).toBe('1000000000'); expect(result.balances[0].formattedBalance).toBe('1000000000'); expect(result.failedAddresses).toHaveLength(0); }); }); - it('includes ERC-20 balance when token info has zero decimals', async () => { + it('includes ERC-20 balance when entry has zero decimals', async () => { await withController(async ({ controller, mockMulticallClient }) => { mockMulticallClient.batchBalanceOf.mockResolvedValue([ createMockBalanceResponse(TEST_TOKEN_1, TEST_ACCOUNT, true, '42'), ]); - const tokenInfoZeroDecimals: TokenFetchInfo = { - address: TEST_TOKEN_1, - decimals: 0, - }; - - const result = await controller.fetchBalancesForTokens( + const result = await controller.fetchBalancesForAssets( MAINNET_CHAIN_ID, TEST_ACCOUNT_ID, TEST_ACCOUNT, - [TEST_TOKEN_1], - { includeNative: false }, - [tokenInfoZeroDecimals], + [{ ...TOKEN_1_ENTRY, decimals: 0 }], ); expect(result.balances).toHaveLength(1); @@ -503,12 +460,11 @@ describe('BalanceFetcher', () => { createMockBalanceResponse(TEST_TOKEN_1, TEST_ACCOUNT, false), ]); - const result = await controller.fetchBalancesForTokens( + const result = await controller.fetchBalancesForAssets( MAINNET_CHAIN_ID, TEST_ACCOUNT_ID, TEST_ACCOUNT, - [TEST_TOKEN_1], - { includeNative: false }, + [TOKEN_1_ENTRY], ); expect(result.balances).toHaveLength(0); @@ -516,30 +472,73 @@ describe('BalanceFetcher', () => { }); }); - it('returns empty result when no tokens to fetch', async () => { - await withController( - { config: { includeNativeByDefault: false } }, - async ({ controller, mockMulticallClient }) => { - const result = await controller.fetchBalancesForTokens( - MAINNET_CHAIN_ID, - TEST_ACCOUNT_ID, + it('returns empty result when no entries provided', async () => { + await withController(async ({ controller, mockMulticallClient }) => { + const result = await controller.fetchBalancesForAssets( + MAINNET_CHAIN_ID, + TEST_ACCOUNT_ID, + TEST_ACCOUNT, + [], + ); + + expect(result).toStrictEqual({ + chainId: MAINNET_CHAIN_ID, + accountId: TEST_ACCOUNT_ID, + accountAddress: TEST_ACCOUNT, + balances: [], + failedAddresses: [], + timestamp: 1700000000000, + }); + + expect(mockMulticallClient.batchBalanceOf).not.toHaveBeenCalled(); + }); + }); + + it('derives hex chainId from asset entries', async () => { + await withController(async ({ controller, mockMulticallClient }) => { + mockMulticallClient.batchBalanceOf.mockResolvedValue([ + createMockBalanceResponse( + ZERO_ADDRESS, TEST_ACCOUNT, - [], - { includeNative: false }, - ); + true, + '1000000000000000000', + ), + ]); - expect(result).toStrictEqual({ - chainId: MAINNET_CHAIN_ID, - accountId: TEST_ACCOUNT_ID, - accountAddress: TEST_ACCOUNT, - balances: [], - failedAddresses: [], - timestamp: 1700000000000, - }); + const result = await controller.fetchBalancesForAssets( + MAINNET_CHAIN_ID, + TEST_ACCOUNT_ID, + TEST_ACCOUNT, + [NATIVE_ETH_ENTRY], + ); - expect(mockMulticallClient.batchBalanceOf).not.toHaveBeenCalled(); - }, - ); + expect(result.chainId).toBe(MAINNET_CHAIN_ID); + }); + }); + + it('deduplicates entries with same address', async () => { + await withController(async ({ controller, mockMulticallClient }) => { + mockMulticallClient.batchBalanceOf.mockResolvedValue([ + createMockBalanceResponse( + ZERO_ADDRESS, + TEST_ACCOUNT, + true, + '1000000000000000000', + ), + ]); + + const result = await controller.fetchBalancesForAssets( + MAINNET_CHAIN_ID, + TEST_ACCOUNT_ID, + TEST_ACCOUNT, + [NATIVE_ETH_ENTRY, NATIVE_ETH_ENTRY], + ); + + expect(mockMulticallClient.batchBalanceOf).toHaveBeenCalledTimes(1); + const calls = mockMulticallClient.batchBalanceOf.mock.calls[0]; + expect(calls[1]).toHaveLength(1); + expect(result.balances).toHaveLength(1); + }); }); }); @@ -555,13 +554,11 @@ describe('BalanceFetcher', () => { ), ]); - const result = await controller.fetchBalancesForTokens( + const result = await controller.fetchBalancesForAssets( MAINNET_CHAIN_ID, TEST_ACCOUNT_ID, TEST_ACCOUNT, - [TEST_TOKEN_1], - { includeNative: false }, - [{ address: TEST_TOKEN_1, decimals: 6, symbol: 'USDC' }], + [{ ...TOKEN_1_ENTRY, decimals: 6 }], ); expect(result.balances[0].formattedBalance).toBe('1234.56789'); @@ -574,13 +571,11 @@ describe('BalanceFetcher', () => { createMockBalanceResponse(TEST_TOKEN_1, TEST_ACCOUNT, true, '0'), ]); - const result = await controller.fetchBalancesForTokens( + const result = await controller.fetchBalancesForAssets( MAINNET_CHAIN_ID, TEST_ACCOUNT_ID, TEST_ACCOUNT, - [TEST_TOKEN_1], - { includeNative: false }, - [TEST_TOKEN_1_WITH_DECIMALS], + [TOKEN_1_ENTRY_WITH_DECIMALS], ); expect(result.balances[0].formattedBalance).toBe('0'); @@ -598,13 +593,11 @@ describe('BalanceFetcher', () => { ), ]); - const result = await controller.fetchBalancesForTokens( + const result = await controller.fetchBalancesForAssets( MAINNET_CHAIN_ID, TEST_ACCOUNT_ID, TEST_ACCOUNT, - [TEST_TOKEN_1], - { includeNative: false }, - [TEST_TOKEN_1_WITH_DECIMALS], + [TOKEN_1_ENTRY_WITH_DECIMALS], ); expect(result.balances[0].balance).toBe('0'); @@ -623,13 +616,11 @@ describe('BalanceFetcher', () => { ), ]); - const result = await controller.fetchBalancesForTokens( + const result = await controller.fetchBalancesForAssets( MAINNET_CHAIN_ID, TEST_ACCOUNT_ID, TEST_ACCOUNT, - [TEST_TOKEN_1], - { includeNative: false }, - [TEST_TOKEN_1_WITH_DECIMALS], + [TOKEN_1_ENTRY_WITH_DECIMALS], ); expect(result.balances[0].formattedBalance).toBe('invalid-balance'); @@ -639,61 +630,63 @@ describe('BalanceFetcher', () => { describe('batching behavior', () => { it('uses custom batch size from options', async () => { - await withController(async ({ controller, mockMulticallClient }) => { - mockMulticallClient.batchBalanceOf.mockResolvedValue([ - createMockBalanceResponse( - TEST_TOKEN_1, - TEST_ACCOUNT, - true, - '1000000000', - ), - ]); - - await controller.fetchBalancesForTokens( - MAINNET_CHAIN_ID, - TEST_ACCOUNT_ID, - TEST_ACCOUNT, - [TEST_TOKEN_1, TEST_TOKEN_2], - { includeNative: false, batchSize: 1 }, - [TEST_TOKEN_1_WITH_DECIMALS, TEST_TOKEN_2_WITH_DECIMALS], - ); - - expect(mockMulticallClient.batchBalanceOf).toHaveBeenCalledTimes(2); - }); - }); - - it('accumulates results across multiple batches', async () => { - await withController(async ({ controller, mockMulticallClient }) => { - mockMulticallClient.batchBalanceOf - .mockResolvedValueOnce([ + await withController( + { config: { defaultBatchSize: 1 } }, + async ({ controller, mockMulticallClient }) => { + mockMulticallClient.batchBalanceOf.mockResolvedValue([ createMockBalanceResponse( TEST_TOKEN_1, TEST_ACCOUNT, true, - '1000000', - ), - ]) - .mockResolvedValueOnce([ - createMockBalanceResponse( - TEST_TOKEN_2, - TEST_ACCOUNT, - true, - '2000000', + '1000000000', ), ]); - const result = await controller.fetchBalancesForTokens( - MAINNET_CHAIN_ID, - TEST_ACCOUNT_ID, - TEST_ACCOUNT, - [TEST_TOKEN_1, TEST_TOKEN_2], - { includeNative: false, batchSize: 1 }, - [TEST_TOKEN_1_WITH_DECIMALS, TEST_TOKEN_2_WITH_DECIMALS], - ); + await controller.fetchBalancesForAssets( + MAINNET_CHAIN_ID, + TEST_ACCOUNT_ID, + TEST_ACCOUNT, + [TOKEN_1_ENTRY_WITH_DECIMALS, TOKEN_2_ENTRY_WITH_DECIMALS], + ); - expect(mockMulticallClient.batchBalanceOf).toHaveBeenCalledTimes(2); - expect(result.balances).toHaveLength(2); - }); + expect(mockMulticallClient.batchBalanceOf).toHaveBeenCalledTimes(2); + }, + ); + }); + + it('accumulates results across multiple batches', async () => { + await withController( + { config: { defaultBatchSize: 1 } }, + async ({ controller, mockMulticallClient }) => { + mockMulticallClient.batchBalanceOf + .mockResolvedValueOnce([ + createMockBalanceResponse( + TEST_TOKEN_1, + TEST_ACCOUNT, + true, + '1000000', + ), + ]) + .mockResolvedValueOnce([ + createMockBalanceResponse( + TEST_TOKEN_2, + TEST_ACCOUNT, + true, + '2000000', + ), + ]); + + const result = await controller.fetchBalancesForAssets( + MAINNET_CHAIN_ID, + TEST_ACCOUNT_ID, + TEST_ACCOUNT, + [TOKEN_1_ENTRY_WITH_DECIMALS, TOKEN_2_ENTRY_WITH_DECIMALS], + ); + + expect(mockMulticallClient.batchBalanceOf).toHaveBeenCalledTimes(2); + expect(result.balances).toHaveLength(2); + }, + ); }); }); }); diff --git a/packages/assets-controller/src/data-sources/evm-rpc-services/services/BalanceFetcher.ts b/packages/assets-controller/src/data-sources/evm-rpc-services/services/BalanceFetcher.ts index 0de7d311238..4c4a6a4b429 100644 --- a/packages/assets-controller/src/data-sources/evm-rpc-services/services/BalanceFetcher.ts +++ b/packages/assets-controller/src/data-sources/evm-rpc-services/services/BalanceFetcher.ts @@ -1,26 +1,24 @@ import { StaticIntervalPollingControllerOnly } from '@metamask/polling-controller'; -import type { CaipAssetType } from '@metamask/utils'; +import { parseCaipAssetType } from '@metamask/utils'; +import { ZERO_ADDRESS } from '../../../utils/constants'; import type { MulticallClient } from '../clients'; import type { AccountId, Address, AssetBalance, + AssetFetchEntry, AssetsBalanceState, - BalanceFetchOptions, BalanceFetchResult, BalanceOfRequest, BalanceOfResponse, + CaipAssetType, ChainId, - TokenFetchInfo, } from '../types'; import { reduceInBatchesSerially } from '../utils'; const DEFAULT_BALANCE_INTERVAL = 30_000; // 30 seconds -const ZERO_ADDRESS: Address = - '0x0000000000000000000000000000000000000000' as Address; - /** * Minimal messenger interface for BalanceFetcher. */ @@ -31,7 +29,6 @@ export type BalanceFetcherMessenger = { export type BalanceFetcherConfig = { defaultBatchSize?: number; defaultTimeoutMs?: number; - includeNativeByDefault?: boolean; /** Polling interval in ms (default: 30s) */ pollingInterval?: number; }; @@ -58,6 +55,11 @@ export type OnBalanceUpdateCallback = ( /** * BalanceFetcher - Fetches token balances via multicall. * Extends StaticIntervalPollingControllerOnly for built-in polling support. + * + * Callers provide CAIP-19 asset IDs; the fetcher extracts on-chain addresses + * (or uses the zero address for native assets) and maps multicall responses + * back to the original asset IDs. This ensures the returned balance entries + * always carry the correct identifier regardless of chain. */ export class BalanceFetcher extends StaticIntervalPollingControllerOnly() { readonly #multicallClient: MulticallClient; @@ -79,7 +81,6 @@ export class BalanceFetcher extends StaticIntervalPollingControllerOnly { - const result = await this.fetchBalances( + const result = await this.#fetchBalances( input.chainId, input.accountId, input.accountAddress, @@ -113,97 +114,113 @@ export class BalanceFetcher extends StaticIntervalPollingControllerOnly(); - - for (const assetId of Object.keys(accountBalances)) { - // Only process ERC20 tokens on the current chain - if (assetId.startsWith(caipChainPrefix) && assetId.includes('/erc20:')) { - // Parse token address from CAIP-19: eip155:1/erc20:0x... - const tokenAddress = assetId.split('/erc20:')[1] as Address; - if (tokenAddress) { - const lowerAddress = tokenAddress.toLowerCase(); - if (!tokenMap.has(lowerAddress)) { - tokenMap.set(lowerAddress, { - address: tokenAddress, - symbol: '', - }); - } + // This is safe because we are filtring with an accountId that is for evm balances only + const chainIdDecimal = parseInt(chainId, 16).toString(); + + const assetsToFetch = new Map(); + + for (const assetId of Object.keys(accountBalances) as CaipAssetType[]) { + const { + chain: { reference: chainReference }, + assetNamespace, + assetReference, + } = parseCaipAssetType(assetId); + + if (chainReference === chainIdDecimal) { + const assetIdLowerCase = assetId.toLowerCase(); + if (assetsToFetch.has(assetIdLowerCase)) { + continue; } + + const isNative = assetNamespace === 'slip44'; + const tokenAddress = isNative + ? ZERO_ADDRESS + : (assetReference.toLowerCase() as Address); + + assetsToFetch.set(assetIdLowerCase, { + assetId, + address: tokenAddress, + }); } } - return Array.from(tokenMap.values()); + return Array.from(assetsToFetch.values()); } - async fetchBalances( + /** + * Fetch balances for assets already tracked in state for the given + * account and chain. + * + * @param chainId - Hex chain ID. + * @param accountId - Account UUID. + * @param accountAddress - On-chain address of the account. + * @returns Balance fetch result. + */ + async #fetchBalances( chainId: ChainId, accountId: AccountId, accountAddress: Address, - options?: BalanceFetchOptions, ): Promise { - const tokens = this.getTokensToFetch(chainId, accountId); - const tokenAddresses = tokens.map((token) => token.address); + const assets = this.#getAssetsToFetch(chainId, accountId); - return this.fetchBalancesForTokens( + return this.fetchBalancesForAssets( chainId, accountId, accountAddress, - tokenAddresses, - options, - tokens, + assets, ); } - async fetchBalancesForTokens( + /** + * Fetch balances for the given assets via multicall. + * + * Each entry bundles a CAIP-19 asset ID with its on-chain address and + * optional decimals. + * + * @param chainId - Hex chain ID. + * @param accountId - Account UUID. + * @param accountAddress - On-chain address of the account. + * @param assets - Asset fetch entries to fetch balances for. + * @returns Balance fetch result. + */ + async fetchBalancesForAssets( chainId: ChainId, accountId: AccountId, accountAddress: Address, - tokenAddresses: Address[], - options?: BalanceFetchOptions, - tokenInfos?: TokenFetchInfo[], + assets: AssetFetchEntry[], ): Promise { - const batchSize = options?.batchSize ?? this.#config.defaultBatchSize; - const includeNative = - options?.includeNative ?? this.#config.includeNativeByDefault; const timestamp = Date.now(); - const tokenInfoMap = new Map(); - if (tokenInfos) { - for (const info of tokenInfos) { - tokenInfoMap.set(info.address.toLowerCase(), info); - } - } - + // Build a single map keyed by lowercase address that holds all info + // needed to match multicall responses back to their original entries. const balanceRequests: BalanceOfRequest[] = []; + const entryByAddress = new Map(); - if (includeNative) { - balanceRequests.push({ - tokenAddress: ZERO_ADDRESS, - accountAddress, - }); - } + for (const entry of assets) { + const lowerAddress = entry.address.toLowerCase(); + if (entryByAddress.has(lowerAddress)) { + continue; // deduplicate + } - for (const tokenAddress of tokenAddresses) { - balanceRequests.push({ - tokenAddress, - accountAddress, - }); + entryByAddress.set(lowerAddress, entry); + balanceRequests.push({ tokenAddress: entry.address, accountAddress }); } if (balanceRequests.length === 0) { @@ -227,7 +244,7 @@ export class BalanceFetcher extends StaticIntervalPollingControllerOnly({ values: balanceRequests, - batchSize, + batchSize: this.#config.defaultBatchSize, initialResult: { balances: [], failedAddresses: [], @@ -244,7 +261,7 @@ export class BalanceFetcher extends StaticIntervalPollingControllerOnly, + entryByAddress: Map, ): { balances: AssetBalance[]; failedAddresses: Address[]; @@ -280,31 +297,29 @@ export class BalanceFetcher extends StaticIntervalPollingControllerOnly