In [None]:

# =====================================================
#  MLP for EC prediction using ESM2 pooled embeddings
#  - Train + Test + Report
# =====================================================

import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from tqdm.auto import tqdm

# -----------------------------------------------------
# 1. Load features & labels
# -----------------------------------------------------
X = np.load("esm2_features.npy")  # shape should be (236607, 480)
print("Shape X:", X.shape)

df = pd.read_csv("3_levels_EC.tsv", sep="\t")
labels = df["EC number"].astype(str).values

le = LabelEncoder()
y = le.fit_transform(labels)

num_classes = len(le.classes_)
print("Total classes:", num_classes)
print("Example classes:", le.classes_[:20])

# -----------------------------------------------------
# 2. Train-test split (same style as RF / XGB)
# -----------------------------------------------------
X_train, X_test, y_train, y_test = train_test_split(
    X, y,
    test_size=0.2,
    random_state=42,
    stratify=None   # don't stratify because there are very rare classes
)

print("Train size:", X_train.shape[0])
print("Test size :", X_test.shape[0])

# -----------------------------------------------------
# 3. Convert to PyTorch tensors & DataLoader
# -----------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.long)

X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.long)

train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
test_dataset  = TensorDataset(X_test_tensor, y_test_tensor)

batch_size = 512  # can be increased if VRAM is sufficient
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,  num_workers=0)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False, num_workers=0)

# -----------------------------------------------------
# 4. Define MLP model (adjust as needed)
# -----------------------------------------------------
class MLP(nn.Module):
    def __init__(self, input_dim=480, hidden_dim=512, num_classes=263, dropout=0.3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),

            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),

            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, x):
        return self.net(x)

model = MLP(input_dim=X.shape[1], hidden_dim=512, num_classes=num_classes, dropout=0.3).to(device)
print(model)

# -----------------------------------------------------
# 5. Loss, optimizer, training config
# -----------------------------------------------------
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
epochs = 20

# -----------------------------------------------------
# 6. Training loop with progress bar
# -----------------------------------------------------
for epoch in range(1, epochs + 1):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}")
    for xb, yb in pbar:
        xb = xb.to(device)
        yb = yb.to(device)

        optimizer.zero_grad()
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * xb.size(0)

        preds = torch.argmax(logits, dim=1)
        correct += (preds == yb).sum().item()
        total += yb.size(0)

        pbar.set_postfix({"loss": loss.item()})

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    print(f"Epoch {epoch}: loss={epoch_loss:.4f}, acc={epoch_acc:.4f}")

# -----------------------------------------------------
# 7. Save trained model
# -----------------------------------------------------
torch.save(model.state_dict(), "mlp_ec_esm2.pt")
print("Model saved as 'mlp_ec_esm2.pt'")

# -----------------------------------------------------
# 8. Evaluation on test set
# -----------------------------------------------------
model.eval()
all_preds = []
all_targets = []

with torch.no_grad():
    for xb, yb in tqdm(test_loader, desc="Evaluating on test set"):
        xb = xb.to(device)
        yb = yb.to(device)

        logits = model(xb)
        preds = torch.argmax(logits, dim=1)

        all_preds.append(preds.cpu().numpy())
        all_targets.append(yb.cpu().numpy())

y_true = np.concatenate(all_targets)
y_pred = np.concatenate(all_preds)

test_acc = accuracy_score(y_true, y_pred)
print("\n=== MLP Test Results ===")
print("Test accuracy MLP:", test_acc)

# -----------------------------------------------------
# 9. Classification report (avoid class mismatch errors)
# -----------------------------------------------------
labels_used = np.unique(np.concatenate([y_true, y_pred]))
target_names = le.inverse_transform(labels_used)

print("\nClassification report (MLP, only classes present in test/pred):")
print(classification_report(
    y_true,
    y_pred,
    labels=labels_used,
    target_names=target_names,
    zero_division=0
))


Shape X: (236607, 480)
Total classes: 263
Example classes: ['1.1.1' '1.1.2' '1.1.3' '1.1.5' '1.1.7' '1.1.9' '1.1.98' '1.1.99'
 '1.10.3' '1.10.5' '1.11.1' '1.11.2' '1.12.1' '1.12.2' '1.12.5' '1.12.7'
 '1.12.98' '1.12.99' '1.13.11' '1.13.12']
Train size: 189285
Test size : 47322
Using device: cuda
MLP(
  (net): Sequential(
    (0): Linear(in_features=480, out_features=512, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.3, inplace=False)
    (3): Linear(in_features=512, out_features=512, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.3, inplace=False)
    (6): Linear(in_features=512, out_features=263, bias=True)
  )
)


