In [1]:
import pandas as pd

# Robust, streaming parse to avoid pandas tokenizer OOM
files = ['data/netMHCpan_training_data/BA/c000_ba', 'data/netMHCpan_training_data/BA/c001_ba',
         'data/netMHCpan_training_data/BA/c002_ba', 'data/netMHCpan_training_data/BA/c003_ba',
         'data/netMHCpan_training_data/BA/c004_ba']

peptides, affinities, alleles = [], [], []
for file in files:
    with open(file) as f:
        for line in f:
            line = line.strip()
            if not line or line.startswith('#'):
                continue
            parts = line.split()
            if len(parts) < 3:
                continue
            pep, aff_str, allele = parts[0], parts[1], parts[2]
            try:
                aff = float(aff_str)
            except ValueError:
                continue
            peptides.append(pep)
            affinities.append(aff)
            alleles.append(allele)

fullFrame = pd.DataFrame({
    'peptide': peptides,
    'binding_affinity': pd.Series(affinities, dtype='float32'),
    'mhc_allele': alleles,
})

fullFrame.head()

Unnamed: 0,peptide,binding_affinity,mhc_allele
0,AAAAAAYAAM,0.177415,H-2-Db
1,AAAAAAYAAM,0.4631,H-2-Kb
2,AAAAFEAAL,0.362118,BoLA-3:00101
3,AAAAFEAAL,0.468035,BoLA-3:00201
4,AAAAFEAAL,0.522653,HLA-B48:01


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class BabyNetMHCpan(nn.Module):
    def __init__(self):
        super().__init__()
        # Peptide: one-hot encode → small CNN
        self.peptide_encoder = nn.Conv1d(20, 64, kernel_size=3)
        
        # MHC: pseudo-sequence → embedding
        self.mhc_encoder = nn.Linear(34 * 20, 64)  # Just flatten one-hot
        
        # Combine and predict
        self.predictor = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()  # Output in [0,1] for BA
        )
    
    def forward(self, peptide, mhc_pseudo):
        # The magic happens here
        p = F.relu(self.peptide_encoder(peptide))
        p = F.adaptive_max_pool1d(p, 1).squeeze(-1)

        if mhc_pseudo.dim() == 3:
            mhc_pseudo = mhc_pseudo.reshape(mhc_pseudo.size(0), -1)
        m = F.relu(self.mhc_encoder(mhc_pseudo))

        z = torch.cat([p, m], dim=1)
        return self.predictor(z)
    


In [3]:
# Parse MHC pseudo-sequences and compute coverage vs BA rows
from pathlib import Path

pseudo_path = Path('data/netMHCpan_training_data/MHC_pseudo.dat')
allele_to_pseudo_raw = {}
with pseudo_path.open() as f:
    for line in f:
        line = line.strip()
        if not line or line.startswith('#'):
            continue
        parts = line.split()
        if len(parts) < 2:
            continue
        allele, seq = parts[0], parts[1]
        allele_to_pseudo_raw[allele] = seq

# Keep only canonical 34-aa pseudo sequences to match model expectation (34 * 20)
allele_to_pseudo = {a: s for a, s in allele_to_pseudo_raw.items() if len(s) == 34}

coverage = (fullFrame['mhc_allele'].isin(allele_to_pseudo)).mean()
print(f"MHC pseudo coverage: {coverage*100:.1f}% of BA rows")


MHC pseudo coverage: 100.0% of BA rows


In [4]:
# Canonicalize allele names to improve pseudo coverage and matching
import re

def canonicalize_allele(a: str) -> str:
    if not isinstance(a, str):
        return ""
    return re.sub(r"[^A-Za-z0-9]", "", a).lower()

# Build canonical map for 34-aa pseudos
canon_to_pseudo_all = {canonicalize_allele(a): s for a, s in allele_to_pseudo_raw.items()}
canon_to_pseudo_34 = {k: v for k, v in canon_to_pseudo_all.items() if len(v) == 34}

# Compare exact vs canonical coverage
exact_cov = (fullFrame['mhc_allele'].isin(allele_to_pseudo)).mean()
canon_cov = fullFrame['mhc_allele'].map(lambda a: canonicalize_allele(a) in canon_to_pseudo_34).mean()
print(f"Exact-name coverage: {exact_cov*100:.1f}% | Canonical coverage: {canon_cov*100:.1f}%")


Exact-name coverage: 100.0% | Canonical coverage: 100.0%


In [5]:
# Encoding utilities (index-based, GPU one-hot later)
AMINO_ACIDS = list("ACDEFGHIKLMNPQRSTVWY")  # 20 standard AAs in fixed order
AA_TO_INDEX = {aa: i for i, aa in enumerate(AMINO_ACIDS)}
PAD_INDEX = 20  # reserved padding index (not used by actual amino acids)


def seq_to_index_tensor(seq: str):
    import torch
    idxs = [AA_TO_INDEX.get(aa, 0) for aa in seq]
    return torch.tensor(idxs, dtype=torch.long)


