Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 100 additions & 9 deletions modules/sdk-coin-sol/src/sol.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
* @prettier
*/

import { TOKEN_2022_PROGRAM_ID, TOKEN_PROGRAM_ID } from '@solana/spl-token';
import BigNumber from 'bignumber.js';
import * as base58 from 'bs58';
import * as _ from 'lodash';
import * as request from 'superagent';

import {
AuditDecryptedKeyParams,
BaseBroadcastTransactionOptions,
BaseBroadcastTransactionResult,
BaseCoin,
Expand All @@ -16,6 +20,7 @@ import {
EDDSAMethods,
EDDSAMethodTypes,
Environments,
ITokenEnablement,
KeyPair,
Memo,
MethodNotImplementedError,
Expand All @@ -27,29 +32,28 @@ import {
MPCTx,
MPCTxs,
MPCUnsignedTx,
MultisigType,
multisigTypes,
OvcInput,
OvcOutput,
ParsedTransaction,
PopulatedIntent,
PrebuildTransactionWithIntentOptions,
PresignTransactionOptions,
PublicKey,
RecoveryTxRequest,
SignedTransaction,
SignTransactionOptions,
TokenEnablement,
TokenEnablementConfig,
TransactionExplanation,
TransactionParams,
TransactionRecipient,
VerifyAddressOptions,
VerifyTransactionOptions,
MultisigType,
multisigTypes,
AuditDecryptedKeyParams,
PopulatedIntent,
PrebuildTransactionWithIntentOptions,
} from '@bitgo/sdk-core';
import { auditEddsaPrivateKey, getDerivationPath } from '@bitgo/sdk-lib-mpc';
import { BaseNetwork, CoinFamily, coins, BaseCoin as StaticsBaseCoin } from '@bitgo/statics';
import * as _ from 'lodash';
import * as request from 'superagent';
import { BaseNetwork, CoinFamily, coins, SolCoin, BaseCoin as StaticsBaseCoin } from '@bitgo/statics';
import { KeyPair as SolKeyPair, Transaction, TransactionBuilder, TransactionBuilderFactory } from './lib';
import {
getAssociatedTokenAccountAddress,
Expand All @@ -60,7 +64,6 @@ import {
isValidPublicKey,
validateRawTransaction,
} from './lib/utils';
import { TOKEN_2022_PROGRAM_ID, TOKEN_PROGRAM_ID } from '@solana/spl-token';

export const DEFAULT_SCAN_FACTOR = 20; // default number of receive addresses to scan for funds

Expand Down Expand Up @@ -173,6 +176,7 @@ export interface SolConsolidationRecoveryOptions extends MPCConsolidationRecover
}

const HEX_REGEX = /^[0-9a-fA-F]+$/;
const BLIND_SIGNING_TX_TYPES_TO_CHECK = { enabletoken: 'AssociatedTokenAccountInitialization' };

export class Sol extends BaseCoin {
protected readonly _staticsCoin: Readonly<StaticsBaseCoin>;
Expand Down Expand Up @@ -233,6 +237,84 @@ export class Sol extends BaseCoin {
return Math.pow(10, this._staticsCoin.decimalPlaces);
}

verifyTxType(expectedTypeFromUserParams: string, actualTypeFromDecoded: string | undefined): void {
const matchFromUserToDecodedType = BLIND_SIGNING_TX_TYPES_TO_CHECK[expectedTypeFromUserParams];
if (matchFromUserToDecodedType !== actualTypeFromDecoded) {
throw new Error(
`Invalid transaction type on token enablement: expected "${matchFromUserToDecodedType}", got "${actualTypeFromDecoded}".`
);
}
}

throwIfMissingTokenEnablementsOrReturn(explanation: TransactionExplanation): ITokenEnablement[] {
if (!explanation.tokenEnablements || explanation.tokenEnablements.length === 0)
throw new Error('Missing tx token enablements data on token enablement tx prebuild');
return explanation.tokenEnablements;
}

throwIfMissingEnableTokenConfigOrReturn(txParams: TransactionParams): TokenEnablement[] {
if (!txParams.enableTokens || txParams.enableTokens.length === 0) throw new Error('Missing enable token config');
return txParams.enableTokens;
}

verifyTokenName(tokenEnablementsPrebuild: ITokenEnablement[], enableTokensConfig: TokenEnablement[]): void {
enableTokensConfig.forEach((enableTokenConfig) => {
const expectedTokenName = enableTokenConfig.name;
tokenEnablementsPrebuild.forEach((tokenEnablement) => {
if (!tokenEnablement.tokenName) throw new Error('Missing token name on token enablement tx');
if (tokenEnablement.tokenName !== expectedTokenName)
throw new Error(
`Invalid token name: expected ${expectedTokenName}, got ${tokenEnablement.tokenName} on token enablement tx`
);
});
});
}

async verifyTokenAddress(
tokenEnablementsPrebuild: ITokenEnablement[],
enableTokensConfig: TokenEnablement[]
): Promise<void> {
for (const enableTokenConfig of enableTokensConfig) {
const expectedTokenAddress = enableTokenConfig.address;
const expectedTokenName = enableTokenConfig.name;

if (!expectedTokenAddress) throw new Error('Missing token address on token enablement tx');
if (!expectedTokenName) throw new Error('Missing token name on token enablement tx');

for (const tokenEnablement of tokenEnablementsPrebuild) {
let tokenMintAddress: Readonly<SolCoin> | undefined;
try {
tokenMintAddress = getSolTokenFromTokenName(expectedTokenName);
} catch {
throw new Error(`Unable to derive ATA for token address: ${expectedTokenAddress}`);
}
if (
!tokenMintAddress ||
tokenMintAddress.tokenAddress === undefined ||
tokenMintAddress.programId === undefined
) {
throw new Error(`Unable to get token mint address for ${expectedTokenName}`);
}
let ata: string;
try {
ata = await getAssociatedTokenAccountAddress(
tokenMintAddress.tokenAddress,
expectedTokenAddress,
true,
tokenMintAddress.programId
);
} catch {
throw new Error(`Unable to derive ATA for token address: ${expectedTokenAddress}`);
}
if (ata !== tokenEnablement.address) {
throw new Error(
`Invalid token address: expected ${ata}, got ${tokenEnablement.address} on token enablement tx`
);
}
}
}
}

async verifyTransaction(params: SolVerifyTransactionOptions): Promise<boolean> {
// asset name to transfer amount map
const totalAmount: Record<string, BigNumber> = {};
Expand Down Expand Up @@ -261,6 +343,15 @@ export class Sol extends BaseCoin {
transaction.fromRawTransaction(rawTxBase64);
const explainedTx = transaction.explainTransaction();

if (txParams.type === 'enabletoken' && verificationOptions?.verifyTokenEnablement) {
this.verifyTxType(txParams.type, explainedTx.type);
const tokenEnablementsPrebuild = this.throwIfMissingTokenEnablementsOrReturn(explainedTx);
const enableTokensConfig = this.throwIfMissingEnableTokenConfigOrReturn(txParams);

this.verifyTokenName(tokenEnablementsPrebuild, enableTokensConfig);
await this.verifyTokenAddress(tokenEnablementsPrebuild, enableTokensConfig);
}

// users do not input recipients for consolidation requests as they are generated by the server
if (txParams.recipients !== undefined) {
const filteredRecipients = txParams.recipients?.map((recipient) =>
Expand Down
Loading