Epoch 1/20: 100%|██████████| 370/370 [00:02<00:00, 165.92it/s, loss=1.13] 


Epoch 1: loss=2.0400, acc=0.5496


Epoch 2/20: 100%|██████████| 370/370 [00:01<00:00, 211.51it/s, loss=0.563]


Epoch 2: loss=0.7293, acc=0.8352


Epoch 3/20: 100%|██████████| 370/370 [00:01<00:00, 216.63it/s, loss=0.492]


Epoch 3: loss=0.5097, acc=0.8845


Epoch 4/20: 100%|██████████| 370/370 [00:01<00:00, 205.69it/s, loss=0.29] 


Epoch 4: loss=0.4044, acc=0.9079


Epoch 5/20: 100%|██████████| 370/370 [00:01<00:00, 193.25it/s, loss=0.256]


Epoch 5: loss=0.3374, acc=0.9218


Epoch 6/20: 100%|██████████| 370/370 [00:01<00:00, 194.96it/s, loss=0.271]


Epoch 6: loss=0.2932, acc=0.9320


Epoch 7/20: 100%|██████████| 370/370 [00:02<00:00, 175.39it/s, loss=0.285]


Epoch 7: loss=0.2607, acc=0.9384


Epoch 8/20: 100%|██████████| 370/370 [00:01<00:00, 193.00it/s, loss=0.211]


Epoch 8: loss=0.2349, acc=0.9438


Epoch 9/20: 100%|██████████| 370/370 [00:01<00:00, 194.47it/s, loss=0.203]


Epoch 9: loss=0.2145, acc=0.9485


Epoch 10/20: 100%|██████████| 370/370 [00:02<00:00, 183.08it/s, loss=0.294]


Epoch 10: loss=0.1974, acc=0.9514


Epoch 11/20: 100%|██████████| 370/370 [00:02<00:00, 181.13it/s, loss=0.169]


Epoch 11: loss=0.1839, acc=0.9549


Epoch 12/20: 100%|██████████| 370/370 [00:01<00:00, 201.22it/s, loss=0.249] 


Epoch 12: loss=0.1724, acc=0.9570


Epoch 13/20: 100%|██████████| 370/370 [00:01<00:00, 185.34it/s, loss=0.184] 


Epoch 13: loss=0.1613, acc=0.9592


Epoch 14/20: 100%|██████████| 370/370 [00:01<00:00, 188.74it/s, loss=0.158] 


Epoch 14: loss=0.1523, acc=0.9612


Epoch 15/20: 100%|██████████| 370/370 [00:01<00:00, 201.17it/s, loss=0.196] 


Epoch 15: loss=0.1440, acc=0.9632


Epoch 16/20: 100%|██████████| 370/370 [00:01<00:00, 193.59it/s, loss=0.142] 


Epoch 16: loss=0.1358, acc=0.9648


Epoch 17/20: 100%|██████████| 370/370 [00:01<00:00, 208.96it/s, loss=0.1]   


Epoch 17: loss=0.1313, acc=0.9657


Epoch 18/20: 100%|██████████| 370/370 [00:01<00:00, 196.01it/s, loss=0.146] 


Epoch 18: loss=0.1242, acc=0.9672


Epoch 19/20: 100%|██████████| 370/370 [00:01<00:00, 208.40it/s, loss=0.128] 


Epoch 19: loss=0.1184, acc=0.9690


Epoch 20/20: 100%|██████████| 370/370 [00:01<00:00, 192.35it/s, loss=0.121] 


Epoch 20: loss=0.1132, acc=0.9698
Model saved as 'mlp_ec_esm2.pt'


Evaluating on test set: 100%|██████████| 93/93 [00:00<00:00, 245.29it/s]


=== MLP Test Results ===
Test accuracy MLP: 0.9756984066607498

Classification report (MLP, only classes present in test/pred):
              precision    recall  f1-score   support

       1.1.1       0.98      0.99      0.98      1401
       1.1.2       0.20      1.00      0.33         1
       1.1.3       0.85      0.89      0.87        19
       1.1.5       0.95      0.97      0.96        64
       1.1.7       0.00      0.00      0.00         0
       1.1.9       0.00      0.00      0.00         2
      1.1.98       1.00      1.00      1.00         6
      1.1.99       0.89      0.85      0.87        40
      1.10.3       0.97      0.97      0.97       119
      1.10.5       1.00      1.00      1.00         1
      1.11.1       0.84      0.96      0.90       224
      1.11.2       1.00      0.50      0.67         6
      1.12.1       0.00      0.00      0.00         3
      1.12.2       0.00      0.00      0.00         1
      1.12.7       1.00      0.33      0.50         3
     1




