In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
from tqdm.auto import tqdm

# =====================================================
# 1. Load ESM2 embeddings & labels
# =====================================================
X = np.load("esm2_features.npy")  # shape = (236607, 480)
print("X shape:", X.shape)

df = pd.read_csv("3_levels_EC.tsv", sep="\t")
labels = df["EC number"].astype(str).values

# Encode label
le = LabelEncoder()
y = le.fit_transform(labels)

num_classes = len(le.classes_)
print("Total classes:", num_classes)
print("Example classes:", le.classes_[:20])

# =====================================================
# 2. Train-test split
# =====================================================
X_train, X_test, y_train, y_test = train_test_split(
    X, y,
    test_size=0.2,
    random_state=42,
    stratify=None  # jangan stratify karena ada kelas sangat langka
)

print("Train size:", X_train.shape[0])
print("Test size :", X_test.shape[0])

# =====================================================
# 3. Dataset & DataLoader (reshape ke pseudo-sequence)
# =====================================================
SEQ_LEN = 15
EMB_DIM = 32  # 15 × 32 = 480

class ProteinDataset(Dataset):
    def __init__(self, X, y):
        # reshape (N, 480) -> (N, 15, 32)
        X = X.reshape(len(X), SEQ_LEN, EMB_DIM)
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

train_ds = ProteinDataset(X_train, y_train)
test_ds  = ProteinDataset(X_test, y_test)

train_loader = DataLoader(train_ds, batch_size=256, shuffle=True)
test_loader  = DataLoader(test_ds, batch_size=256, shuffle=False)

# =====================================================
# 4. BiLSTM model
# =====================================================
class BiLSTMClassifier(nn.Module):
    def __init__(self, embed_dim=32, hidden_size=128, num_layers=2, num_classes=263):
        super().__init__()

        self.lstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True
        )

        # karena bidirectional -> hidden_size * 2
        self.fc = nn.Linear(hidden_size * 2, num_classes)

    def forward(self, x):
        # x: (batch, seq_len=15, embed_dim=32)
        out, (hn, cn) = self.lstm(x)
        # hn shape: (num_layers*2, batch, hidden_size)
        # pakai last layer, dua arah: hn[-2] (forward), hn[-1] (backward)
        final = torch.cat((hn[-2, :, :], hn[-1, :, :]), dim=1)  # (batch, hidden_size*2)
        logits = self.fc(final)
        return logits

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

model = BiLSTMClassifier(
    embed_dim=EMB_DIM,
    hidden_size=128,
    num_layers=2,
    num_classes=num_classes
).to(device)

print(model)

# =====================================================
# 5. Optimizer & Loss
# =====================================================
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

EPOCHS = 20

