# 02 â€“ Baseline CNN Models (Basset-style)

This notebook trains a lightweight CNN baseline for functional element classification.
It expects processed `.npz` files from Notebook 01.

In [1]:
import numpy as np
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score, average_precision_score
from pathlib import Path

PROC = Path('data/processed')
train = np.load(PROC/'train.npz')
val = np.load(PROC/'val.npz')
Xtr, ytr = train['X'], train['y']
Xva, yva = val['X'], val['y']
print('Train:', Xtr.shape, 'Val:', Xva.shape)

Train: (40, 2000, 4) Val: (5, 2000, 4)


In [2]:
class SeqDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X.transpose(0,2,1), dtype=torch.float32)  # (N, 4, L)
        self.y = torch.tensor(y, dtype=torch.float32)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, i):
        return self.X[i], self.y[i]

train_ds = SeqDataset(Xtr, ytr)
val_ds = SeqDataset(Xva, yva)
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=128)


In [3]:
# Minimal Basset-style CNN
class BassetMini(nn.Module):
    def __init__(self, L):
        super().__init__()
        self.conv1 = nn.Conv1d(4, 300, kernel_size=19, padding=9)
        self.pool1 = nn.MaxPool1d(3)
        self.conv2 = nn.Conv1d(300, 200, kernel_size=11, padding=5)
        self.pool2 = nn.MaxPool1d(4)
        self.conv3 = nn.Conv1d(200, 200, kernel_size=7, padding=3)
        self.pool3 = nn.MaxPool1d(4)
        # compute output length
        def out_len(L, k, p, s):
            return (L + 2*p - k)//s + 1
        L1 = out_len(L, 19, 9, 1)//3
        L2 = out_len(L1, 11, 5, 1)//4
        L3 = out_len(L2, 7, 3, 1)//4
        self.fc1 = nn.Linear(200*L3, 1000)
        self.drop = nn.Dropout(0.3)
        self.fc2 = nn.Linear(1000, 1)  # binary
        self.act = nn.ReLU()
    def forward(self, x):
        x = self.pool1(self.act(self.conv1(x)))
        x = self.pool2(self.act(self.conv2(x)))
        x = self.pool3(self.act(self.conv3(x)))
        x = torch.flatten(x, 1)
        x = self.act(self.fc1(x))
        x = self.drop(x)
        x = self.fc2(x)
        return x.squeeze(1)


In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
L = Xtr.shape[1]
model = BassetMini(L).to(device)
optim = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.BCEWithLogitsLoss()


In [5]:
def evaluate(model, loader):
    model.eval()
    ys, ps = [], []
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            logits = model(xb)
            prob = torch.sigmoid(logits)
            ys.append(yb.detach().cpu().numpy())
            ps.append(prob.detach().cpu().numpy())
    y = np.concatenate(ys)
    p = np.concatenate(ps)
    try:
        auroc = roc_auc_score(y, p)
    except Exception:
        auroc = float('nan')
    try:
        prauc = average_precision_score(y, p)
    except Exception:
        prauc = float('nan')
    return auroc, prauc


In [8]:
best = {'epoch': -1, 'auroc': 0.0, 'prauc': 0.0}
EPOCHS = 5
for epoch in range(1, EPOCHS+1):
    model.train()
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        loss = criterion(logits, yb)
        optim.zero_grad(); loss.backward(); optim.step()
    auroc, prauc = evaluate(model, val_loader)
    print(f'Epoch {epoch}: val AUROC={auroc:.4f}, PR-AUC={prauc:.4f}')
    if auroc > best['auroc']:
        best = {'epoch': epoch, 'auroc': auroc, 'prauc': prauc}
        torch.save(model.state_dict(), 'results/bassetmini_best.pt')

import csv
with open('results/metrics.csv', 'a', newline='') as f:
    w = csv.writer(f)
    w.writerow(['BassetMini','binary',f"{best['auroc']:.4f}",f"{best['prauc']:.4f}", 'val'])
print('Best:', best)
print('Saved checkpoint to results/bassetmini_best.pt and appended results/metrics.csv')


Epoch 1: val AUROC=0.5000, PR-AUC=0.8000
Epoch 2: val AUROC=0.5000, PR-AUC=0.8000
Epoch 3: val AUROC=0.5000, PR-AUC=0.8000
Epoch 4: val AUROC=0.5000, PR-AUC=0.8000
Epoch 5: val AUROC=0.5000, PR-AUC=0.8000
Best: {'epoch': 1, 'auroc': 0.5, 'prauc': 0.8}
Saved checkpoint to results/bassetmini_best.pt and appended results/metrics.csv
