Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions packages/transaction-pay-controller/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Fixed

- **BREAKING:** Always retrieve quote if using Relay strategy and required token is Arbitrum USDC, even if payment token matches ([#7146](https://github.com/MetaMask/core/pull/7146))
- Change `getStrategy` constructor option from asynchronous to synchronous.

## [5.0.0]

### Added
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ describe('TransactionPayController', () => {
createController();

expect(
await messenger.call(
messenger.call(
'TransactionPayController:getStrategy',
TRANSACTION_META_MOCK,
),
Expand All @@ -87,12 +87,12 @@ describe('TransactionPayController', () => {
it('returns callback value if provided', async () => {
new TransactionPayController({
getDelegationTransaction: jest.fn(),
getStrategy: async () => TransactionPayStrategy.Test,
getStrategy: () => TransactionPayStrategy.Test,
messenger,
});

expect(
await messenger.call(
messenger.call(
'TransactionPayController:getStrategy',
TRANSACTION_META_MOCK,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ export class TransactionPayController extends BaseController<

readonly #getStrategy?: (
transaction: TransactionMeta,
) => Promise<TransactionPayStrategy>;
) => TransactionPayStrategy;

constructor({
getDelegationTransaction,
Expand Down Expand Up @@ -139,7 +139,7 @@ export class TransactionPayController extends BaseController<

this.messenger.registerActionHandler(
'TransactionPayController:getStrategy',
this.#getStrategy ?? (async () => TransactionPayStrategy.Relay),
this.#getStrategy ?? (() => TransactionPayStrategy.Relay),
);

this.messenger.registerActionHandler(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ describe('TransactionPayPublishHook', () => {
},
} as TransactionPayControllerState);

getStrategyMock.mockResolvedValue(TransactionPayStrategy.Test);
getStrategyMock.mockReturnValue(TransactionPayStrategy.Test);
});

it('executes strategy with quotes', async () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ export class TransactionPayPublishHook {
return EMPTY_RESULT;
}

const strategy = await getStrategy(this.#messenger, transactionMeta);
const strategy = getStrategy(this.#messenger, transactionMeta);

return await strategy.execute({
isSmartTransaction: this.#isSmartTransaction,
Expand Down
6 changes: 2 additions & 4 deletions packages/transaction-pay-controller/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ export type TransactionPayControllerGetDelegationTransactionAction = {
/** Action to get the pay strategy type used for a transaction. */
export type TransactionPayControllerGetStrategyAction = {
type: `${typeof CONTROLLER_NAME}:getStrategy`;
handler: (transaction: TransactionMeta) => Promise<TransactionPayStrategy>;
handler: (transaction: TransactionMeta) => TransactionPayStrategy;
};

/** Action to update the payment token for a transaction. */
Expand Down Expand Up @@ -104,9 +104,7 @@ export type TransactionPayControllerOptions = {
getDelegationTransaction: GetDelegationTransactionCallback;

/** Callback to select the PayStrategy for a transaction. */
getStrategy?: (
transaction: TransactionMeta,
) => Promise<TransactionPayStrategy>;
getStrategy?: (transaction: TransactionMeta) => TransactionPayStrategy;

/** Controller messenger. */
messenger: TransactionPayControllerMessenger;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ describe('Quotes Utils', () => {
jest.resetAllMocks();
jest.clearAllTimers();

getStrategyMock.mockResolvedValue({
getStrategyMock.mockReturnValue({
execute: jest.fn(),
getQuotes: getQuotesMock,
getBatchTransactions: getBatchTransactionsMock,
Expand Down
2 changes: 1 addition & 1 deletion packages/transaction-pay-controller/src/utils/quotes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ async function getQuotes(
messenger: TransactionPayControllerMessenger,
) {
const { id: transactionId } = transaction;
const strategy = await getStrategy(messenger as never, transaction);
const strategy = getStrategy(messenger as never, transaction);
let quotes: TransactionPayQuote<Json>[] | undefined = [];

try {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
import { updateSourceAmounts } from './source-amounts';
import { getTokenFiatRate } from './token';
import type { TransactionPaymentToken } from '..';
import { getTransaction } from './transaction';
import { TransactionPayStrategy, type TransactionPaymentToken } from '..';
import {
ARBITRUM_USDC_ADDRESS,
CHAIN_ID_ARBITRUM,
} from '../strategy/relay/constants';
import { getMessengerMock } from '../tests/messenger-mock';
import type { TransactionData, TransactionPayRequiredToken } from '../types';

jest.mock('./token');
jest.mock('./transaction');

const PAYMENT_TOKEN_MOCK: TransactionPaymentToken = {
address: '0x123',
Expand Down Expand Up @@ -37,11 +44,15 @@ const TRANSACTION_ID_MOCK = '123-456';

describe('Source Amounts Utils', () => {
const getTokenFiatRateMock = jest.mocked(getTokenFiatRate);
const getTransactionMock = jest.mocked(getTransaction);
const { messenger, getStrategyMock } = getMessengerMock();

beforeEach(() => {
jest.resetAllMocks();

getTokenFiatRateMock.mockReturnValue({ fiatRate: '2.0', usdRate: '3.0' });
getStrategyMock.mockReturnValue(TransactionPayStrategy.Test);
getTransactionMock.mockReturnValue({ id: TRANSACTION_ID_MOCK } as never);
});

describe('updateSourceAmounts', () => {
Expand All @@ -52,7 +63,7 @@ describe('Source Amounts Utils', () => {
tokens: [TRANSACTION_TOKEN_MOCK],
};

updateSourceAmounts(TRANSACTION_ID_MOCK, transactionData, {} as never);
updateSourceAmounts(TRANSACTION_ID_MOCK, transactionData, messenger);

expect(transactionData.sourceAmounts).toStrictEqual([
{
Expand All @@ -76,11 +87,35 @@ describe('Source Amounts Utils', () => {
],
};

updateSourceAmounts(TRANSACTION_ID_MOCK, transactionData, {} as never);
updateSourceAmounts(TRANSACTION_ID_MOCK, transactionData, messenger);

expect(transactionData.sourceAmounts).toStrictEqual([]);
});

it('does not return empty array if payment token matches but hyperliquid deposit and relay strategy', () => {
getStrategyMock.mockReturnValue(TransactionPayStrategy.Relay);

const transactionData: TransactionData = {
isLoading: false,
paymentToken: {
...PAYMENT_TOKEN_MOCK,
address: ARBITRUM_USDC_ADDRESS,
chainId: CHAIN_ID_ARBITRUM,
},
tokens: [
{
...TRANSACTION_TOKEN_MOCK,
address: ARBITRUM_USDC_ADDRESS,
chainId: CHAIN_ID_ARBITRUM,
},
],
};

updateSourceAmounts(TRANSACTION_ID_MOCK, transactionData, messenger);

expect(transactionData.sourceAmounts).toHaveLength(1);
});

it('returns empty array if skipIfBalance and has balance', () => {
const transactionData: TransactionData = {
isLoading: false,
Expand All @@ -94,7 +129,7 @@ describe('Source Amounts Utils', () => {
],
};

updateSourceAmounts(TRANSACTION_ID_MOCK, transactionData, {} as never);
updateSourceAmounts(TRANSACTION_ID_MOCK, transactionData, messenger);

expect(transactionData.sourceAmounts).toStrictEqual([]);
});
Expand All @@ -108,7 +143,7 @@ describe('Source Amounts Utils', () => {

getTokenFiatRateMock.mockReturnValue(undefined);

updateSourceAmounts(TRANSACTION_ID_MOCK, transactionData, {} as never);
updateSourceAmounts(TRANSACTION_ID_MOCK, transactionData, messenger);

expect(transactionData.sourceAmounts).toStrictEqual([]);
});
Expand All @@ -125,7 +160,7 @@ describe('Source Amounts Utils', () => {
],
};

updateSourceAmounts(TRANSACTION_ID_MOCK, transactionData, {} as never);
updateSourceAmounts(TRANSACTION_ID_MOCK, transactionData, messenger);

expect(transactionData.sourceAmounts).toStrictEqual([]);
});
Expand All @@ -136,7 +171,7 @@ describe('Source Amounts Utils', () => {
tokens: [TRANSACTION_TOKEN_MOCK],
};

updateSourceAmounts(TRANSACTION_ID_MOCK, transactionData, {} as never);
updateSourceAmounts(TRANSACTION_ID_MOCK, transactionData, messenger);

expect(transactionData.sourceAmounts).toBeUndefined();
});
Expand All @@ -148,14 +183,14 @@ describe('Source Amounts Utils', () => {
tokens: [],
};

updateSourceAmounts(TRANSACTION_ID_MOCK, transactionData, {} as never);
updateSourceAmounts(TRANSACTION_ID_MOCK, transactionData, messenger);

expect(transactionData.sourceAmounts).toBeUndefined();
});

// eslint-disable-next-line jest/expect-expect
it('does nothing if no transaction data', () => {
updateSourceAmounts(TRANSACTION_ID_MOCK, undefined, {} as never);
updateSourceAmounts(TRANSACTION_ID_MOCK, undefined, messenger);
});
});
});
64 changes: 58 additions & 6 deletions packages/transaction-pay-controller/src/utils/source-amounts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,18 @@ import { createModuleLogger } from '@metamask/utils';
import { BigNumber } from 'bignumber.js';

import { getTokenFiatRate } from './token';
import { getTransaction } from './transaction';
import type {
TransactionPayControllerMessenger,
TransactionPaymentToken,
} from '..';
import { TransactionPayStrategy } from '..';
import type { TransactionMeta } from '../../../transaction-controller/src';
import { projectLogger } from '../logger';
import {
ARBITRUM_USDC_ADDRESS,
CHAIN_ID_ARBITRUM,
} from '../strategy/relay/constants';
import type {
TransactionPaySourceAmount,
TransactionData,
Expand Down Expand Up @@ -38,7 +45,9 @@ export function updateSourceAmounts(
}

const sourceAmounts = tokens
.map((t) => calculateSourceAmount(paymentToken, t, messenger))
.map((t) =>
calculateSourceAmount(paymentToken, t, messenger, transactionId),
)
.filter(Boolean) as TransactionPaySourceAmount[];

log('Updated source amounts', { transactionId, sourceAmounts });
Expand All @@ -52,12 +61,14 @@ export function updateSourceAmounts(
* @param paymentToken - Selected payment token.
* @param token - Target token to cover.
* @param messenger - Controller messenger.
* @param transactionId - ID of the transaction.
* @returns The source amount or undefined if calculation failed.
*/
function calculateSourceAmount(
paymentToken: TransactionPaymentToken,
token: TransactionPayRequiredToken,
messenger: TransactionPayControllerMessenger,
transactionId: string,
): TransactionPaySourceAmount | undefined {
const paymentTokenFiatRate = getTokenFiatRate(
messenger,
Expand All @@ -71,18 +82,22 @@ function calculateSourceAmount(

const hasBalance = new BigNumber(token.balanceRaw).gte(token.amountRaw);

const isSameTokenSelected =
token.address.toLowerCase() === paymentToken.address.toLowerCase() &&
token.chainId === paymentToken.chainId;

if (token.skipIfBalance && hasBalance) {
log('Skipping token as sufficient balance', {
tokenAddress: token.address,
});
return undefined;
}

if (isSameTokenSelected) {
const strategy = getStrategyType(transactionId, messenger);

const isSameTokenSelected =
token.address.toLowerCase() === paymentToken.address.toLowerCase() &&
token.chainId === paymentToken.chainId;

const isAlwaysRequired = isQuoteAlwaysRequired(token, strategy);

if (isSameTokenSelected && !isAlwaysRequired) {
log('Skipping token as same as payment token');
return undefined;
}
Expand All @@ -108,3 +123,40 @@ function calculateSourceAmount(
targetTokenAddress: token.address,
};
}

/**
* Determine if a quote is always required for a token and strategy.
*
* @param token - Target token.
* @param strategy - Payment strategy.
* @returns True if a quote is always required, false otherwise.
*/
function isQuoteAlwaysRequired(
token: TransactionPayRequiredToken,
strategy: TransactionPayStrategy,
) {
const isHyperliquidDeposit =
token.chainId === CHAIN_ID_ARBITRUM &&
token.address.toLowerCase() === ARBITRUM_USDC_ADDRESS.toLowerCase();

return strategy === TransactionPayStrategy.Relay && isHyperliquidDeposit;
}

/**
* Get the strategy type for a transaction.
*
* @param transactionId - ID of the transaction.
* @param messenger - Controller messenger.
* @returns Payment strategy type.
*/
function getStrategyType(
transactionId: string,
messenger: TransactionPayControllerMessenger,
) {
const transaction = getTransaction(
transactionId,
messenger,
) as TransactionMeta;

return messenger.call('TransactionPayController:getStrategy', transaction);
}
Loading
Loading