In this notebook we will see how to embed a batch of sequences using ESM C, as well as explore its different layers

# Set up Forge client for ESM C

Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories.

In [None]:
# Install esm and other dependencies
! pip install esm
! pip install matplotlib

Collecting esm
  Downloading esm-3.1.5-py3-none-any.whl.metadata (16 kB)
Collecting torchtext (from esm)
  Downloading torchtext-0.18.0-cp311-cp311-manylinux1_x86_64.whl.metadata (7.9 kB)
Collecting transformers<4.47.0 (from esm)
  Downloading transformers-4.46.3-py3-none-any.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.1/44.1 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
Collecting biotite==0.41.2 (from esm)
  Downloading biotite-0.41.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.1 kB)
Collecting msgpack-numpy (from esm)
  Downloading msgpack_numpy-0.4.8-py2.py3-none-any.whl.metadata (5.0 kB)
Collecting biopython (from esm)
  Downloading biopython-1.85-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Collecting brotli (from esm)
  Downloading Brotli-1.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.5 kB)
Collecting zstd (from esm)
  Downloading zstd-1.5.6.6-cp311

In [11]:
from getpass import getpass

token = getpass("Token from Forge console: ")

Token from Forge console: ··········


In [12]:
from esm.sdk import client

model = client(
    model="esmc-300m-2024-12", url="https://forge.evolutionaryscale.ai", token=token
)

In [13]:
SAVE_PATH = "/content/HM_ESM3/esm_outputs_m.pkl"  # 确保 Colab 挂载了 Google Drive

# Set up utilities for embedding sequences

Since we're embedding more than a few sequences, we're going to use a threaded async call to Forge and let Forge take care of batching and parallelization on the backend.

In [14]:
from concurrent.futures import ThreadPoolExecutor
from typing import Sequence
from typing import Dict
import pickle
import tempfile
import shutil

from esm.sdk.api import (
    ESM3InferenceClient,
    ESMProtein,
    ESMProteinError,
    LogitsConfig,
    LogitsOutput,
    ProteinType,
)

EMBEDDING_CONFIG = LogitsConfig(
    sequence=True, return_embeddings=True, return_hidden_states=True
)

def load_existing_pkl(filename: str) -> Dict[str, LogitsOutput]:
    if os.path.exists(filename):
        with open(filename, "rb") as f:
            return pickle.load(f)
    return {}


def save_to_pkl(data, filename):
    """使用临时文件存储，避免 pkl 文件损坏"""
    temp_file = filename + ".tmp"
    with open(temp_file, "wb") as f:
        pickle.dump(data, f)
    shutil.move(temp_file, filename)

def embed_sequence(model: ESM3InferenceClient, sequence: str) -> LogitsOutput:
    protein = ESMProtein(sequence=sequence)
    protein_tensor = model.encode(protein)
    print(protein_tensor)
    output = model.logits(protein_tensor, EMBEDDING_CONFIG)
    return output


def batch_embed(
    model: ESM3InferenceClient, inputs: Sequence[ProteinType], ids: Sequence[str]
  ) -> Dict[str, LogitsOutput]:

    existing_results = load_existing_pkl(SAVE_PATH)
    print(f"已加载 {len(existing_results)} 个已计算的样本")
    new_inputs = []
    new_ids = []
    results = {}

    for seq, seq_id in zip(inputs, ids):
        if seq_id not in existing_results:  # 只计算新样本
            new_inputs.append(seq)
            new_ids.append(seq_id)

    print(f"需要计算 {len(new_inputs)} 个新样本")

    results = existing_results.copy()

    """Forge supports auto-batching. So batch_embed() is as simple as running a collection
    of embed calls in parallel using asyncio.
    """
    with ThreadPoolExecutor() as executor:
        futures = {executor.submit(embed_sequence, model, seq): seq_id for seq, seq_id in zip(new_inputs, new_ids)}

        for future in futures:
            seq_id = futures[future]  # 当前序列 ID
            try:
                output = future.result()
                results[seq_id] = output  # 存入结果
            except Exception as e:
                results[seq_id] = ESMProteinError(500, str(e))  # 记录错误

            # **每处理 10 个样本就存一次，防止数据丢失**
            if len(results) % 10 == 0:
                save_to_pkl(results, SAVE_PATH)
                print(f"已保存 {len(results)} 个样本到 {SAVE_PATH}")

                # **释放内存**
                del output
                gc.collect()

    save_to_pkl(results, SAVE_PATH)
    print(f"最终保存 {len(results)} 个样本到 {SAVE_PATH}")
    return results