In [1]:
import torch
import numpy as np
import joblib

from transformers import AutoTokenizer, AutoModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
label_encoder = joblib.load("label_encoder_ec_esm2.joblib")
num_classes = len(label_encoder.classes_)
print("Num classes:", num_classes)

class MLP_EC(torch.nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(input_dim, 512),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.3),

            torch.nn.Linear(512, 512),     # hidden 2 (512 -> 512)
            torch.nn.ReLU(),
            torch.nn.Dropout(0.3),

            torch.nn.Linear(512, num_classes),  # output
        )

    def forward(self, x):
        return self.net(x)



  from .autonotebook import tqdm as notebook_tqdm


Device: cuda
Num classes: 263


In [2]:
features = np.load("esm2_features.npy")  
INPUT_DIM = features.shape[1]
print("INPUT_DIM:", INPUT_DIM)

model = MLP_EC(INPUT_DIM, num_classes).to(device)

state_dict = torch.load("mlp_ec_esm2.pt", map_location=device)
model.load_state_dict(state_dict)
model.eval()

print("MLP model loaded.")


INPUT_DIM: 480
MLP model loaded.


  state_dict = torch.load("mlp_ec_esm2.pt", map_location=device)


In [3]:
MODEL_NAME = "facebook/esm2_t12_35M_UR50D"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
esm_model = AutoModel.from_pretrained(MODEL_NAME)
esm_model.to(device)
esm_model.eval()

print("Loaded ESM2 model:", MODEL_NAME)


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loaded ESM2 model: facebook/esm2_t12_35M_UR50D


In [None]:
@torch.no_grad()
def embed_seq(seq: str) -> np.ndarray:
    """
    Embed 1 protein sequence using ESM2 t12 mean pooling.
    seq: amino acid string (A,C,D,...), without spaces.
    return: numpy array shape (INPUT_DIM,)
    """
    tokens = tokenizer(
        [seq],
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=1024
    )
    tokens = {k: v.to(device) for k, v in tokens.items()}

    outputs = esm_model(**tokens)           # last_hidden_state: [1, L, D]
    last_hidden = outputs.last_hidden_state # [1, L, D]
    mask = tokens["attention_mask"].unsqueeze(-1)  # [1, L, 1]

    masked = last_hidden * mask             # [1, L, D]
    summed = masked.sum(dim=1)              # [1, D]
    counts = mask.sum(dim=1)                # [1, 1]
    emb = (summed / counts).cpu().numpy()[0]    # [D]

    # check dimension
    if emb.shape[0] != INPUT_DIM:
        raise ValueError(
            f"Embedding dimension ({emb.shape[0]}) != INPUT_DIM MLP ({INPUT_DIM}). "
            "This means MODEL_NAME used is different from preprocessing."
        )

    return emb
@torch.no_grad()
def predict_ec(seq: str, top_k: int = 3):
    """
    Predict EC number from one protein sequence.
    seq  : amino acid string (A,C,D,...)
    top_k: show top N predictions
    """
    emb = embed_seq(seq)                                      # (D,)
    x = torch.tensor(emb, dtype=torch.float32, device=device).unsqueeze(0)  # (1, D)

    logits = model(x)
    probs = torch.softmax(logits, dim=1)[0].cpu().numpy()     # (num_classes,)

    top_idx = probs.argsort()[::-1][:top_k]
    top_ec  = label_encoder.inverse_transform(top_idx)
    top_conf = probs[top_idx]

    print("Top predictions:")
    for ec, p in zip(top_ec, top_conf):
        print(f"  {ec:10s}  (confidence = {p:.4f})")

    return top_ec[0], top_conf[0]


In [6]:
TEST_SEQ = input("Input sequence protein (A,C,D,... without space): ").strip()
pred_ec, conf = predict_ec(TEST_SEQ, top_k=3)

print("\nFinal prediction:")
print("Predicted EC :", pred_ec)
print("Confidence   :", conf)


Top predictions:
  2.7.1       (confidence = 0.7490)
  2.3.1       (confidence = 0.1324)
  3.1.1       (confidence = 0.0702)

Final prediction:
Predicted EC : 2.7.1
Confidence   : 0.74900246
