In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModel

# =====================================================
# 1. Load ESM2 embeddings & labels
# =====================================================
X = np.load("esm2_features.npy")  # shape = (236607, 480)
print("X shape:", X.shape)

df = pd.read_csv("3_levels_EC.tsv", sep="\t")
labels = df["EC number"].astype(str).values

# Encode label
le = LabelEncoder()
y = le.fit_transform(labels)

num_classes = len(le.classes_)
print("Total classes:", num_classes)
print("Example classes:", le.classes_[:20])

# =====================================================
# 2. Train-test split
# =====================================================
X_train, X_test, y_train, y_test = train_test_split(
    X, y,
    test_size=0.2,
    random_state=42,
    stratify=None  # jangan stratify karena ada kelas sangat langka
)

print("Train size:", X_train.shape[0])
print("Test size :", X_test.shape[0])

# =====================================================
# 3. Dataset & DataLoader (reshape ke pseudo-sequence)
# =====================================================
SEQ_LEN = 15
EMB_DIM = 32  # 15 × 32 = 480

class ProteinDataset(Dataset):
    def __init__(self, X, y):
        # reshape (N, 480) -> (N, 15, 32)
        X = X.reshape(len(X), SEQ_LEN, EMB_DIM)
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

train_ds = ProteinDataset(X_train, y_train)
test_ds  = ProteinDataset(X_test, y_test)

train_loader = DataLoader(train_ds, batch_size=256, shuffle=True)
test_loader  = DataLoader(test_ds, batch_size=256, shuffle=False)

# =====================================================
# 4. BiLSTM model
# =====================================================

class BiLSTMClassifier(nn.Module):
    def __init__(self, embed_dim=32, hidden_size=128, num_layers=2, num_classes=263):
        super().__init__()

        self.lstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True
        )

        # karena bidirectional -> hidden_size * 2
        self.fc = nn.Linear(hidden_size * 2, num_classes)

    def forward(self, x):
        # x: (batch, seq_len=15, embed_dim=32)
        out, (hn, cn) = self.lstm(x)
        # hn shape: (num_layers*2, batch, hidden_size)
        # pakai last layer, dua arah: hn[-2] (forward), hn[-1] (backward)
        final = torch.cat((hn[-2, :, :], hn[-1, :, :]), dim=1)  # (batch, hidden_size*2)
        logits = self.fc(final)
        return logits

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

model = BiLSTMClassifier(
    embed_dim=EMB_DIM,
    hidden_size=128,
    num_layers=2,
    num_classes=num_classes
).to(device)

print(model)

# =====================================================
# 5. Optimizer & Loss
# =====================================================
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

EPOCHS = 20

