diff --git a/packages/subscription-controller/CHANGELOG.md b/packages/subscription-controller/CHANGELOG.md index b8b2c5eba5..fc4fb2f3ee 100644 --- a/packages/subscription-controller/CHANGELOG.md +++ b/packages/subscription-controller/CHANGELOG.md @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `cancelSubscription`: Cancel user active subscription. - `startShieldSubscriptionWithCard`: start shield subscription via card (with trial option) ([#6300](https://github.com/MetaMask/core/pull/6300)) - Add `getPricing` method ([#6356](https://github.com/MetaMask/core/pull/6356)) +- Add methods `startSubscriptionWithCrypto` and `getCryptoApproveTransactionParams` method ([#6456](https://github.com/MetaMask/core/pull/6456)) - Added `triggerAccessTokenRefresh` to trigger an access token refresh ([#6374](https://github.com/MetaMask/core/pull/6374)) [Unreleased]: https://github.com/MetaMask/core/ diff --git a/packages/subscription-controller/package.json b/packages/subscription-controller/package.json index 6761e9f52d..e147744aa5 100644 --- a/packages/subscription-controller/package.json +++ b/packages/subscription-controller/package.json @@ -48,6 +48,7 @@ }, "dependencies": { "@metamask/base-controller": "^8.3.0", + "@metamask/controller-utils": "^11.12.0", "@metamask/utils": "^11.4.2" }, "devDependencies": { @@ -56,7 +57,6 @@ "@types/jest": "^27.4.1", "deepmerge": "^4.2.2", "jest": "^27.5.1", - "nock": "^13.3.1", "ts-jest": "^27.1.4", "typedoc": "^0.24.8", "typedoc-plugin-missing-exports": "^2.0.0", diff --git a/packages/subscription-controller/src/SubscriptionController.test.ts b/packages/subscription-controller/src/SubscriptionController.test.ts index 63af795857..feca9b5c8e 100644 --- a/packages/subscription-controller/src/SubscriptionController.test.ts +++ b/packages/subscription-controller/src/SubscriptionController.test.ts @@ -14,7 +14,14 @@ import { type SubscriptionControllerOptions, type SubscriptionControllerState, } from './SubscriptionController'; -import type { Subscription, PricingResponse } from './types'; +import type { + Subscription, + PricingResponse, + ProductPricing, + PricingPaymentMethod, + StartCryptoSubscriptionRequest, + StartCryptoSubscriptionResponse, +} from './types'; import { PaymentType, ProductType, @@ -29,8 +36,8 @@ const MOCK_SUBSCRIPTION: Subscription = { { name: ProductType.SHIELD, id: 'prod_shield_basic', - currency: 'USD', - amount: 9.99, + currency: 'usd', + amount: 900, }, ], currentPeriodStart: '2024-01-01T00:00:00Z', @@ -42,6 +49,43 @@ const MOCK_SUBSCRIPTION: Subscription = { }, }; +const MOCK_PRODUCT_PRICE: ProductPricing = { + name: ProductType.SHIELD, + prices: [ + { + interval: RecurringInterval.month, + currency: 'usd', + unitAmount: 900, + unitDecimals: 2, + trialPeriodDays: 0, + minBillingCycles: 1, + }, + ], +}; + +const MOCK_PRICING_PAYMENT_METHOD: PricingPaymentMethod = { + type: PaymentType.byCrypto, + chains: [ + { + chainId: '0x1', + paymentAddress: '0xspender', + tokens: [ + { + address: '0xtoken', + symbol: 'USDT', + decimals: 18, + conversionRate: { usd: '1.0' }, + }, + ], + }, + ], +}; + +const MOCK_PRICE_INFO_RESPONSE: PricingResponse = { + products: [MOCK_PRODUCT_PRICE], + paymentMethods: [MOCK_PRICING_PAYMENT_METHOD], +}; + /** * Creates a custom subscription messenger, in case tests need different permissions * @@ -113,12 +157,14 @@ function createMockSubscriptionService() { const mockCancelSubscription = jest.fn(); const mockStartSubscriptionWithCard = jest.fn(); const mockGetPricing = jest.fn(); + const mockStartSubscriptionWithCrypto = jest.fn(); const mockService = { getSubscriptions: mockGetSubscriptions, cancelSubscription: mockCancelSubscription, startSubscriptionWithCard: mockStartSubscriptionWithCard, getPricing: mockGetPricing, + startSubscriptionWithCrypto: mockStartSubscriptionWithCrypto, }; return { @@ -127,6 +173,7 @@ function createMockSubscriptionService() { mockCancelSubscription, mockStartSubscriptionWithCard, mockGetPricing, + mockStartSubscriptionWithCrypto, }; } @@ -137,6 +184,7 @@ type WithControllerCallback = (params: { controller: SubscriptionController; initialState: SubscriptionControllerState; messenger: SubscriptionControllerMessenger; + baseMessenger: Messenger; mockService: ReturnType['mockService']; mockPerformSignOut: jest.Mock; }) => Promise | ReturnValue; @@ -157,7 +205,8 @@ async function withController( ...args: WithControllerArgs ) { const [{ ...rest }, fn] = args.length === 2 ? args : [{}, args[0]]; - const { messenger, mockPerformSignOut } = createMockSubscriptionMessenger(); + const { messenger, mockPerformSignOut, baseMessenger } = + createMockSubscriptionMessenger(); const { mockService } = createMockSubscriptionService(); const controller = new SubscriptionController({ @@ -170,6 +219,7 @@ async function withController( controller, initialState: controller.state, messenger, + baseMessenger, mockService, mockPerformSignOut, }); @@ -489,6 +539,44 @@ describe('SubscriptionController', () => { }); }); + describe('startCryptoSubscription', () => { + it('should start crypto subscription successfully when user is not subscribed', async () => { + await withController( + { + state: { + subscriptions: [], + }, + }, + async ({ controller, mockService }) => { + const request: StartCryptoSubscriptionRequest = { + products: [ProductType.SHIELD], + isTrialRequested: false, + recurringInterval: RecurringInterval.month, + billingCycles: 3, + chainId: '0x1', + payerAddress: '0x0000000000000000000000000000000000000001', + tokenSymbol: 'USDC', + rawTransaction: '0xdeadbeef', + }; + + const response: StartCryptoSubscriptionResponse = { + subscriptionId: 'sub_crypto_123', + status: SubscriptionStatus.active, + }; + + mockService.startSubscriptionWithCrypto.mockResolvedValue(response); + + const result = await controller.startSubscriptionWithCrypto(request); + + expect(result).toStrictEqual(response); + expect(mockService.startSubscriptionWithCrypto).toHaveBeenCalledWith( + request, + ); + }, + ); + }); + }); + describe('integration scenarios', () => { it('should handle complete subscription lifecycle with updated logic', async () => { await withController(async ({ controller, mockService }) => { @@ -547,6 +635,181 @@ describe('SubscriptionController', () => { }); }); + describe('getCryptoApproveTransactionParams', () => { + it('returns transaction params for crypto approve transaction', async () => { + await withController(async ({ controller, mockService }) => { + // Provide product pricing and crypto payment info with unitDecimals small to avoid integer div to 0 + mockService.getPricing.mockResolvedValue(MOCK_PRICE_INFO_RESPONSE); + + const result = await controller.getCryptoApproveTransactionParams({ + chainId: '0x1', + paymentTokenAddress: '0xtoken', + productType: ProductType.SHIELD, + interval: RecurringInterval.month, + }); + + expect(result).toStrictEqual({ + approveAmount: '9000000000000000000', + paymentAddress: '0xspender', + paymentTokenAddress: '0xtoken', + chainId: '0x1', + }); + }); + }); + + it('throws when product price not found', async () => { + await withController(async ({ controller, mockService }) => { + mockService.getPricing.mockResolvedValue({ + products: [], + paymentMethods: [], + }); + + await expect( + controller.getCryptoApproveTransactionParams({ + chainId: '0x1', + paymentTokenAddress: '0xtoken', + productType: ProductType.SHIELD, + interval: RecurringInterval.month, + }), + ).rejects.toThrow('Product price not found'); + }); + }); + + it('throws when price not found for interval', async () => { + await withController(async ({ controller, mockService }) => { + mockService.getPricing.mockResolvedValue({ + products: [ + { + name: ProductType.SHIELD, + prices: [ + { + interval: RecurringInterval.year, + currency: 'usd', + unitAmount: 10, + unitDecimals: 18, + trialPeriodDays: 0, + minBillingCycles: 1, + }, + ], + }, + ], + paymentMethods: [], + }); + + await expect( + controller.getCryptoApproveTransactionParams({ + chainId: '0x1', + paymentTokenAddress: '0xtoken', + productType: ProductType.SHIELD, + interval: RecurringInterval.month, + }), + ).rejects.toThrow('Price not found'); + }); + }); + + it('throws when chains payment info not found', async () => { + await withController(async ({ controller, mockService }) => { + mockService.getPricing.mockResolvedValue({ + ...MOCK_PRICE_INFO_RESPONSE, + paymentMethods: [ + { + type: PaymentType.byCard, + }, + ], + }); + + await expect( + controller.getCryptoApproveTransactionParams({ + chainId: '0x1', + paymentTokenAddress: '0xtoken', + productType: ProductType.SHIELD, + interval: RecurringInterval.month, + }), + ).rejects.toThrow('Chains payment info not found'); + }); + }); + + it('throws when invalid chain id', async () => { + await withController(async ({ controller, mockService }) => { + mockService.getPricing.mockResolvedValue({ + ...MOCK_PRICE_INFO_RESPONSE, + paymentMethods: [ + { + type: PaymentType.byCrypto, + chains: [ + { + chainId: '0x2', + paymentAddress: '0xspender', + tokens: [], + }, + ], + }, + ], + }); + + await expect( + controller.getCryptoApproveTransactionParams({ + chainId: '0x1', + paymentTokenAddress: '0xtoken', + productType: ProductType.SHIELD, + interval: RecurringInterval.month, + }), + ).rejects.toThrow('Invalid chain id'); + }); + }); + + it('throws when invalid token address', async () => { + await withController(async ({ controller, mockService }) => { + mockService.getPricing.mockResolvedValue(MOCK_PRICE_INFO_RESPONSE); + + await expect( + controller.getCryptoApproveTransactionParams({ + chainId: '0x1', + paymentTokenAddress: '0xtoken-invalid', + productType: ProductType.SHIELD, + interval: RecurringInterval.month, + }), + ).rejects.toThrow('Invalid token address'); + }); + }); + + it('throws when conversion rate not found', async () => { + await withController(async ({ controller, mockService }) => { + // Valid product and chain/token, but token lacks conversion rate for currency + mockService.getPricing.mockResolvedValue({ + ...MOCK_PRICE_INFO_RESPONSE, + paymentMethods: [ + { + type: PaymentType.byCrypto, + chains: [ + { + chainId: '0x1', + paymentAddress: '0xspender', + tokens: [ + { + address: '0xtoken', + decimals: 18, + conversionRate: {}, + }, + ], + }, + ], + }, + ], + }); + + await expect( + controller.getCryptoApproveTransactionParams({ + chainId: '0x1', + paymentTokenAddress: '0xtoken', + productType: ProductType.SHIELD, + interval: RecurringInterval.month, + }), + ).rejects.toThrow('Conversion rate not found'); + }); + }); + }); + describe('triggerAuthTokenRefresh', () => { it('should trigger auth token refresh', async () => { await withController(async ({ controller, mockPerformSignOut }) => { diff --git a/packages/subscription-controller/src/SubscriptionController.ts b/packages/subscription-controller/src/SubscriptionController.ts index 0327710c0f..6853f380e9 100644 --- a/packages/subscription-controller/src/SubscriptionController.ts +++ b/packages/subscription-controller/src/SubscriptionController.ts @@ -11,7 +11,15 @@ import { controllerName, SubscriptionControllerErrorMessage, } from './constants'; +import type { + GetCryptoApproveTransactionRequest, + GetCryptoApproveTransactionResponse, + ProductPrice, + StartCryptoSubscriptionRequest, + TokenPaymentInfo, +} from './types'; import { + PaymentType, SubscriptionStatus, type ISubscriptionService, type PricingResponse, @@ -41,6 +49,14 @@ export type SubscriptionControllerGetPricingAction = { type: `${typeof controllerName}:getPricing`; handler: SubscriptionController['getPricing']; }; +export type SubscriptionControllerGetCryptoApproveTransactionParamsAction = { + type: `${typeof controllerName}:getCryptoApproveTransactionParams`; + handler: SubscriptionController['getCryptoApproveTransactionParams']; +}; +export type SubscriptionControllerStartSubscriptionWithCryptoAction = { + type: `${typeof controllerName}:startSubscriptionWithCrypto`; + handler: SubscriptionController['startSubscriptionWithCrypto']; +}; export type SubscriptionControllerGetStateAction = ControllerGetStateAction< typeof controllerName, @@ -51,7 +67,9 @@ export type SubscriptionControllerActions = | SubscriptionControllerCancelSubscriptionAction | SubscriptionControllerStartShieldSubscriptionWithCardAction | SubscriptionControllerGetPricingAction - | SubscriptionControllerGetStateAction; + | SubscriptionControllerGetStateAction + | SubscriptionControllerGetCryptoApproveTransactionParamsAction + | SubscriptionControllerStartSubscriptionWithCryptoAction; export type AllowedActions = | AuthenticationController.AuthenticationControllerGetBearerToken @@ -151,7 +169,6 @@ export class SubscriptionController extends BaseController< }); this.#subscriptionService = subscriptionService; - this.#registerMessageHandlers(); } @@ -179,6 +196,16 @@ export class SubscriptionController extends BaseController< 'SubscriptionController:getPricing', this.getPricing.bind(this), ); + + this.messagingSystem.registerActionHandler( + 'SubscriptionController:getCryptoApproveTransactionParams', + this.getCryptoApproveTransactionParams.bind(this), + ); + + this.messagingSystem.registerActionHandler( + 'SubscriptionController:startSubscriptionWithCrypto', + this.startSubscriptionWithCrypto.bind(this), + ); } /** @@ -230,6 +257,122 @@ export class SubscriptionController extends BaseController< return response; } + async startSubscriptionWithCrypto(request: StartCryptoSubscriptionRequest) { + this.#assertIsUserNotSubscribed({ products: request.products }); + return await this.#subscriptionService.startSubscriptionWithCrypto(request); + } + + /** + * Get transaction params to create crypto approve transaction for subscription payment + * + * @param request - The request object + * @param request.chainId - The chain ID + * @param request.tokenAddress - The address of the token + * @param request.productType - The product type + * @param request.interval - The interval + * @returns The crypto approve transaction params + */ + async getCryptoApproveTransactionParams( + request: GetCryptoApproveTransactionRequest, + ): Promise { + const pricing = await this.getPricing(); + const product = pricing.products.find( + (p) => p.name === request.productType, + ); + if (!product) { + throw new Error('Product price not found'); + } + + const price = product.prices.find((p) => p.interval === request.interval); + if (!price) { + throw new Error('Price not found'); + } + + const chainsPaymentInfo = pricing.paymentMethods.find( + (t) => t.type === PaymentType.byCrypto, + ); + if (!chainsPaymentInfo) { + throw new Error('Chains payment info not found'); + } + const chainPaymentInfo = chainsPaymentInfo.chains?.find( + (t) => t.chainId === request.chainId, + ); + if (!chainPaymentInfo) { + throw new Error('Invalid chain id'); + } + const tokenPaymentInfo = chainPaymentInfo.tokens.find( + (t) => t.address === request.paymentTokenAddress, + ); + if (!tokenPaymentInfo) { + throw new Error('Invalid token address'); + } + + const tokenApproveAmount = this.#getTokenApproveAmount( + price, + tokenPaymentInfo, + ); + + return { + approveAmount: tokenApproveAmount.toString(), + paymentAddress: chainPaymentInfo.paymentAddress, + paymentTokenAddress: request.paymentTokenAddress, + chainId: request.chainId, + }; + } + + /** + * Calculate total subscription price amount from price info + * e.g: $8 per month * 12 months min billing cycles = $96 + * + * @param price - The price info + * @returns The price amount + */ + #getSubscriptionPriceAmount(price: ProductPrice) { + // no need to use BigInt since max unitDecimals are always 2 for price + const amount = + (price.unitAmount / 10 ** price.unitDecimals) * price.minBillingCycles; + return amount; + } + + /** + * Calculate token approve amount from price info + * + * @param price - The price info + * @param tokenPaymentInfo - The token price info + * @returns The token approve amount + */ + #getTokenApproveAmount( + price: ProductPrice, + tokenPaymentInfo: TokenPaymentInfo, + ) { + const conversionRate = + tokenPaymentInfo.conversionRate[ + price.currency as keyof typeof tokenPaymentInfo.conversionRate + ]; + if (!conversionRate) { + throw new Error('Conversion rate not found'); + } + // conversion rate is a float string e.g: "1.0" + // We need to handle float conversion rates with integer math for BigInt. + // We'll scale the conversion rate to an integer by multiplying by 10^4. + // conversionRate is in usd decimal. In most currencies, we only care about 2 decimals (cents) + // So, scale must be max of 10 ** 4 (most exchanges trade with max 4 decimals of usd) + // This allows us to avoid floating point math and keep precision. + const SCALE = 10n ** 4n; + const conversionRateScaled = + BigInt(Math.round(Number(conversionRate) * Number(SCALE))) / SCALE; + // price of the product + const priceAmount = this.#getSubscriptionPriceAmount(price); + const priceAmountScaled = + BigInt(Math.round(priceAmount * Number(SCALE))) / SCALE; + + const tokenDecimal = BigInt(10) ** BigInt(tokenPaymentInfo.decimals); + + const tokenAmount = + (priceAmountScaled * tokenDecimal) / conversionRateScaled; + return tokenAmount; + } + #assertIsUserNotSubscribed({ products }: { products: ProductType[] }) { if ( this.state.subscriptions.find((subscription) => diff --git a/packages/subscription-controller/src/SubscriptionService.test.ts b/packages/subscription-controller/src/SubscriptionService.test.ts index 83a04c3b7e..d5197f8124 100644 --- a/packages/subscription-controller/src/SubscriptionService.test.ts +++ b/packages/subscription-controller/src/SubscriptionService.test.ts @@ -1,4 +1,4 @@ -import nock, { cleanAll, isDone } from 'nock'; +import { handleFetch } from '@metamask/controller-utils'; import { Env, @@ -9,6 +9,7 @@ import { SubscriptionServiceError } from './errors'; import { SubscriptionService } from './SubscriptionService'; import type { StartSubscriptionRequest, + StartCryptoSubscriptionRequest, Subscription, PricingResponse, } from './types'; @@ -19,6 +20,11 @@ import { SubscriptionStatus, } from './types'; +// Mock the handleFetch function +jest.mock('@metamask/controller-utils', () => ({ + handleFetch: jest.fn(), +})); + // Mock data const MOCK_SUBSCRIPTION: Subscription = { id: 'sub_123456789', @@ -26,7 +32,7 @@ const MOCK_SUBSCRIPTION: Subscription = { { name: ProductType.SHIELD, id: 'prod_shield_basic', - currency: 'USD', + currency: 'usd', amount: 9.99, }, ], @@ -41,11 +47,6 @@ const MOCK_SUBSCRIPTION: Subscription = { const MOCK_ACCESS_TOKEN = 'mock-access-token-12345'; -const MOCK_ERROR_RESPONSE = { - message: 'Subscription not found', - error: 'NOT_FOUND', -}; - const MOCK_START_SUBSCRIPTION_REQUEST: StartSubscriptionRequest = { products: [ProductType.SHIELD], isTrialRequested: true, @@ -61,19 +62,14 @@ const MOCK_START_SUBSCRIPTION_RESPONSE = { * * @param params - The parameters object * @param [params.env] - The environment to use for the config - * @param [params.fetchFn] - The fetch function to use for the config * @returns The mock configuration object */ -function createMockConfig({ - env = Env.DEV, - fetchFn = fetch, -}: { env?: Env; fetchFn?: typeof fetch } = {}) { +function createMockConfig({ env = Env.DEV }: { env?: Env } = {}) { return { env, auth: { getAccessToken: jest.fn().mockResolvedValue(MOCK_ACCESS_TOKEN), }, - fetchFn, }; } @@ -107,8 +103,8 @@ function withMockSubscriptionService( } describe('SubscriptionService', () => { - afterEach(() => { - cleanAll(); + beforeEach(() => { + jest.clearAllMocks(); }); describe('constructor', () => { @@ -132,35 +128,29 @@ describe('SubscriptionService', () => { describe('getSubscriptions', () => { it('should fetch subscriptions successfully', async () => { - await withMockSubscriptionService( - async ({ service, testUrl, config }) => { - nock(testUrl) - .get('/api/v1/subscriptions') - .matchHeader('Authorization', `Bearer ${MOCK_ACCESS_TOKEN}`) - .reply(200, { - customerId: 'cus_1', - subscriptions: [MOCK_SUBSCRIPTION], - trialedProducts: [], - }); - - const result = await service.getSubscriptions(); - - expect(result).toStrictEqual({ - customerId: 'cus_1', - subscriptions: [MOCK_SUBSCRIPTION], - trialedProducts: [], - }); - expect(config.auth.getAccessToken).toHaveBeenCalledTimes(1); - expect(isDone()).toBe(true); - }, - ); + await withMockSubscriptionService(async ({ service, config }) => { + (handleFetch as jest.Mock).mockResolvedValue({ + customerId: 'cus_1', + subscriptions: [MOCK_SUBSCRIPTION], + trialedProducts: [], + }); + + const result = await service.getSubscriptions(); + + expect(result).toStrictEqual({ + customerId: 'cus_1', + subscriptions: [MOCK_SUBSCRIPTION], + trialedProducts: [], + }); + expect(config.auth.getAccessToken).toHaveBeenCalledTimes(1); + }); }); it('should throw SubscriptionServiceError for error responses', async () => { - await withMockSubscriptionService(async ({ service, testUrl }) => { - nock(testUrl) - .get('/api/v1/subscriptions') - .reply(404, MOCK_ERROR_RESPONSE); + await withMockSubscriptionService(async ({ service }) => { + (handleFetch as jest.Mock).mockRejectedValue( + new Error('Network error'), + ); await expect(service.getSubscriptions()).rejects.toThrow( SubscriptionServiceError, @@ -169,10 +159,10 @@ describe('SubscriptionService', () => { }); it('should throw SubscriptionServiceError for network errors', async () => { - await withMockSubscriptionService(async ({ service, testUrl }) => { - nock(testUrl) - .get('/api/v1/subscriptions') - .replyWithError('Network error'); + await withMockSubscriptionService(async ({ service }) => { + (handleFetch as jest.Mock).mockRejectedValue( + new Error('Network error'), + ); await expect(service.getSubscriptions()).rejects.toThrow( SubscriptionServiceError, @@ -192,64 +182,57 @@ describe('SubscriptionService', () => { }); it('should handle null exceptions in catch block', async () => { - const fetchMock = jest.fn().mockRejectedValueOnce(null); - const config = createMockConfig({ fetchFn: fetchMock }); + const config = createMockConfig({}); const service = new SubscriptionService(config); + (handleFetch as jest.Mock).mockRejectedValue(null); await expect( service.cancelSubscription({ subscriptionId: 'sub_123456789' }), ).rejects.toThrow(SubscriptionServiceError); }); + + it('should handle non-Error exceptions in catch block', async () => { + await withMockSubscriptionService(async ({ service }) => { + // Mock handleFetch to throw null (not an Error instance) + (handleFetch as jest.Mock).mockRejectedValue(null); + + await expect(service.getSubscriptions()).rejects.toThrow( + SubscriptionServiceError, + ); + }); + }); }); describe('cancelSubscription', () => { it('should cancel subscription successfully', async () => { - await withMockSubscriptionService( - async ({ service, testUrl, config }) => { - nock(testUrl) - .delete('/api/v1/subscriptions/sub_123456789') - .matchHeader('Authorization', `Bearer ${MOCK_ACCESS_TOKEN}`) - .reply(200, {}); - - await service.cancelSubscription({ subscriptionId: 'sub_123456789' }); - - expect(config.auth.getAccessToken).toHaveBeenCalledTimes(1); - expect(isDone()).toBe(true); - }, - ); - }); + await withMockSubscriptionService(async ({ service, config }) => { + (handleFetch as jest.Mock).mockResolvedValue({}); - it('should throw SubscriptionServiceError for error responses', async () => { - await withMockSubscriptionService(async ({ service, testUrl }) => { - nock(testUrl) - .delete('/api/v1/subscriptions/sub_123456789') - .reply(400, MOCK_ERROR_RESPONSE); + await service.cancelSubscription({ subscriptionId: 'sub_123456789' }); - await expect( - service.cancelSubscription({ subscriptionId: 'sub_123456789' }), - ).rejects.toThrow(/Subscription not found/u); + expect(config.auth.getAccessToken).toHaveBeenCalledTimes(1); }); }); it('should throw SubscriptionServiceError for network errors', async () => { - await withMockSubscriptionService(async ({ service, testUrl }) => { - nock(testUrl) - .delete('/api/v1/subscriptions/sub_123456789') - .replyWithError('Network error'); + await withMockSubscriptionService(async ({ service }) => { + (handleFetch as jest.Mock).mockRejectedValue( + new Error('Network error'), + ); await expect( service.cancelSubscription({ subscriptionId: 'sub_123456789' }), - ).rejects.toThrow(/Network error/u); + ).rejects.toThrow(SubscriptionServiceError); }); }); }); describe('startSubscription', () => { it('should start subscription successfully', async () => { - await withMockSubscriptionService(async ({ service, testUrl }) => { - nock(testUrl) - .post('/api/v1/subscriptions/card', MOCK_START_SUBSCRIPTION_REQUEST) - .reply(200, MOCK_START_SUBSCRIPTION_RESPONSE); + await withMockSubscriptionService(async ({ service }) => { + (handleFetch as jest.Mock).mockResolvedValue( + MOCK_START_SUBSCRIPTION_RESPONSE, + ); const result = await service.startSubscriptionWithCard( MOCK_START_SUBSCRIPTION_REQUEST, @@ -262,16 +245,15 @@ describe('SubscriptionService', () => { it('should start subscription without trial', async () => { const config = createMockConfig(); const service = new SubscriptionService(config); - const testUrl = getTestUrl(Env.DEV); const request: StartSubscriptionRequest = { products: [ProductType.SHIELD], isTrialRequested: false, recurringInterval: RecurringInterval.month, }; - nock(testUrl) - .post('/api/v1/subscriptions/card', request) - .reply(200, MOCK_START_SUBSCRIPTION_RESPONSE); + (handleFetch as jest.Mock).mockResolvedValue( + MOCK_START_SUBSCRIPTION_RESPONSE, + ); const result = await service.startSubscriptionWithCard(request); @@ -293,6 +275,34 @@ describe('SubscriptionService', () => { }); }); + describe('startCryptoSubscription', () => { + it('should start crypto subscription successfully', async () => { + await withMockSubscriptionService(async ({ service }) => { + const request: StartCryptoSubscriptionRequest = { + products: [ProductType.SHIELD], + isTrialRequested: false, + recurringInterval: RecurringInterval.month, + billingCycles: 3, + chainId: '0x1', + payerAddress: '0x0000000000000000000000000000000000000001', + tokenSymbol: 'USDC', + rawTransaction: '0xdeadbeef', + }; + + const response = { + subscriptionId: 'sub_crypto_123', + status: SubscriptionStatus.active, + }; + + (handleFetch as jest.Mock).mockResolvedValue(response); + + const result = await service.startSubscriptionWithCrypto(request); + + expect(result).toStrictEqual(response); + }); + }); + }); + describe('getPricing', () => { const mockPricingResponse: PricingResponse = { products: [], @@ -302,9 +312,8 @@ describe('SubscriptionService', () => { it('should fetch pricing successfully', async () => { const config = createMockConfig(); const service = new SubscriptionService(config); - const testUrl = getTestUrl(Env.DEV); - nock(testUrl).get('/api/v1/pricing').reply(200, mockPricingResponse); + (handleFetch as jest.Mock).mockResolvedValue(mockPricingResponse); const result = await service.getPricing(); diff --git a/packages/subscription-controller/src/SubscriptionService.ts b/packages/subscription-controller/src/SubscriptionService.ts index 3bd784869a..225de019ee 100644 --- a/packages/subscription-controller/src/SubscriptionService.ts +++ b/packages/subscription-controller/src/SubscriptionService.ts @@ -1,3 +1,5 @@ +import { handleFetch } from '@metamask/controller-utils'; + import { getEnvUrls, SubscriptionControllerErrorMessage, @@ -9,6 +11,8 @@ import type { GetSubscriptionsResponse, ISubscriptionService, PricingResponse, + StartCryptoSubscriptionRequest, + StartCryptoSubscriptionResponse, StartSubscriptionRequest, StartSubscriptionResponse, } from './types'; @@ -16,28 +20,19 @@ import type { export type SubscriptionServiceConfig = { env: Env; auth: AuthUtils; - fetchFn: typeof globalThis.fetch; -}; - -type ErrorMessage = { - message: string; - error: string; }; export const SUBSCRIPTION_URL = (env: Env, path: string) => - `${getEnvUrls(env).subscriptionApiUrl}/api/v1/${path}`; + `${getEnvUrls(env).subscriptionApiUrl}/v1/${path}`; export class SubscriptionService implements ISubscriptionService { readonly #env: Env; - readonly #fetch: typeof globalThis.fetch; - public authUtils: AuthUtils; constructor(config: SubscriptionServiceConfig) { this.#env = config.env; this.authUtils = config.auth; - this.#fetch = config.fetchFn; } async getSubscriptions(): Promise { @@ -63,6 +58,13 @@ export class SubscriptionService implements ISubscriptionService { return await this.#makeRequest(path, 'POST', request); } + async startSubscriptionWithCrypto( + request: StartCryptoSubscriptionRequest, + ): Promise { + const path = 'subscriptions/crypto'; + return await this.#makeRequest(path, 'POST', request); + } + async #makeRequest( path: string, method: 'GET' | 'POST' | 'DELETE' | 'PUT' | 'PATCH' = 'GET', @@ -72,7 +74,7 @@ export class SubscriptionService implements ISubscriptionService { const headers = await this.#getAuthorizationHeader(); const url = new URL(SUBSCRIPTION_URL(this.#env, path)); - const response = await this.#fetch(url.toString(), { + const response = await handleFetch(url.toString(), { method, headers: { 'Content-Type': 'application/json', @@ -81,13 +83,7 @@ export class SubscriptionService implements ISubscriptionService { body: body ? JSON.stringify(body) : undefined, }); - const responseBody = await response.json(); - if (!response.ok) { - const { message, error } = responseBody as ErrorMessage; - throw new Error(`HTTP error message: ${message}, error: ${error}`); - } - - return responseBody as Result; + return response; } catch (e) { const errorMessage = e instanceof Error ? e.message : JSON.stringify(e ?? 'unknown error'); diff --git a/packages/subscription-controller/src/constants.test.ts b/packages/subscription-controller/src/constants.test.ts index d190b62582..a2277bc61a 100644 --- a/packages/subscription-controller/src/constants.test.ts +++ b/packages/subscription-controller/src/constants.test.ts @@ -1,30 +1,8 @@ -import { Env, getEnvUrls, controllerName } from './constants'; +import type { Env } from './constants'; +import { getEnvUrls, controllerName } from './constants'; describe('constants', () => { describe('getEnvUrls', () => { - it('should return correct URLs for dev environment', () => { - const result = getEnvUrls(Env.DEV); - expect(result).toStrictEqual({ - subscriptionApiUrl: - 'https://subscription-service.dev-api.cx.metamask.io', - }); - }); - - it('should return correct URLs for uat environment', () => { - const result = getEnvUrls(Env.UAT); - expect(result).toStrictEqual({ - subscriptionApiUrl: - 'https://subscription-service.uat-api.cx.metamask.io', - }); - }); - - it('should return correct URLs for prd environment', () => { - const result = getEnvUrls(Env.PRD); - expect(result).toStrictEqual({ - subscriptionApiUrl: 'https://subscription-service.api.cx.metamask.io', - }); - }); - it('should throw error for invalid environment', () => { // Type assertion to test invalid environment const invalidEnv = 'invalid' as Env; diff --git a/packages/subscription-controller/src/constants.ts b/packages/subscription-controller/src/constants.ts index fd736d53ba..a3b78dcffd 100644 --- a/packages/subscription-controller/src/constants.ts +++ b/packages/subscription-controller/src/constants.ts @@ -12,13 +12,13 @@ type EnvUrlsEntry = { const ENV_URLS: Record = { dev: { - subscriptionApiUrl: 'https://subscription-service.dev-api.cx.metamask.io', + subscriptionApiUrl: 'https://subscription.dev-api.cx.metamask.io', }, uat: { - subscriptionApiUrl: 'https://subscription-service.uat-api.cx.metamask.io', + subscriptionApiUrl: 'https://subscription.uat-api.cx.metamask.io', }, prd: { - subscriptionApiUrl: 'https://subscription-service.api.cx.metamask.io', + subscriptionApiUrl: 'https://subscription.api.cx.metamask.io', }, }; diff --git a/packages/subscription-controller/src/index.ts b/packages/subscription-controller/src/index.ts index 4f20c55206..a03990f47b 100644 --- a/packages/subscription-controller/src/index.ts +++ b/packages/subscription-controller/src/index.ts @@ -6,6 +6,8 @@ export type { SubscriptionControllerCancelSubscriptionAction, SubscriptionControllerStartShieldSubscriptionWithCardAction, SubscriptionControllerGetPricingAction, + SubscriptionControllerGetCryptoApproveTransactionParamsAction, + SubscriptionControllerStartSubscriptionWithCryptoAction, SubscriptionControllerGetStateAction, SubscriptionControllerMessenger, SubscriptionControllerOptions, @@ -19,7 +21,14 @@ export type { Subscription, AuthUtils, ISubscriptionService, - PaymentMethod, + StartCryptoSubscriptionRequest, + StartCryptoSubscriptionResponse, + StartSubscriptionRequest, + StartSubscriptionResponse, + GetCryptoApproveTransactionRequest, + GetCryptoApproveTransactionResponse, + RecurringInterval, + SubscriptionStatus, PaymentType, Product, ProductType, @@ -27,6 +36,7 @@ export type { ProductPricing, TokenPaymentInfo, ChainPaymentInfo, + Currency, PricingPaymentMethod, PricingResponse, } from './types'; diff --git a/packages/subscription-controller/src/types.ts b/packages/subscription-controller/src/types.ts index e007ace29a..81f39e5d04 100644 --- a/packages/subscription-controller/src/types.ts +++ b/packages/subscription-controller/src/types.ts @@ -1,11 +1,16 @@ +import type { Hex } from '@metamask/utils'; + export enum ProductType { SHIELD = 'shield', } +/** only usd for now */ +export type Currency = 'usd'; + export type Product = { name: ProductType; id: string; - currency: string; + currency: Currency; amount: number; }; @@ -19,15 +24,6 @@ export enum RecurringInterval { year = 'year', } -export type PaymentMethod = { - type: PaymentType; - crypto?: { - payerAddress: string; - chainId: string; - tokenSymbol: string; - }; -}; - export enum SubscriptionStatus { // Initial states incomplete = 'incomplete', @@ -57,7 +53,16 @@ export type Subscription = { currentPeriodEnd: string; // ISO 8601 status: SubscriptionStatus; interval: RecurringInterval; - paymentMethod: PaymentMethod; + paymentMethod: SubscriptionPaymentMethod; +}; + +export type SubscriptionPaymentMethod = { + type: PaymentType; + crypto?: { + payerAddress: Hex; + chainId: Hex; + tokenSymbol: string; + }; }; export type GetSubscriptionsResponse = { @@ -76,36 +81,61 @@ export type StartSubscriptionResponse = { checkoutSessionUrl: string; }; +export type StartCryptoSubscriptionRequest = { + products: ProductType[]; + isTrialRequested: boolean; + recurringInterval: RecurringInterval; + billingCycles: number; + chainId: Hex; + payerAddress: Hex; + /** + * e.g. "USDC" + */ + tokenSymbol: string; + rawTransaction: Hex; +}; + +export type StartCryptoSubscriptionResponse = { + subscriptionId: string; + status: SubscriptionStatus; +}; + export type AuthUtils = { getAccessToken: () => Promise; }; export type ProductPrice = { - interval: string; // "month" | "year" - unitAmount: string; // amount in the smallest unit of the currency, e.g., cents + interval: RecurringInterval; + unitAmount: number; // amount in the smallest unit of the currency, e.g., cents unitDecimals: number; // number of decimals for the smallest unit of the currency - currency: string; // "usd" + /** only usd for now */ + currency: Currency; trialPeriodDays: number; minBillingCycles: number; }; export type ProductPricing = { - name: string; + name: ProductType; prices: ProductPrice[]; }; export type TokenPaymentInfo = { symbol: string; - address: string; + address: Hex; decimals: number; + /** + * example: { + usd: '1.0', + }, + */ conversionRate: { usd: string; }; }; export type ChainPaymentInfo = { - chainId: string; - paymentAddress: string; + chainId: Hex; + paymentAddress: Hex; tokens: TokenPaymentInfo[]; }; @@ -119,6 +149,36 @@ export type PricingResponse = { paymentMethods: PricingPaymentMethod[]; }; +export type GetCryptoApproveTransactionRequest = { + /** + * payment chain ID + */ + chainId: Hex; + /** + * Payment token address + */ + paymentTokenAddress: Hex; + productType: ProductType; + interval: RecurringInterval; +}; + +export type GetCryptoApproveTransactionResponse = { + /** + * The amount to approve + * e.g: "100000000" + */ + approveAmount: string; + /** + * The contract address (spender) + */ + paymentAddress: Hex; + /** + * The payment token address + */ + paymentTokenAddress: Hex; + chainId: Hex; +}; + export type ISubscriptionService = { getSubscriptions(): Promise; cancelSubscription(request: { subscriptionId: string }): Promise; @@ -126,4 +186,7 @@ export type ISubscriptionService = { request: StartSubscriptionRequest, ): Promise; getPricing(): Promise; + startSubscriptionWithCrypto( + request: StartCryptoSubscriptionRequest, + ): Promise; }; diff --git a/yarn.lock b/yarn.lock index 8bd5c3e5e8..d411324729 100644 --- a/yarn.lock +++ b/yarn.lock @@ -4652,12 +4652,12 @@ __metadata: dependencies: "@metamask/auto-changelog": "npm:^3.4.4" "@metamask/base-controller": "npm:^8.3.0" + "@metamask/controller-utils": "npm:^11.12.0" "@metamask/profile-sync-controller": "npm:^24.0.0" "@metamask/utils": "npm:^11.4.2" "@types/jest": "npm:^27.4.1" deepmerge: "npm:^4.2.2" jest: "npm:^27.5.1" - nock: "npm:^13.3.1" ts-jest: "npm:^27.1.4" typedoc: "npm:^0.24.8" typedoc-plugin-missing-exports: "npm:^2.0.0"