# =====================================================
# 6. Training loop with tqdm
# =====================================================
for epoch in range(1, EPOCHS + 1):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}")
    for Xb, yb in pbar:
        Xb, yb = Xb.to(device), yb.to(device)

        optimizer.zero_grad()
        logits = model(Xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * Xb.size(0)

        preds = logits.argmax(dim=1)
        correct += (preds == yb).sum().item()
        total += yb.size(0)

        pbar.set_postfix({"loss": loss.item()})

    train_loss = total_loss / total
    train_acc = correct / total
    print(f"Epoch {epoch}: loss={train_loss:.4f}, acc={train_acc:.4f}")

# =====================================================
# 7. Save model
# =====================================================
torch.save(model.state_dict(), "bilstm_ec_esm2.pt")
print("Model saved as 'bilstm_ec_esm2.pt'")

# =====================================================
# 8. Evaluation: test accuracy + classification_report
# =====================================================
model.eval()
all_preds = []
all_targets = []

with torch.no_grad():
    for Xb, yb in tqdm(test_loader, desc="Evaluating (BiLSTM)"):
        Xb, yb = Xb.to(device), yb.to(device)
        logits = model(Xb)
        preds = logits.argmax(dim=1)

        all_preds.append(preds.cpu().numpy())
        all_targets.append(yb.cpu().numpy())

y_true = np.concatenate(all_targets)
y_pred = np.concatenate(all_preds)

test_acc = accuracy_score(y_true, y_pred)
print("\n=== BiLSTM Test Results ===")
print("Test Accuracy (BiLSTM):", test_acc)

# Hanya kelas yang muncul di y_true/y_pred supaya tidak error
labels_used = np.unique(np.concatenate([y_true, y_pred]))
target_names = le.inverse_transform(labels_used)

print("\nClassification report (BiLSTM, only classes present in test/pred):")
print(classification_report(
    y_true,
    y_pred,
    labels=labels_used,
    target_names=target_names,
    zero_division=0
))


  from .autonotebook import tqdm as notebook_tqdm


X shape: (236607, 480)
Total classes: 263
Example classes: ['1.1.1' '1.1.2' '1.1.3' '1.1.5' '1.1.7' '1.1.9' '1.1.98' '1.1.99'
 '1.10.3' '1.10.5' '1.11.1' '1.11.2' '1.12.1' '1.12.2' '1.12.5' '1.12.7'
 '1.12.98' '1.12.99' '1.13.11' '1.13.12']
Train size: 189285
Test size : 47322
Using device: cuda
BiLSTMClassifier(
  (lstm): LSTM(32, 128, num_layers=2, batch_first=True, bidirectional=True)
  (fc): Linear(in_features=256, out_features=263, bias=True)
)


Epoch 1/20: 100%|██████████| 740/740 [00:03<00:00, 185.06it/s, loss=1.2] 


Epoch 1: loss=2.9993, acc=0.3332


Epoch 2/20: 100%|██████████| 740/740 [00:03<00:00, 188.08it/s, loss=1.15] 


Epoch 2: loss=1.0446, acc=0.7669


Epoch 3/20: 100%|██████████| 740/740 [00:03<00:00, 189.65it/s, loss=0.701]


Epoch 3: loss=0.6391, acc=0.8547


Epoch 4/20: 100%|██████████| 740/740 [00:04<00:00, 176.68it/s, loss=0.536]


Epoch 4: loss=0.4688, acc=0.8921


Epoch 5/20: 100%|██████████| 740/740 [00:04<00:00, 182.22it/s, loss=0.314]


Epoch 5: loss=0.3680, acc=0.9148


Epoch 6/20: 100%|██████████| 740/740 [00:03<00:00, 186.51it/s, loss=0.246]


Epoch 6: loss=0.3041, acc=0.9277


Epoch 7/20: 100%|██████████| 740/740 [00:04<00:00, 174.79it/s, loss=0.23] 


Epoch 7: loss=0.2562, acc=0.9384


Epoch 8/20: 100%|██████████| 740/740 [00:04<00:00, 177.59it/s, loss=0.333] 


Epoch 8: loss=0.2269, acc=0.9446


Epoch 9/20: 100%|██████████| 740/740 [00:04<00:00, 178.49it/s, loss=0.185] 


Epoch 9: loss=0.1998, acc=0.9508


Epoch 10/20: 100%|██████████| 740/740 [00:03<00:00, 200.44it/s, loss=0.177] 


Epoch 10: loss=0.1750, acc=0.9562


Epoch 11/20: 100%|██████████| 740/740 [00:04<00:00, 180.27it/s, loss=0.148] 


Epoch 11: loss=0.1575, acc=0.9601


Epoch 12/20: 100%|██████████| 740/740 [00:03<00:00, 203.34it/s, loss=0.0828]


Epoch 12: loss=0.1420, acc=0.9635


Epoch 13/20: 100%|██████████| 740/740 [00:03<00:00, 200.59it/s, loss=0.0781]


Epoch 13: loss=0.1275, acc=0.9672


Epoch 14/20: 100%|██████████| 740/740 [00:03<00:00, 198.22it/s, loss=0.0991]


Epoch 14: loss=0.1167, acc=0.9693


Epoch 15/20: 100%|██████████| 740/740 [00:03<00:00, 189.35it/s, loss=0.155] 


Epoch 15: loss=0.1068, acc=0.9719


Epoch 16/20: 100%|██████████| 740/740 [00:03<00:00, 196.03it/s, loss=0.0673]


Epoch 16: loss=0.0972, acc=0.9739


Epoch 17/20: 100%|██████████| 740/740 [00:04<00:00, 179.05it/s, loss=0.144] 


Epoch 17: loss=0.0858, acc=0.9773


Epoch 18/20: 100%|██████████| 740/740 [00:03<00:00, 198.67it/s, loss=0.0719]


Epoch 18: loss=0.0835, acc=0.9773


Epoch 19/20: 100%|██████████| 740/740 [00:03<00:00, 200.03it/s, loss=0.0531]


Epoch 19: loss=0.0757, acc=0.9790


Epoch 20/20: 100%|██████████| 740/740 [00:03<00:00, 194.78it/s, loss=0.126] 


Epoch 20: loss=0.0738, acc=0.9795
Model saved as 'bilstm_ec_esm2.pt'


Evaluating (BiLSTM): 100%|██████████| 185/185 [00:00<00:00, 337.71it/s]


=== BiLSTM Test Results ===
Test Accuracy (BiLSTM): 0.9612865052195596

Classification report (BiLSTM, only classes present in test/pred):
              precision    recall  f1-score   support

       1.1.1       0.98      0.98      0.98      1401
       1.1.2       0.17      1.00      0.29         1
       1.1.3       0.72      0.95      0.82        19
       1.1.5       0.94      0.95      0.95        64
       1.1.9       0.00      0.00      0.00         2
      1.1.98       0.86      1.00      0.92         6
      1.1.99       0.92      0.82      0.87        40
      1.10.3       0.97      0.97      0.97       119
      1.10.5       1.00      1.00      1.00         1
      1.11.1       0.85      0.92      0.88       224
      1.11.2       1.00      0.17      0.29         6
      1.12.1       1.00      0.33      0.50         3
      1.12.2       0.00      0.00      0.00         1
      1.12.7       1.00      0.67      0.80         3
     1.12.98       1.00      0.43      0.60      