above占用太多内存

In [15]:
import os
import gc
import pickle
import shutil
from concurrent.futures import ThreadPoolExecutor
from typing import Sequence, Dict
from esm.sdk.api import (
    ESM3InferenceClient,
    ESMProtein,
    ESMProteinError,
    LogitsConfig,
    LogitsOutput,
    ProteinType,
)


EMBEDDING_CONFIG = LogitsConfig(
    sequence=True, return_embeddings=True, return_hidden_states=True
)

def load_existing_pkl(filename: str) -> Dict[str, LogitsOutput]:
    """加载已保存的 .pkl 文件，避免读取完整数据导致 OOM"""
    if os.path.exists(filename) and os.path.getsize(filename) > 0:
        try:
            with open(filename, "rb") as f:
                return pickle.load(f)
        except (EOFError, pickle.UnpicklingError):
            print(f"⚠ `{filename}` 文件损坏，删除...")
            os.remove(filename)
    return {}

def save_to_pkl_incremental(new_data, filename):
    """增量保存数据，避免全量写入导致 OOM"""
    existing_data = load_existing_pkl(filename)
    existing_data.update(new_data)  # 仅合并新数据
    temp_file = filename + ".tmp"
    with open(temp_file, "wb") as f:
        pickle.dump(existing_data, f)
    shutil.move(temp_file, filename)

def embed_sequence(model: ESM3InferenceClient, sequence: str) -> LogitsOutput:
    """计算单个序列的 embedding"""
    print("length of sequence", len(sequence))
    protein = ESMProtein(sequence=sequence)
    protein_tensor = model.encode(protein)
    output = model.logits(protein_tensor, EMBEDDING_CONFIG)
    return output

def batch_embed(
    model: ESM3InferenceClient, inputs: Sequence[ProteinType], ids: Sequence[str]
) -> None:
    """优化版本：分批计算 embedding，避免 OOM"""

    # **仅加载 ID，不读取所有数据**
    existing_results = load_existing_pkl(SAVE_PATH)
    print(f"✅ 已加载 {len(existing_results)} 个已计算样本")

    # **筛选出需要计算的序列**
    new_inputs = []
    new_ids = []
    for seq, seq_id in zip(inputs, ids):
        if seq_id not in existing_results:  # 只计算新样本
            new_inputs.append(seq)
            new_ids.append(seq_id)

    print(f"🚀 需要计算 {len(new_inputs)} 个新样本")

    # **限制线程数，避免 OOM**
    with ThreadPoolExecutor(max_workers=4) as executor:
        futures = {executor.submit(embed_sequence, model, seq): seq_id for seq, seq_id in zip(new_inputs, new_ids)}
        batch_results = {}

        for future in futures:
            seq_id = futures[future]  # 当前序列 ID
            try:
                output = future.result()
                batch_results[seq_id] = output  # 存入结果
            except Exception as e:
                batch_results[seq_id] = ESMProteinError(500, str(e))  # 记录错误

            # **每 5 个样本就保存一次，防止数据丢失**
            if len(batch_results) >= 1:
                save_to_pkl_incremental(batch_results, SAVE_PATH)
                print(f"✅ 已保存 {len(batch_results)} 个新样本到 {SAVE_PATH}")
                batch_results.clear()  # 清空 RAM
                gc.collect()  # 释放内存

    # **最终保存所有剩余数据**
    if batch_results:
        save_to_pkl_incremental(batch_results, SAVE_PATH)
        print(f"✅ 最终保存 {len(batch_results)} 个样本到 {SAVE_PATH}")

    print(f"🎉 计算完成，数据保存在 {SAVE_PATH}")


# Requesting a specific hidden layer

ESM C 6B's hidden states are really large, so we only allow one specific layer to be requested per API call. This also works for other ESM C models, but it is required for ESM C 6B.
Refer to https://forge.evolutionaryscale.ai/console to find the number of hidden layers for each model.