In [6]:
# Torch Dataset for BA data (index-based, pickle-safe for multiprocessing)
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence

class BADataset(Dataset):
    def __init__(self, df, allele_to_pseudo):
        # Filter rows to those we have pseudo sequences for
        self.df = df[df['mhc_allele'].isin(allele_to_pseudo)].reset_index(drop=True)
        self.targets = self.df['binding_affinity'].astype('float32').to_numpy()
        self.peptides = self.df['peptide'].tolist()
        self.alleles = self.df['mhc_allele'].tolist()
        # Precompute pseudo indices per unique allele as plain Python lists (avoid torch tensors here)
        unique_alleles = sorted(set(self.alleles))
        self.allele_to_pseudo_idx_list = {
            a: [AA_TO_INDEX.get(aa, 0) for aa in allele_to_pseudo[a]] for a in unique_alleles
        }

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        import torch
        # Convert to tensors inside worker to avoid sharing storages across processes
        pep_idx = torch.tensor([AA_TO_INDEX.get(aa, 0) for aa in self.peptides[idx]], dtype=torch.long)
        pseudo_idx = torch.tensor(self.allele_to_pseudo_idx_list[self.alleles[idx]], dtype=torch.long)
        y = self.targets[idx]
        return pep_idx, pseudo_idx, y


def collate_batch(batch):
    import torch
    pep_list, pseudo_list, y_list = zip(*batch)
    pep_padded = pad_sequence(pep_list, batch_first=True, padding_value=PAD_INDEX)  # (B, Lmax)
    pseudo_idx = torch.stack(pseudo_list, dim=0)  # (B, 34)
    y = torch.tensor(y_list, dtype=torch.float32).unsqueeze(-1)  # (B,1)
    return pep_padded, pseudo_idx, y


dataset = BADataset(fullFrame, allele_to_pseudo)
len(dataset)


208093

In [7]:
# Train/val split and DataLoaders
from sklearn.model_selection import train_test_split
import os
from torch.utils.data import Subset, DataLoader

train_idx, val_idx = train_test_split(range(len(dataset)), test_size=0.1, random_state=42)

train_ds = Subset(dataset, train_idx)
val_ds = Subset(dataset, val_idx)

batch_size = 512  # adjust higher if GPU memory allows

# Stable, single-process loaders to avoid shared-memory mmap errors in containers
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_batch,
                          num_workers=0, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_batch,
                        num_workers=0, pin_memory=True)

len(train_ds), len(val_ds)


(187283, 20810)

In [8]:
if torch.cuda.is_available():
    print('cuda')
else:
    print('cpu')

cuda


In [9]:
# Instantiate model and simple training loop
import torch
import torch.nn.functional as F

# Enable cuDNN autotune and TF32 on Ampere/ADA
torch.backends.cudnn.benchmark = True
try:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
except Exception:
    pass

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = BabyNetMHCpan().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = torch.nn.MSELoss()
scaler = torch.cuda.amp.GradScaler(enabled=device.type == 'cuda')


def to_one_hot_on_gpu(pep_idx, pseudo_idx):
    # pep_idx: (B, Lmax) long with PAD_INDEX padding
    pep_oh = F.one_hot(pep_idx, num_classes=PAD_INDEX + 1)[..., :20]  # (B, Lmax, 20)
    pep_oh = pep_oh.permute(0, 2, 1).float()  # (B, 20, Lmax)
    pseudo_oh = F.one_hot(pseudo_idx, num_classes=20).float()  # (B, 34, 20)
    return pep_oh, pseudo_oh


def evaluate(loader):
    model.eval()
    loss_sum, n = 0.0, 0
    with torch.no_grad():
        for pep_idx, pseudo_idx, y in loader:
            pep_idx = pep_idx.to(device, non_blocking=True)
            pseudo_idx = pseudo_idx.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)
            pep_oh, pseudo_oh = to_one_hot_on_gpu(pep_idx, pseudo_idx)
            with torch.cuda.amp.autocast(enabled=device.type == 'cuda'):
                y_pred = model(pep_oh, pseudo_oh)
                loss = loss_fn(y_pred, y)
            loss_sum += loss.item() * y.size(0)
            n += y.size(0)
    return loss_sum / max(n, 1)


epochs = 20
for epoch in range(1, epochs + 1):
    model.train()
    for pep_idx, pseudo_idx, y in train_loader:
        pep_idx = pep_idx.to(device, non_blocking=True)
        pseudo_idx = pseudo_idx.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)
        pep_oh, pseudo_oh = to_one_hot_on_gpu(pep_idx, pseudo_idx)

        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=device.type == 'cuda'):
            y_pred = model(pep_oh, pseudo_oh)
            loss = loss_fn(y_pred, y)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
    val_loss = evaluate(val_loader)
    print(f"Epoch {epoch} | val MSE: {val_loss:.4f}")


  scaler = torch.cuda.amp.GradScaler(enabled=device.type == 'cuda')
  with torch.cuda.amp.autocast(enabled=device.type == 'cuda'):
  with torch.cuda.amp.autocast(enabled=device.type == 'cuda'):


