In [2]:
import os
import torch
import torch.nn as nn
import pandas as pd
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader

In [3]:
esm2_path = "model/esm2_model"

tokenizer = AutoTokenizer.from_pretrained(esm2_path, local_files_only=True)
model = AutoModel.from_pretrained(esm2_path, local_files_only=True)
model.eval()

Some weights of EsmModel were not initialized from the model checkpoint at model/esm2_model and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


EsmModel(
  (embeddings): EsmEmbeddings(
    (word_embeddings): Embedding(33, 320, padding_idx=1)
    (dropout): Dropout(p=0.0, inplace=False)
    (position_embeddings): Embedding(1026, 320, padding_idx=1)
  )
  (encoder): EsmEncoder(
    (layer): ModuleList(
      (0-5): 6 x EsmLayer(
        (attention): EsmAttention(
          (self): EsmSelfAttention(
            (query): Linear(in_features=320, out_features=320, bias=True)
            (key): Linear(in_features=320, out_features=320, bias=True)
            (value): Linear(in_features=320, out_features=320, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (rotary_embeddings): RotaryEmbedding()
          )
          (output): EsmSelfOutput(
            (dense): Linear(in_features=320, out_features=320, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (LayerNorm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
        )
        (intermediate): EsmIntermediate(
    

In [4]:
df = pd.read_csv("data/cancer_data_preprocessed.csv")
df = df[["mutated_protein", "Cancer Stage"]].dropna()

le = LabelEncoder()
df["label"] = le.fit_transform(df["Cancer Stage"])

In [5]:
def encode_sequence(seq):
    # 清理无效字符（严格只保留 vocab 内字符，一般大写字母氨基酸）
    seq = ''.join([aa for aa in seq if aa in tokenizer.get_vocab()])
    inputs = tokenizer(seq, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.last_hidden_state.mean(dim=1).squeeze(0)

In [6]:
embeddings, labels = [], []

print("🧬 Encoding protein sequences with ESM2...")
for _, row in tqdm(df.iterrows(), total=len(df)):
    try:
        emb = encode_sequence(row["mutated_protein"])
        embeddings.append(emb)
        labels.append(row["label"])
    except Exception as e:
        print("❌ Error:", e)

X = torch.stack(embeddings)
y = torch.tensor(labels)

🧬 Encoding protein sequences with ESM2...


  0%|          | 0/38258 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
  1%|          | 329/38258 [02:25<4:40:14,  2.26it/s] 


KeyboardInterrupt: 

In [None]:
X_np, y_np = X.numpy(), y.numpy()

X_train, X_temp, y_train, y_temp = train_test_split(
    X_np, y_np, test_size=0.3, stratify=y_np, random_state=42)

X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)

X_train, y_train = torch.tensor(X_train), torch.tensor(y_train)
X_val, y_val = torch.tensor(X_val), torch.tensor(y_val)
X_test, y_test = torch.tensor(X_test), torch.tensor(y_test)

train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=32, shuffle=True)
val_loader = DataLoader(TensorDataset(X_val, y_val), batch_size=64)
test_loader = DataLoader(TensorDataset(X_test, y_test), batch_size=64)

In [None]:
class ProteinClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes)
        )
    def forward(self, x):
        return self.fc(x)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_cls = ProteinClassifier(X.shape[1], len(le.classes_)).to(device)
optimizer = torch.optim.AdamW(model_cls.parameters(), lr=2e-4)
loss_fn = nn.CrossEntropyLoss()

EPOCHS = 10
for epoch in range(EPOCHS):
    model_cls.train()
    train_loss = 0
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        out = model_cls(xb)
        loss = loss_fn(out, yb)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    # 验证
    model_cls.eval()
    val_loss = 0
    with torch.no_grad():
        for xb, yb in val_loader:
            xb, yb = xb.to(device), yb.to(device)
            out = model_cls(xb)
            loss = loss_fn(out, yb)
            val_loss += loss.item()

    print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

In [None]:
model_cls.eval()
all_preds, all_labels = [], []
with torch.no_grad():
    for xb, yb in test_loader:
        xb = xb.to(device)
        preds = model_cls(xb).argmax(dim=1).cpu()
        all_preds.extend(preds)
        all_labels.extend(yb)

acc = accuracy_score(all_labels, all_preds)
f1 = f1_score(all_labels, all_preds, average="macro")
print(f"✅ Test Accuracy: {acc:.4f}, Macro F1: {f1:.4f}")