Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions packages/agent-sdk/src/evm/client.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import { createPublicClient, http, PublicClient } from "viem";
import * as chains from "viem/chains";

type Chain = chains.Chain;

const CHAINS_BY_CHAIN_ID = Object.fromEntries(
Object.values(chains).map((chain) => [chain.id, chain]),
);

const getChainById = (chainId: number): Chain | undefined => {
return CHAINS_BY_CHAIN_ID[chainId];
};

export function getClientForChain(chainId: number): PublicClient {
const chain = getChainById(chainId);
if (!chain) {
throw new Error(`Chain with ID ${chainId} not found`);
}
return createPublicClient({
chain,
transport: http(chain.rpcUrls.default.http[0]),
});
}
9 changes: 5 additions & 4 deletions packages/agent-sdk/src/evm/erc20.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import { erc20Abi } from "viem";
import { encodeFunctionData, type Address } from "viem";
import { getClient, type MetaTransaction } from "near-safe";
import type { MetaTransaction } from "near-safe";
import type { TokenInfo } from "./types";
import { getClientForChain } from "./client";

const NATIVE_ASSET = "0xEeeeeEeeeEeEeeEeEeEeeEEEeeeeEeeeeeeeEEeE";
const MAX_APPROVAL = BigInt(
Expand Down Expand Up @@ -49,7 +50,7 @@ export async function checkAllowance(
spender: Address,
chainId: number,
): Promise<bigint> {
return getClient(chainId).readContract({
return getClientForChain(chainId).readContract({
address: token,
abi: erc20Abi,
functionName: "allowance",
Expand Down Expand Up @@ -85,7 +86,7 @@ export async function getTokenDecimals(
chainId: number,
address: Address,
): Promise<number> {
const client = getClient(chainId);
const client = getClientForChain(chainId);
try {
return await client.readContract({
address,
Expand All @@ -101,7 +102,7 @@ export async function getTokenSymbol(
chainId: number,
address: Address,
): Promise<string> {
const client = getClient(chainId);
const client = getClientForChain(chainId);
try {
return await client.readContract({
address,
Expand Down
4 changes: 2 additions & 2 deletions packages/agent-sdk/src/evm/safe.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { type Address, checksumAddress, parseUnits } from "viem";
import { type UserToken, ZerionAPI } from "zerion-sdk";
import { scientificToDecimal } from "../misc";
import { getClient } from "near-safe";
import { getClientForChain } from "./client";

export interface TokenBalance {
tokenAddress: string | null; // null for native token
Expand Down Expand Up @@ -54,7 +54,7 @@ export async function getSafeBalances(
address: Address,
zerionKey?: string,
): Promise<TokenBalance[]> {
const client = await getClient(chainId);
const client = await getClientForChain(chainId);
const codeAt = await client.getCode({ address });
if (!codeAt) {
// Not a Safe - Get balances from Zerion.
Expand Down
18 changes: 10 additions & 8 deletions packages/agent-sdk/tests/evm/erc20.spec.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import { type Address, erc20Abi } from "viem";
import { getClient } from "near-safe";
import {
erc20Transfer,
erc20Approve,
Expand All @@ -8,9 +7,12 @@ import {
getTokenDecimals,
getTokenSymbol,
} from "../../src";
import { getClientForChain } from "../../src/evm/client";

// Mock the external dependencies
jest.mock("near-safe");
jest.mock("../../src/evm/client", () => ({
getClientForChain: jest.fn(),
}));

describe("ERC20 Utilities", () => {
const mockAddress = "0x1234567890123456789012345678901234567890" as Address;
Expand Down Expand Up @@ -67,7 +69,7 @@ describe("ERC20 Utilities", () => {
const mockClient = {
readContract: jest.fn().mockResolvedValue(BigInt(1000)),
};
(getClient as jest.Mock).mockReturnValue(mockClient);
(getClientForChain as jest.Mock).mockReturnValue(mockClient);

const result = await checkAllowance(
mockAddress,
Expand All @@ -94,7 +96,7 @@ describe("ERC20 Utilities", () => {
.mockResolvedValueOnce(18) // decimals
.mockResolvedValueOnce("TEST"), // symbol
};
(getClient as jest.Mock).mockReturnValue(mockClient);
(getClientForChain as jest.Mock).mockReturnValue(mockClient);

const result = await getTokenInfo(mockChainId, mockAddress);

Expand All @@ -111,7 +113,7 @@ describe("ERC20 Utilities", () => {
const mockClient = {
readContract: jest.fn().mockResolvedValue(18),
};
(getClient as jest.Mock).mockReturnValue(mockClient);
(getClientForChain as jest.Mock).mockReturnValue(mockClient);

const result = await getTokenDecimals(mockChainId, mockAddress);

Expand All @@ -122,7 +124,7 @@ describe("ERC20 Utilities", () => {
const mockClient = {
readContract: jest.fn().mockRejectedValue(new Error("Test error")),
};
(getClient as jest.Mock).mockReturnValue(mockClient);
(getClientForChain as jest.Mock).mockReturnValue(mockClient);

await expect(getTokenDecimals(mockChainId, mockAddress)).rejects.toThrow(
"Error fetching token decimals: Error: Test error",
Expand All @@ -135,7 +137,7 @@ describe("ERC20 Utilities", () => {
const mockClient = {
readContract: jest.fn().mockResolvedValue("TEST"),
};
(getClient as jest.Mock).mockReturnValue(mockClient);
(getClientForChain as jest.Mock).mockReturnValue(mockClient);

const result = await getTokenSymbol(mockChainId, mockAddress);

Expand All @@ -146,7 +148,7 @@ describe("ERC20 Utilities", () => {
const mockClient = {
readContract: jest.fn().mockRejectedValue(new Error("Test error")),
};
(getClient as jest.Mock).mockReturnValue(mockClient);
(getClientForChain as jest.Mock).mockReturnValue(mockClient);

await expect(getTokenSymbol(mockChainId, mockAddress)).rejects.toThrow(
"Error fetching token decimals: Error: Test error",
Expand Down