Epoch 1 | val MSE: 0.0703
Epoch 2 | val MSE: 0.0639
Epoch 3 | val MSE: 0.0575
Epoch 4 | val MSE: 0.0527
Epoch 5 | val MSE: 0.0488
Epoch 6 | val MSE: 0.0465
Epoch 7 | val MSE: 0.0453
Epoch 8 | val MSE: 0.0443
Epoch 9 | val MSE: 0.0434
Epoch 10 | val MSE: 0.0427
Epoch 11 | val MSE: 0.0419
Epoch 12 | val MSE: 0.0422
Epoch 13 | val MSE: 0.0416
Epoch 14 | val MSE: 0.0413
Epoch 15 | val MSE: 0.0404
Epoch 16 | val MSE: 0.0403
Epoch 17 | val MSE: 0.0400
Epoch 18 | val MSE: 0.0398
Epoch 19 | val MSE: 0.0394
Epoch 20 | val MSE: 0.0394


In [10]:
# Sanity-check one forward pass
pep_idx, pseudo_idx, y = next(iter(train_loader))
with torch.no_grad():
    pep_oh, pseudo_oh = to_one_hot_on_gpu(pep_idx.to(device, non_blocking=True),
                                          pseudo_idx.to(device, non_blocking=True))
    out = model(pep_oh, pseudo_oh)
print(out.shape, out[:5].squeeze().cpu().numpy())

torch.Size([512, 1]) [0.04793084 0.03500205 0.27789724 0.07731942 0.9229736 ]


In [11]:
# Validation metrics (Pearson/Spearman) on the val set (SciPy optional)
import numpy as np

try:
    from scipy.stats import pearsonr, spearmanr  # type: ignore
    use_scipy = True
except Exception:
    use_scipy = False

model.eval()
all_y, all_pred = [], []
with torch.no_grad():
    for pep_idx, pseudo_idx, y in val_loader:
        y_np = y.numpy().ravel()
        pep_idx = pep_idx.to(device, non_blocking=True)
        pseudo_idx = pseudo_idx.to(device, non_blocking=True)
        pep_oh, pseudo_oh = to_one_hot_on_gpu(pep_idx, pseudo_idx)
        pred_np = model(pep_oh, pseudo_oh).cpu().numpy().ravel()
        all_y.append(y_np)
        all_pred.append(pred_np)
all_y = np.concatenate(all_y)
all_pred = np.concatenate(all_pred)

if use_scipy:
    p = pearsonr(all_y, all_pred)[0]
    s = spearmanr(all_y, all_pred)[0]
else:
    # Pearson without SciPy
    y_mean = all_y.mean(); p_mean = all_pred.mean()
    num = ((all_y - y_mean) * (all_pred - p_mean)).sum()
    den = np.sqrt(((all_y - y_mean) ** 2).sum()) * np.sqrt(((all_pred - p_mean) ** 2).sum())
    p = num / (den + 1e-12)
    # Spearman via rank correlation
    y_rank = all_y.argsort().argsort().astype(float)
    p_rank = all_pred.argsort().argsort().astype(float)
    yrm = y_rank.mean(); prm = p_rank.mean()
    s_num = ((y_rank - yrm) * (p_rank - prm)).sum()
    s_den = np.sqrt(((y_rank - yrm) ** 2).sum()) * np.sqrt(((p_rank - prm) ** 2).sum())
    s = s_num / (s_den + 1e-12)

print(f"Pearson: {p:.3f} | Spearman: {s:.3f}")


Pearson: 0.715 | Spearman: 0.656


In [12]:
# Save model helper and quick inference function
import torch
from pathlib import Path

ckpt_path = Path('tmp/baby_netmhcpan.pt')
ckpt_path.parent.mkdir(parents=True, exist_ok=True)

def save_model(model, path=ckpt_path):
    torch.save(model.state_dict(), path)
    print(f'Saved to {path}')


def predict_peptide_allele(peptide: str, allele: str):
    akey = canonicalize_allele(allele)
    if akey not in canon_to_pseudo_34:
        raise ValueError(f'No pseudo for allele {allele}')
    pep_idx = seq_to_index_tensor(peptide).unsqueeze(0).to(device)
    pseudo_idx = seq_to_index_tensor(canon_to_pseudo_34[akey]).unsqueeze(0).to(device)
    pep_oh, pseudo_oh = to_one_hot_on_gpu(pep_idx, pseudo_idx)
    with torch.no_grad(), torch.cuda.amp.autocast(enabled=device.type == 'cuda'):
        pred = model(pep_oh, pseudo_oh).item()
    return pred

# Example usage (after training):
# save_model(model)
# predict_peptide_allele('SIINFEKL', 'H-2-Kb')
