In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm
import pandas as pd

# Load ESM2 pooled embeddings
X = np.load("esm2_features.npy")  # (236607, 480)

df = pd.read_csv("3_levels_EC.tsv", sep="\t")
labels = df["EC number"].astype(str).values

le = LabelEncoder()
y = le.fit_transform(labels)

num_classes = len(le.classes_)
print("Classes:", num_classes)


  from .autonotebook import tqdm as notebook_tqdm


Classes: 263


In [2]:
class CNNDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32).unsqueeze(1)  # (N, 1, 480)
        self.y = torch.tensor(y, dtype=torch.long)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=None
)

train_ds = CNNDataset(X_train, y_train)
test_ds = CNNDataset(X_test, y_test)

train_loader = DataLoader(train_ds, batch_size=256, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=256, shuffle=False)


In [3]:
class CNNClassifier(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv1d(1, 32, kernel_size=7, padding=3),
            nn.ReLU(),
            nn.BatchNorm1d(32),

            nn.Conv1d(32, 64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.BatchNorm1d(64),

            nn.Conv1d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm1d(128),

            nn.AdaptiveAvgPool1d(1)  # → (batch, 128, 1)
        )

        self.fc = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.conv(x)
        x = x.squeeze(-1)  # remove last dim
        return self.fc(x)


In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CNNClassifier(num_classes).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

EPOCHS = 20

for epoch in range(1, EPOCHS+1):
    model.train()
    total = 0
    correct = 0
    total_loss = 0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}")

    for Xb, yb in pbar:
        Xb, yb = Xb.to(device), yb.to(device)

        optimizer.zero_grad()
        logits = model(Xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * len(yb)
        preds = logits.argmax(1)
        correct += (preds == yb).sum().item()
        total += len(yb)

        pbar.set_postfix(loss=loss.item())

    print(f"Epoch {epoch}: loss={total_loss/total:.4f}, acc={correct/total:.4f}")


Epoch 1/20: 100%|██████████| 740/740 [00:07<00:00, 93.79it/s, loss=2.45] 


Epoch 1: loss=3.5105, acc=0.2533


Epoch 2/20: 100%|██████████| 740/740 [00:06<00:00, 109.37it/s, loss=1.67]


Epoch 2: loss=2.0762, acc=0.5560


Epoch 3/20: 100%|██████████| 740/740 [00:06<00:00, 107.34it/s, loss=1.37]


Epoch 3: loss=1.4895, acc=0.6871


Epoch 4/20: 100%|██████████| 740/740 [00:06<00:00, 108.80it/s, loss=0.97] 


Epoch 4: loss=1.1754, acc=0.7546


Epoch 5/20: 100%|██████████| 740/740 [00:06<00:00, 109.32it/s, loss=0.63] 


Epoch 5: loss=0.9834, acc=0.7950


Epoch 6/20: 100%|██████████| 740/740 [00:06<00:00, 108.14it/s, loss=0.559]


Epoch 6: loss=0.8516, acc=0.8213


Epoch 7/20: 100%|██████████| 740/740 [00:06<00:00, 108.72it/s, loss=0.822]


Epoch 7: loss=0.7512, acc=0.8414


Epoch 8/20: 100%|██████████| 740/740 [00:06<00:00, 107.34it/s, loss=0.902]


Epoch 8: loss=0.6740, acc=0.8577


Epoch 9/20: 100%|██████████| 740/740 [00:06<00:00, 110.08it/s, loss=0.785]


Epoch 9: loss=0.6133, acc=0.8706


Epoch 10/20: 100%|██████████| 740/740 [00:06<00:00, 107.95it/s, loss=0.467]


Epoch 10: loss=0.5638, acc=0.8801


Epoch 11/20: 100%|██████████| 740/740 [00:06<00:00, 108.25it/s, loss=0.443]


Epoch 11: loss=0.5239, acc=0.8882


Epoch 12/20: 100%|██████████| 740/740 [00:06<00:00, 109.11it/s, loss=0.421]


Epoch 12: loss=0.4886, acc=0.8951


Epoch 13/20: 100%|██████████| 740/740 [00:06<00:00, 107.08it/s, loss=0.249]


Epoch 13: loss=0.4598, acc=0.9010


Epoch 14/20: 100%|██████████| 740/740 [00:06<00:00, 109.69it/s, loss=0.456]


Epoch 14: loss=0.4356, acc=0.9065


Epoch 15/20: 100%|██████████| 740/740 [00:06<00:00, 107.92it/s, loss=0.376]


Epoch 15: loss=0.4121, acc=0.9102


Epoch 16/20: 100%|██████████| 740/740 [00:06<00:00, 108.05it/s, loss=0.414]


Epoch 16: loss=0.3928, acc=0.9143


Epoch 17/20: 100%|██████████| 740/740 [00:06<00:00, 108.83it/s, loss=0.347]


Epoch 17: loss=0.3761, acc=0.9174


Epoch 18/20: 100%|██████████| 740/740 [00:06<00:00, 109.19it/s, loss=0.346]


Epoch 18: loss=0.3585, acc=0.9204


Epoch 19/20: 100%|██████████| 740/740 [00:06<00:00, 108.89it/s, loss=0.415]


Epoch 19: loss=0.3461, acc=0.9237


Epoch 20/20: 100%|██████████| 740/740 [00:06<00:00, 108.17it/s, loss=0.393]

Epoch 20: loss=0.3332, acc=0.9260





In [5]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for Xb, yb in tqdm(test_loader, desc="Evaluating"):
        Xb, yb = Xb.to(device), yb.to(device)
        logits = model(Xb)
        preds = logits.argmax(1)
        correct += (preds == yb).sum().item()
        total += len(yb)

print("Test accuracy:", correct / total)


Evaluating: 100%|██████████| 185/185 [00:00<00:00, 237.03it/s]

Test accuracy: 0.8819576518321288



