diff --git a/packages/subscription-controller/CHANGELOG.md b/packages/subscription-controller/CHANGELOG.md index 94c10fff777..79412d81fbd 100644 --- a/packages/subscription-controller/CHANGELOG.md +++ b/packages/subscription-controller/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- Updated `submitShieldSubscriptionCryptoApproval` to handle change payment method transaction if subscription already existed ([#7231](https://github.com/MetaMask/core/pull/7231)) - Bump `@metamask/transaction-controller` from `^62.0.0` to `^62.3.0` ([#7215](https://github.com/MetaMask/core/pull/7215), [#7220](https://github.com/MetaMask/core/pull/7220), [#7236](https://github.com/MetaMask/core/pull/7236)) - Move peer dependencies for controller and service packages to direct dependencies ([#7209](https://github.com/MetaMask/core/pull/7209)) - The dependencies moved are: diff --git a/packages/subscription-controller/src/SubscriptionController.test.ts b/packages/subscription-controller/src/SubscriptionController.test.ts index eac28dde529..3b8ea5f9f1c 100644 --- a/packages/subscription-controller/src/SubscriptionController.test.ts +++ b/packages/subscription-controller/src/SubscriptionController.test.ts @@ -1891,9 +1891,12 @@ describe('SubscriptionController', () => { status: SUBSCRIPTION_STATUSES.trialing, }); - mockService.getSubscriptions.mockResolvedValue( - MOCK_GET_SUBSCRIPTIONS_RESPONSE, - ); + mockService.getSubscriptions + .mockResolvedValueOnce({ + subscriptions: [], + trialedProducts: [], + }) + .mockResolvedValue(MOCK_GET_SUBSCRIPTIONS_RESPONSE); // Create a shield subscription approval transaction const txMeta = { @@ -2095,5 +2098,103 @@ describe('SubscriptionController', () => { }, ); }); + + it('should update payment method when user has active subscription', async () => { + await withController( + { + state: { + pricing: MOCK_PRICE_INFO_RESPONSE, + trialedProducts: [], + subscriptions: [MOCK_SUBSCRIPTION], + lastSelectedPaymentMethod: { + [PRODUCT_TYPES.SHIELD]: { + type: PAYMENT_TYPES.byCrypto, + paymentTokenAddress: '0xtoken', + paymentTokenSymbol: 'USDT', + plan: RECURRING_INTERVALS.month, + }, + }, + }, + }, + async ({ controller, mockService }) => { + mockService.updatePaymentMethodCrypto.mockResolvedValue({}); + mockService.getSubscriptions.mockResolvedValue( + MOCK_GET_SUBSCRIPTIONS_RESPONSE, + ); + + const txMeta = { + ...generateMockTxMeta(), + type: TransactionType.shieldSubscriptionApprove, + chainId: '0x1' as Hex, + rawTx: '0x123', + txParams: { + data: '0x456', + from: '0x1234567890123456789012345678901234567890', + to: '0xtoken', + }, + status: TransactionStatus.submitted, + }; + + await controller.submitShieldSubscriptionCryptoApproval(txMeta); + + expect(mockService.updatePaymentMethodCrypto).toHaveBeenCalledTimes( + 1, + ); + expect( + mockService.startSubscriptionWithCrypto, + ).not.toHaveBeenCalled(); + }, + ); + }); + + it('should throw error when subscription status is not valid for crypto approval', async () => { + await withController( + { + state: { + pricing: MOCK_PRICE_INFO_RESPONSE, + trialedProducts: [], + subscriptions: [], + lastSelectedPaymentMethod: { + [PRODUCT_TYPES.SHIELD]: { + type: PAYMENT_TYPES.byCrypto, + paymentTokenAddress: '0xtoken', + paymentTokenSymbol: 'USDT', + plan: RECURRING_INTERVALS.month, + }, + }, + }, + }, + async ({ controller, mockService }) => { + mockService.getSubscriptions.mockResolvedValue({ + subscriptions: [ + { + ...MOCK_SUBSCRIPTION, + status: SUBSCRIPTION_STATUSES.incomplete, + }, + ], + trialedProducts: [], + }); + + const txMeta = { + ...generateMockTxMeta(), + type: TransactionType.shieldSubscriptionApprove, + chainId: '0x1' as Hex, + rawTx: '0x123', + txParams: { + data: '0x456', + from: '0x1234567890123456789012345678901234567890', + to: '0xtoken', + }, + status: TransactionStatus.submitted, + }; + + await expect( + controller.submitShieldSubscriptionCryptoApproval(txMeta), + ).rejects.toThrow( + SubscriptionControllerErrorMessage.SubscriptionNotValidForCryptoApproval, + ); + }, + ); + }); }); }); diff --git a/packages/subscription-controller/src/SubscriptionController.ts b/packages/subscription-controller/src/SubscriptionController.ts index 6a654d39fed..ffdb2c2cdbc 100644 --- a/packages/subscription-controller/src/SubscriptionController.ts +++ b/packages/subscription-controller/src/SubscriptionController.ts @@ -33,10 +33,12 @@ import type { CachedLastSelectedPaymentMethod, SubmitSponsorshipIntentsMethodParams, RecurringInterval, + SubscriptionStatus, } from './types'; import { PAYMENT_TYPES, PRODUCT_TYPES, + SUBSCRIPTION_STATUSES, type ISubscriptionService, type PricingResponse, type ProductType, @@ -509,26 +511,50 @@ export class SubscriptionController extends StaticIntervalPollingController()< lastSelectedPaymentMethod[PRODUCT_TYPES.SHIELD]; this.#assertIsPaymentMethodCrypto(lastSelectedPaymentMethodShield); - const isTrialed = trialedProducts?.includes(PRODUCT_TYPES.SHIELD); - const productPrice = this.#getProductPriceByProductAndPlan( PRODUCT_TYPES.SHIELD, lastSelectedPaymentMethodShield.plan, ); + const isTrialed = trialedProducts?.includes(PRODUCT_TYPES.SHIELD); + // get the latest subscriptions state to check if the user has an active shield subscription + await this.getSubscriptions(); + const currentSubscription = this.state.subscriptions.find((subscription) => + subscription.products.some((p) => p.name === PRODUCT_TYPES.SHIELD), + ); + + this.#assertValidSubscriptionStateForCryptoApproval({ + product: PRODUCT_TYPES.SHIELD, + }); + // if shield subscription exists, this transaction is for changing payment method + const isChangePaymentMethod = Boolean(currentSubscription); + + if (isChangePaymentMethod) { + await this.updatePaymentMethod({ + paymentType: PAYMENT_TYPES.byCrypto, + subscriptionId: (currentSubscription as Subscription).id, + chainId, + payerAddress: txMeta.txParams.from as Hex, + tokenSymbol: lastSelectedPaymentMethodShield.paymentTokenSymbol, + rawTransaction: rawTx as Hex, + recurringInterval: productPrice.interval, + billingCycles: productPrice.minBillingCycles, + }); + } else { + const params = { + products: [PRODUCT_TYPES.SHIELD], + isTrialRequested: !isTrialed, + recurringInterval: productPrice.interval, + billingCycles: productPrice.minBillingCycles, + chainId, + payerAddress: txMeta.txParams.from as Hex, + tokenSymbol: lastSelectedPaymentMethodShield.paymentTokenSymbol, + rawTransaction: rawTx as Hex, + isSponsored, + useTestClock: lastSelectedPaymentMethodShield.useTestClock, + }; + await this.startSubscriptionWithCrypto(params); + } - const params = { - products: [PRODUCT_TYPES.SHIELD], - isTrialRequested: !isTrialed, - recurringInterval: productPrice.interval, - billingCycles: productPrice.minBillingCycles, - chainId, - payerAddress: txMeta.txParams.from as Hex, - tokenSymbol: lastSelectedPaymentMethodShield.paymentTokenSymbol, - rawTransaction: rawTx as Hex, - isSponsored, - useTestClock: lastSelectedPaymentMethodShield.useTestClock, - }; - await this.startSubscriptionWithCrypto(params); // update the subscriptions state after subscription created in server await this.getSubscriptions(); } @@ -799,6 +825,34 @@ export class SubscriptionController extends StaticIntervalPollingController()< return productPrice; } + #assertValidSubscriptionStateForCryptoApproval({ + product, + }: { + product: ProductType; + }) { + const subscription = this.state.subscriptions.find((sub) => + sub.products.some((p) => p.name === product), + ); + + const isValid = + !subscription || + ( + [ + SUBSCRIPTION_STATUSES.pastDue, + SUBSCRIPTION_STATUSES.unpaid, + SUBSCRIPTION_STATUSES.paused, + SUBSCRIPTION_STATUSES.provisional, + SUBSCRIPTION_STATUSES.active, + SUBSCRIPTION_STATUSES.trialing, + ] as SubscriptionStatus[] + ).includes(subscription.status); + if (!isValid) { + throw new Error( + SubscriptionControllerErrorMessage.SubscriptionNotValidForCryptoApproval, + ); + } + } + #assertIsUserNotSubscribed({ products }: { products: ProductType[] }) { const subscription = this.state.subscriptions.find((sub) => sub.products.some((p) => products.includes(p.name)), diff --git a/packages/subscription-controller/src/constants.ts b/packages/subscription-controller/src/constants.ts index d5bdc8ed0f7..a86a1e22dc0 100644 --- a/packages/subscription-controller/src/constants.ts +++ b/packages/subscription-controller/src/constants.ts @@ -45,6 +45,7 @@ export enum SubscriptionControllerErrorMessage { PaymentTokenAddressAndSymbolRequiredForCrypto = `${controllerName} - Payment token address and symbol are required for crypto payment`, PaymentMethodNotCrypto = `${controllerName} - Payment method is not crypto`, ProductPriceNotFound = `${controllerName} - Product price not found`, + SubscriptionNotValidForCryptoApproval = `${controllerName} - Subscription is not valid for crypto approval`, } export const DEFAULT_POLLING_INTERVAL = 5 * 60 * 1_000; // 5 minutes