# =====================================================
# 6. Training loop with tqdm
# =====================================================
for epoch in range(1, EPOCHS + 1):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}")
    for Xb, yb in pbar:
        Xb, yb = Xb.to(device), yb.to(device)

        optimizer.zero_grad()
        logits = model(Xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * Xb.size(0)

        preds = logits.argmax(dim=1)
        correct += (preds == yb).sum().item()
        total += yb.size(0)

        pbar.set_postfix({"loss": loss.item()})

    train_loss = total_loss / total
    train_acc = correct / total
    print(f"Epoch {epoch}: loss={train_loss:.4f}, acc={train_acc:.4f}")

# =====================================================
# 7. Save model
# =====================================================
torch.save(model.state_dict(), "bilstm_ec_esm2.pt")
print("Model saved as 'bilstm_ec_esm2.pt'")

# =====================================================
# 8. Evaluation: test accuracy + classification_report
# =====================================================
model.eval()
all_preds = []
all_targets = []

with torch.no_grad():
    for Xb, yb in tqdm(test_loader, desc="Evaluating (BiLSTM)"):
        Xb, yb = Xb.to(device), yb.to(device)
        logits = model(Xb)
        preds = logits.argmax(dim=1)

        all_preds.append(preds.cpu().numpy())
        all_targets.append(yb.cpu().numpy())

y_true = np.concatenate(all_targets)
y_pred = np.concatenate(all_preds)

test_acc = accuracy_score(y_true, y_pred)
print("\n=== BiLSTM Test Results ===")
print("Test Accuracy (BiLSTM):", test_acc)

# Hanya kelas yang muncul di y_true/y_pred supaya tidak error
labels_used = np.unique(np.concatenate([y_true, y_pred]))
target_names = le.inverse_transform(labels_used)

print("\nClassification report (BiLSTM, only classes present in test/pred):")
print(classification_report(
    y_true,
    y_pred,
    labels=labels_used,
    target_names=target_names,
    zero_division=0
))


  from .autonotebook import tqdm as notebook_tqdm


X shape: (236607, 480)
Total classes: 263
Example classes: ['1.1.1' '1.1.2' '1.1.3' '1.1.5' '1.1.7' '1.1.9' '1.1.98' '1.1.99'
 '1.10.3' '1.10.5' '1.11.1' '1.11.2' '1.12.1' '1.12.2' '1.12.5' '1.12.7'
 '1.12.98' '1.12.99' '1.13.11' '1.13.12']
Train size: 189285
Test size : 47322
Using device: cuda
BiLSTMClassifier(
  (lstm): LSTM(32, 128, num_layers=2, batch_first=True, bidirectional=True)
  (fc): Linear(in_features=256, out_features=263, bias=True)
)


Epoch 1/20: 100%|██████████| 740/740 [00:17<00:00, 41.42it/s, loss=1.6] 


Epoch 1: loss=3.1566, acc=0.2978


Epoch 2/20: 100%|██████████| 740/740 [00:17<00:00, 42.83it/s, loss=0.834]


Epoch 2: loss=1.1202, acc=0.7494


Epoch 3/20: 100%|██████████| 740/740 [00:18<00:00, 41.06it/s, loss=0.594]


Epoch 3: loss=0.6680, acc=0.8491


Epoch 4/20: 100%|██████████| 740/740 [00:18<00:00, 39.46it/s, loss=0.483]


Epoch 4: loss=0.4810, acc=0.8905


Epoch 5/20: 100%|██████████| 740/740 [00:18<00:00, 40.26it/s, loss=0.455]


Epoch 5: loss=0.3764, acc=0.9130


Epoch 6/20: 100%|██████████| 740/740 [00:19<00:00, 37.89it/s, loss=0.38] 


Epoch 6: loss=0.3099, acc=0.9276


Epoch 7/20: 100%|██████████| 740/740 [00:19<00:00, 37.82it/s, loss=0.224]


Epoch 7: loss=0.2630, acc=0.9371


Epoch 8/20: 100%|██████████| 740/740 [00:19<00:00, 38.79it/s, loss=0.105]


Epoch 8: loss=0.2294, acc=0.9450


Epoch 9/20: 100%|██████████| 740/740 [00:18<00:00, 39.60it/s, loss=0.138] 


Epoch 9: loss=0.1998, acc=0.9509


Epoch 10/20: 100%|██████████| 740/740 [00:18<00:00, 40.59it/s, loss=0.262] 


Epoch 10: loss=0.1766, acc=0.9558


Epoch 11/20: 100%|██████████| 740/740 [00:18<00:00, 41.00it/s, loss=0.131] 


Epoch 11: loss=0.1586, acc=0.9601


Epoch 12/20: 100%|██████████| 740/740 [00:18<00:00, 40.63it/s, loss=0.119] 


Epoch 12: loss=0.1423, acc=0.9636


Epoch 13/20: 100%|██████████| 740/740 [00:18<00:00, 40.68it/s, loss=0.0962]


Epoch 13: loss=0.1288, acc=0.9669


Epoch 14/20: 100%|██████████| 740/740 [00:18<00:00, 40.41it/s, loss=0.0958]


Epoch 14: loss=0.1158, acc=0.9696


Epoch 15/20: 100%|██████████| 740/740 [00:18<00:00, 40.01it/s, loss=0.0805]


Epoch 15: loss=0.1051, acc=0.9721


Epoch 16/20: 100%|██████████| 740/740 [00:18<00:00, 40.09it/s, loss=0.121] 


Epoch 16: loss=0.1011, acc=0.9729


Epoch 17/20: 100%|██████████| 740/740 [00:18<00:00, 39.19it/s, loss=0.103] 


Epoch 17: loss=0.0909, acc=0.9757


Epoch 18/20: 100%|██████████| 740/740 [00:18<00:00, 40.19it/s, loss=0.179] 


Epoch 18: loss=0.0825, acc=0.9777


Epoch 19/20: 100%|██████████| 740/740 [00:18<00:00, 40.32it/s, loss=0.0968]


Epoch 19: loss=0.0768, acc=0.9788


Epoch 20/20: 100%|██████████| 740/740 [00:18<00:00, 40.13it/s, loss=0.173] 


Epoch 20: loss=0.0692, acc=0.9807
Model saved as 'bilstm_ec_esm2.pt'


Evaluating (BiLSTM): 100%|██████████| 185/185 [00:02<00:00, 70.46it/s]



=== BiLSTM Test Results ===
Test Accuracy (BiLSTM): 0.9581589958158996

Classification report (BiLSTM, only classes present in test/pred):
              precision    recall  f1-score   support

       1.1.1       0.97      0.98      0.97      1401
       1.1.2       0.50      1.00      0.67         1
       1.1.3       0.72      0.95      0.82        19
       1.1.5       0.92      0.95      0.94        64
       1.1.9       0.00      0.00      0.00         2
      1.1.98       1.00      1.00      1.00         6
      1.1.99       0.94      0.82      0.88        40
      1.10.3       0.99      0.96      0.97       119
      1.10.5       1.00      1.00      1.00         1
      1.11.1       0.84      0.94      0.89       224
      1.11.2       1.00      0.50      0.67         6
      1.12.1       0.25      0.33      0.29         3
      1.12.2       0.00      0.00      0.00         1
      1.12.7       1.00      0.33      0.50         3
     1.12.98       0.50      0.29      0.36      

In [10]:
features = np.load("esm2_features.npy")  
INPUT_DIM = features.shape[1]
print("INPUT_DIM:", INPUT_DIM)

model = BiLSTMClassifier(INPUT_DIM, num_classes).to(device)

state_dict = torch.load("bilstm_ec_esm2.pt", map_location=device)
model.load_state_dict(state_dict)
model.eval()

print("MLP model loaded.")


INPUT_DIM: 480


RuntimeError: Error(s) in loading state_dict for BiLSTMClassifier:
	size mismatch for lstm.weight_ih_l0: copying a param with shape torch.Size([512, 32]) from checkpoint, the shape in current model is torch.Size([1052, 480]).
	size mismatch for lstm.weight_hh_l0: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([1052, 263]).
	size mismatch for lstm.bias_ih_l0: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1052]).
	size mismatch for lstm.bias_hh_l0: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1052]).
	size mismatch for lstm.weight_ih_l0_reverse: copying a param with shape torch.Size([512, 32]) from checkpoint, the shape in current model is torch.Size([1052, 480]).
	size mismatch for lstm.weight_hh_l0_reverse: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([1052, 263]).
	size mismatch for lstm.bias_ih_l0_reverse: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1052]).
	size mismatch for lstm.bias_hh_l0_reverse: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1052]).
	size mismatch for lstm.weight_ih_l1: copying a param with shape torch.Size([512, 256]) from checkpoint, the shape in current model is torch.Size([1052, 526]).
	size mismatch for lstm.weight_hh_l1: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([1052, 263]).
	size mismatch for lstm.bias_ih_l1: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1052]).
	size mismatch for lstm.bias_hh_l1: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1052]).
	size mismatch for lstm.weight_ih_l1_reverse: copying a param with shape torch.Size([512, 256]) from checkpoint, the shape in current model is torch.Size([1052, 526]).
	size mismatch for lstm.weight_hh_l1_reverse: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([1052, 263]).
	size mismatch for lstm.bias_ih_l1_reverse: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1052]).
	size mismatch for lstm.bias_hh_l1_reverse: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1052]).
	size mismatch for fc.weight: copying a param with shape torch.Size([263, 256]) from checkpoint, the shape in current model is torch.Size([263, 526]).