In [None]:
# ESMC_6B_EMBEDDING_CONFIG = LogitsConfig(return_hidden_states=True, ith_hidden_layer=55)

# Load dataset

This dataset is taken from Muir, et al. 2024 ["Evolutionary-Scale Enzymology Enables Biochemical Constant Prediction Across a Multi-Peaked Catalytic Landscape"](https://doi.org/10.1101/2024.10.23.619915) which explores a model enzyme called Adenylate Kinase (ADK). Adenylate Kinase appears in many different organisms with different structural classes (referred to as its "lid type"). We'll embed this set of ADK sequences and see if we can recover known structural classes.

In [None]:
# !wget --no-check-certificate "https://docs.google.com/uc?export=download&id=1SpOkL11MJxIgy99dqufvUNJuCiuhxuyg" -O adk.csv

--2025-03-03 07:50:02--  https://docs.google.com/uc?export=download&id=1SpOkL11MJxIgy99dqufvUNJuCiuhxuyg
Resolving docs.google.com (docs.google.com)... 142.251.2.101, 142.251.2.113, 142.251.2.138, ...
Connecting to docs.google.com (docs.google.com)|142.251.2.101|:443... connected.
HTTP request sent, awaiting response... 303 See Other
Location: https://drive.usercontent.google.com/download?id=1SpOkL11MJxIgy99dqufvUNJuCiuhxuyg&export=download [following]
--2025-03-03 07:50:02--  https://drive.usercontent.google.com/download?id=1SpOkL11MJxIgy99dqufvUNJuCiuhxuyg&export=download
Resolving drive.usercontent.google.com (drive.usercontent.google.com)... 74.125.137.132, 2607:f8b0:4023:c03::84
Connecting to drive.usercontent.google.com (drive.usercontent.google.com)|74.125.137.132|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 43132 (42K) [application/octet-stream]
Saving to: ‘adk.csv’


2025-03-03 07:50:04 (7.13 MB/s) - ‘adk.csv’ saved [43132/43132]



In [16]:
import pandas as pd
adk_path = "/content/HOM_MouseHumans_sequence.csv"
df = pd.read_csv(adk_path)
df = df[["human_id", "mouse_id", "sequence_h", "sequence_m"]]


In [17]:
path2 = '/content/1.10.510.10_pair.csv'
df2 = pd.read_csv(path2)
df2 = df2[["ids", "ids2"]]
len(df2)

298

In [18]:
# prompt: 在df2中存在的id对，对应df中的行human_id的sequence_h序列写入列表中

sequences_h = []
sequences_m = []
for index, row in df2.iterrows():
  id1 = row['ids']
  id2 = row['ids2']
  # Find matching rows in df based on human_id and mouse_id
  matching_rows = df[(df['human_id'] == id1) | (df['mouse_id'] == id2)]
  # Extract sequence_h values from matching rows
  sequence_h = ','.join(matching_rows['sequence_h'].astype(str).tolist())
  sequences_h.append(sequence_h)
  sequence_m = ','.join(matching_rows['sequence_m'].astype(str).tolist())
  sequences_m.append(sequence_m)

# df2.loc[:, 'sequence_h'] = sequences_h
df2['sequence_h'] = sequences_h
df2['sequence_m'] = sequences_m
df2


Unnamed: 0,ids,ids2,sequence_h,sequence_m
0,P53667,P53668,MRLTLLCCTWREERMGEEGSELPVCASCGQRIYDGQYLQALNADWH...,MRLTLLCCTWREERMGEEGSELPVCASCGQRIYDGQYLQALNADWH...
1,Q99986,Q80X41,MPRVKAAQAGRQSSAKRHLAEQFAVGEIITDMAKKEWKVGLPIGQG...,MPRVKAAQAGRPGPAKRRLAEQFAAGEVLTDMSRKEWKLGLPIGQG...
2,Q96PF2,O54863,MDDATVLRKKGYIVGINLGKGSYAKVKSAYSERLKFNVAVKIIDRK...,MDDAAVLRKKGYIVGINLGKGSYAKVKSAYSERLKFNVAVKIIDRK...
3,Q13043,Q9JI11,METVQLRNPPRRQLKKLDEDSLTKQPEEVFDVLEKLGEGSYGSVYK...,METVQLRNPPRRQLKKLDEDSLTKQPEEVFDVLEKLGEGSYGSVYK...
4,O15075,Q9JLM8,MSFGRDMELEHFDERDKAQRYSRGSRVNGLPSPTHSAHCSFYRTRT...,MSFGRDMELEHFDERDKAQRYSRGSRVNGLPSPTHSAHCSFYRTRT...
...,...,...,...,...
293,P49761,O35492,MHHCKRYRSPEPDPYLSYRWKRRRSYSREHEGRLRYPSRREPPPRR...,MHHCKRYRSPEPDPYLSYRWKRRRSYSREHEGRLRYPSRREPPPRR...
294,Q15835,Q9WVL4,MDFGSLETVVANSAFIAARGSFDGSSSQPSRDKKYLAKLKLPPLSK...,MDFGSLETVVANSAFIAARGSFDGSSTPSSRDKKYLAKLRLPPLSK...
295,Q92918,P70218,MDVVDPDIFNRDPRDHYDLLQRLGGGTYGEVFKARDKVSGDLVALK...,MALVDPDIFNKDPREHYDLLQRLGGGTYGEVFKARDKVSKDLVALK...
296,Q13705,P27040,MTAPWVALALLWGSLCAGSGRGEAETRECIYYNANWELERTNQSGL...,MTAPWAALALLWGSLCAGSGRGEAETRECIYYNANWELERTNQSGL...


In [None]:
# import matplotlib.pyplot as plt
# import pandas as pd
# import seaborn as sns

# adk_path = "adk.csv"
# df = pd.read_csv(adk_path)
# df = df[["org_name", "sequence", "lid_type", "temperature"]]
# df = df[df["lid_type"] != "other"]  # drop one structural class for simplicity

In [19]:
import gc
import pickle
from typing import Dict
import os

batch_size = int(len(df2) / 5)
api_no = 0

In [None]:
# You may see some error messages due to rate limits on each Forge account,
# but this will retry until the embedding job is complete
# This may take a few minutes to run
outputs_h = batch_embed(model, df2["sequence_m"].tolist()[batch_size*(api_no):batch_size*(api_no+1)], df2["ids2"].tolist()[batch_size*(api_no):batch_size*(api_no+1)])

✅ 已加载 25 个已计算样本
🚀 需要计算 34 个新样本


In [None]:
import torch

with open(SAVE_PATH, "rb") as f:
    saved_outputs_h = pickle.load(f)

# all_mean_embeddings_h = {
#     seq_id: torch.mean(output.hidden_states, dim=-2).squeeze()
#     for seq_id, output in saved_outputs_h.items() if isinstance(output, LogitsOutput)
# }
all_mean_embeddings_h = [
    torch.mean(output.hidden_states, dim=-2).squeeze()
    for _, output in saved_outputs_h.items() if isinstance(output, LogitsOutput)
]
print(f"成功计算 {len(all_mean_embeddings_h)} 个 embedding")

成功计算 3 个 embedding


In [None]:
import torch

# we'll summarize the embeddings using their mean across the sequence dimension
# which allows us to compare embeddings for sequences of different lengths
all_mean_embeddings_h = [
    torch.mean(output.hidden_states, dim=-2).squeeze() for output in outputs_h
]
# all_mean_embeddings_m = [
#     torch.mean(output.hidden_states, dim=-2).squeeze() for output in outputs_m
# ]


In [None]:

# now we have a list of tensors of [num_layers, hidden_size]
print("embedding shape [num_layers, hidden_size]:", all_mean_embeddings_h[0].shape)

embedding shape [num_layers, hidden_size]: torch.Size([31, 960])


In [None]:
all_mean_embeddings_h

[tensor([[ 1.5179e-01, -2.4687e-02, -1.4900e-01,  ..., -4.6417e-02,
           1.9227e-01, -1.7083e-01],
         [-2.5817e-01, -2.4541e-01, -1.4340e-01,  ..., -9.2550e-02,
           1.2583e-01, -3.3051e-01],
         [-2.2053e-01, -2.0242e-01, -3.9532e-02,  ...,  3.3464e-02,
           2.1106e-01, -4.6161e-01],
         ...,
         [-3.8543e+00,  6.9117e+00,  1.2235e+01,  ..., -1.0095e+01,
          -7.8135e+00,  4.4944e-01],
         [ 9.4077e+00,  1.4029e+01,  4.9635e+00,  ..., -9.9581e+00,
          -3.5056e+01, -1.4664e+01],
         [-1.1882e-02,  1.3286e-03,  1.5210e-02,  ..., -1.2846e-02,
          -3.5730e-02, -9.4714e-03]]),
 tensor([[ 1.3955e-01,  8.7762e-02, -3.1518e-01,  ..., -6.1653e-02,
           1.2724e-01, -1.7230e-01],
         [-3.1643e-01, -2.1005e-01, -1.4017e-01,  ..., -1.3113e-01,
           1.2547e-01, -2.6205e-01],
         [-2.6563e-01, -1.0068e-01,  6.0194e-02,  ...,  1.5872e-02,
           1.6433e-01, -4.1852e-01],
         ...,
         [-7.4761e-01,  4

# Examine the performance of different layer embeddings

For this example, we're going to use PCA to visualize whether the embeddings separate our proteins by their structural class. To assess the quality of our PCA, we fit a K means classifier with three clusters, corresponding to the three structural classes of our enzyme, and compute the [rand index](https://en.wikipedia.org/wiki/Rand_index), a measure of the quality of the clustering.

In [None]:
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.metrics import adjusted_rand_score
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import numpy as np
N_KMEANS_CLUSTERS = 3

In [None]:


def plot_embeddings_at_layer(all_mean_embeddings_h, all_mean_embeddings_m, layer_idx: int, df: pd.DataFrame):
    # 将 `sequence_h` 和 `sequence_m` 嵌入向量合并
    embeddings_h = torch.stack([embedding[layer_idx, :] for embedding in all_mean_embeddings_h]).numpy()
    embeddings_m = torch.stack([embedding[layer_idx, :] for embedding in all_mean_embeddings_m]).numpy()

    # 合并两个嵌入
    stacked_mean_embeddings = np.vstack([embeddings_h, embeddings_m])

    # 使用 PCA 降维到 2D
    pca = PCA(n_components=2)
    projected_mean_embeddings = pca.fit_transform(stacked_mean_embeddings)

    # KMeans 聚类
    # kmeans = KMeans(n_clusters=3, random_state=0).fit(projected_mean_embeddings)
    # kmeans_labels = kmeans.labels_

    # 计算 Adjusted Rand Index (ARI) 评估聚类效果
    # rand_index = adjusted_rand_score([0] * len(embeddings_h) + [1] * len(embeddings_m), kmeans_labels)

    # 标签：0 表示 `sequence_h`，1 表示 `sequence_m`
    labels = [0] * len(embeddings_h) + [1] * len(embeddings_m)

    # 绘制 PCA + KMeans 聚类结果
    plt.figure(figsize=(6, 6))
    sns.scatterplot(
        x=projected_mean_embeddings[:, 0],
        y=projected_mean_embeddings[:, 1],
        hue=labels,  # 用颜色表示 KMeans 结果
        palette="viridis",
        style=labels,  # 用形状区分 sequence_h (o) 和 sequence_m (s)
        markers=["o", "s"],
    )

    # 标题
    # plt.title(f"PCA of Mean Embeddings at Layer {layer_idx} (Rand Index: {rand_index:.2f})")
    plt.title(f"PCA of Mean Embeddings at Layer {layer_idx}")
    plt.xlabel("PC 1")
    plt.ylabel("PC 2")
    # plt.legend(title="Cluster", loc="best")
    plt.show()


In [None]:
plot_embeddings_at_layer(all_mean_embeddings_h, all_mean_embeddings_m, layer_idx=30, df=df)
plot_embeddings_at_layer(all_mean_embeddings_h, all_mean_embeddings_m, layer_idx=12, df=df)

NameError: name 'all_mean_embeddings_m' is not defined

We see that the top principal components of layer 12 separate structural classes better than that of layer 30. Embed away! And keep in mind that different layers may be better or worse for your particular use-case.