In [4]:
MODEL_NAME = "facebook/esm2_t12_35M_UR50D"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
esm_model = AutoModel.from_pretrained(MODEL_NAME)
esm_model.to(device)
esm_model.eval()

print("Loaded ESM2 model:", MODEL_NAME)


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loaded ESM2 model: facebook/esm2_t12_35M_UR50D


In [8]:
@torch.no_grad()
def embed_seq(seq: str) -> np.ndarray:
    """
    Embed 1 protein sequence using ESM2 t12 mean pooling.
    seq: string asam amino (A,C,D,...), tanpa spasi.
    return: numpy array shape (INPUT_DIM,)
    """
    tokens = tokenizer(
        [seq],
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=1024
    )
    tokens = {k: v.to(device) for k, v in tokens.items()}

    outputs = esm_model(**tokens)           # last_hidden_state: [1, L, D]
    last_hidden = outputs.last_hidden_state # [1, L, D]
    mask = tokens["attention_mask"].unsqueeze(-1)  # [1, L, 1]

    masked = last_hidden * mask             # [1, L, D]
    summed = masked.sum(dim=1)              # [1, D]
    counts = mask.sum(dim=1)                # [1, 1]
    emb = (summed / counts).cpu().numpy()[0]    # [D]

    # cek dimensi
    if emb.shape[0] != INPUT_DIM:
        raise ValueError(
            f"Dimension embeddings ({emb.shape[0]}) != INPUT_DIM MLP ({INPUT_DIM}). "
            f"Please retrain the MLP model with the correct INPUT_DIM."
        )

    return emb
@torch.no_grad()
def predict_ec(seq: str, top_k: int = 3):
    """
    Predict EC number from one protein sequence.
    seq  : string amino acid (A,C,D,...)
    top_k: print top-k predictions
    return: (pred_ec, conf)
    """
    emb = embed_seq(seq)                                      # (D,)
    x = torch.tensor(emb, dtype=torch.float32, device=device).unsqueeze(0)  # (1, D)

    logits = model(x)
    probs = torch.softmax(logits, dim=1)[0].cpu().numpy()     # (num_classes,)

    top_idx = probs.argsort()[::-1][:top_k]
    top_ec  = label_encoder.inverse_transform(top_idx)
    top_conf = probs[top_idx]

    print("Top predictions:")
    for ec, p in zip(top_ec, top_conf):
        print(f"  {ec:10s}  (confidence = {p:.4f})")

    return top_ec[0], top_conf[0]


In [9]:
TEST_SEQ = input("Input sequence protein (A,C,D,... without space): ").strip()
pred_ec, conf = predict_ec(TEST_SEQ, top_k=3)

print("\nFinal prediction:")
print("Predicted EC :", pred_ec)
print("Confidence   :", conf)


IndexError: too many indices for tensor of dimension 2