In [2]:
import pandas as pd

files = ['data/netMHCpan_training_data/BA/c000_ba', 'data/netMHCpan_training_data/BA/c001_ba', 
        'data/netMHCpan_training_data/BA/c002_ba', 'data/netMHCpan_training_data/BA/c003_ba', 
        'data/netMHCpan_training_data/BA/c004_ba']
fullFrame = pd.DataFrame()
for file in files:
    df = pd.read_csv(file, sep=r'\s+', header=None, names=['peptide', 'binding_affinity', 'mhc_allele'])
    fullFrame = pd.concat([fullFrame, df], axis=0, ignore_index=True)

# ensure numeric target
fullFrame['binding_affinity'] = pd.to_numeric(fullFrame['binding_affinity'], errors='coerce')
fullFrame = fullFrame.dropna(subset=['binding_affinity']).reset_index(drop=True)

fullFrame.head()

Unnamed: 0,peptide,binding_affinity,mhc_allele
0,AAAAAAYAAM,0.177415,H-2-Db
1,AAAAAAYAAM,0.4631,H-2-Kb
2,AAAAFEAAL,0.362118,BoLA-3:00101
3,AAAAFEAAL,0.468035,BoLA-3:00201
4,AAAAFEAAL,0.522653,HLA-B48:01


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class BabyNetMHCpan(nn.Module):
    def __init__(self):
        super().__init__()
        # Peptide: one-hot encode → small CNN
        self.peptide_encoder = nn.Conv1d(20, 64, kernel_size=3)
        
        # MHC: pseudo-sequence → embedding
        self.mhc_encoder = nn.Linear(34 * 20, 64)  # Just flatten one-hot
        
        # Combine and predict
        self.predictor = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()  # Output in [0,1] for BA
        )
    
    def forward(self, peptide, mhc_pseudo):
        # The magic happens here
        p = F.relu(self.peptide_encoder(peptide))
        p = F.adaptive_max_pool1d(p, 1).squeeze(-1)

        if mhc_pseudo.dim() == 3:
            mhc_pseudo = mhc_pseudo.reshape(mhc_pseudo.size(0), -1)
        m = F.relu(self.mhc_encoder(mhc_pseudo))

        z = torch.cat([p, m], dim=1)
        return self.predictor(z)
    


In [4]:
# Parse MHC pseudo-sequences and compute coverage vs BA rows
from pathlib import Path

pseudo_path = Path('data/netMHCpan_training_data/MHC_pseudo.dat')
allele_to_pseudo_raw = {}
with pseudo_path.open() as f:
    for line in f:
        line = line.strip()
        if not line or line.startswith('#'):
            continue
        parts = line.split()
        if len(parts) < 2:
            continue
        allele, seq = parts[0], parts[1]
        allele_to_pseudo_raw[allele] = seq

# Keep only canonical 34-aa pseudo sequences to match model expectation (34 * 20)
allele_to_pseudo = {a: s for a, s in allele_to_pseudo_raw.items() if len(s) == 34}

coverage = (fullFrame['mhc_allele'].isin(allele_to_pseudo)).mean()
print(f"MHC pseudo coverage: {coverage*100:.1f}% of BA rows")


MHC pseudo coverage: 100.0% of BA rows


In [None]:
# Canonicalize allele names to improve pseudo coverage and matching
import re

def canonicalize_allele(a: str) -> str:
    if not isinstance(a, str):
        return ""
    return re.sub(r"[^A-Za-z0-9]", "", a).lower()

# Build canonical map for 34-aa pseudos
canon_to_pseudo_all = {canonicalize_allele(a): s for a, s in allele_to_pseudo_raw.items()}
canon_to_pseudo_34 = {k: v for k, v in canon_to_pseudo_all.items() if len(v) == 34}

# Compare exact vs canonical coverage
exact_cov = (fullFrame['mhc_allele'].isin(allele_to_pseudo)).mean()
canon_cov = fullFrame['mhc_allele'].map(lambda a: canonicalize_allele(a) in canon_to_pseudo_34).mean()
print(f"Exact-name coverage: {exact_cov*100:.1f}% | Canonical coverage: {canon_cov*100:.1f}%")


In [5]:
# Encoding utilities
import string

AMINO_ACIDS = list("ACDEFGHIKLMNPQRSTVWY")  # 20 standard AAs in fixed order
AA_TO_INDEX = {aa: i for i, aa in enumerate(AMINO_ACIDS)}


def one_hot_sequence(seq: str, vocab=AA_TO_INDEX, num_channels=20):
    # shape: (C, L) for Conv1d
    import torch
    L = len(seq)
    x = torch.zeros(num_channels, L, dtype=torch.float32)
    for pos, aa in enumerate(seq):
        idx = vocab.get(aa)
        if idx is not None:
            x[idx, pos] = 1.0
    return x


def one_hot_pseudo(seq: str, num_channels=20):
    # shape: (34, 20) then we will flatten to (34*20) or keep as (34,20)
    import torch
    L = len(seq)
    x = torch.zeros(L, num_channels, dtype=torch.float32)
    for pos, aa in enumerate(seq):
        idx = AA_TO_INDEX.get(aa)
        if idx is not None:
            x[pos, idx] = 1.0
    return x


In [6]:
# Torch Dataset for BA data
from torch.utils.data import Dataset, DataLoader

class BADataset(Dataset):
    def __init__(self, df, allele_to_pseudo):
        # Filter rows to those we have pseudo sequences for
        self.df = df[df['mhc_allele'].isin(allele_to_pseudo)].reset_index(drop=True)
        self.allele_to_pseudo = allele_to_pseudo

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        peptide = row['peptide']
        allele = row['mhc_allele']
        y = float(row['binding_affinity'])
        pep_oh = one_hot_sequence(peptide)  # (20, L)
        pseudo_oh = one_hot_pseudo(self.allele_to_pseudo[allele])  # (34, 20)
        return pep_oh, pseudo_oh, y


def collate_batch(batch):
    # Pad peptides to the max length in the batch on the right with zeros
    import torch
    pep_list, pseudo_list, y_list = zip(*batch)
    max_len = max(t.shape[1] for t in pep_list)
    padded_peps = []
    for t in pep_list:
        if t.shape[1] < max_len:
            pad = torch.zeros(t.shape[0], max_len - t.shape[1])
            t = torch.cat([t, pad], dim=1)
        padded_peps.append(t)
    pep_tensor = torch.stack(padded_peps, dim=0)  # (B, 20, Lmax)

    pseudo_tensor = torch.stack(pseudo_list, dim=0)  # (B, 34, 20)
    y_tensor = torch.tensor(y_list, dtype=torch.float32).unsqueeze(-1)  # (B,1)
    return pep_tensor, pseudo_tensor, y_tensor


dataset = BADataset(fullFrame, allele_to_pseudo)
len(dataset)


208093

In [7]:
# Train/val split and DataLoaders
from sklearn.model_selection import train_test_split

train_idx, val_idx = train_test_split(range(len(dataset)), test_size=0.1, random_state=42)

from torch.utils.data import Subset
train_ds = Subset(dataset, train_idx)
val_ds = Subset(dataset, val_idx)

batch_size = 256
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_batch, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_batch, num_workers=0)

len(train_ds), len(val_ds)


(187283, 20810)

In [8]:
# Instantiate model and simple training loop
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = BabyNetMHCpan().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = torch.nn.MSELoss()


def evaluate(loader):
    model.eval()
    loss_sum, n = 0.0, 0
    with torch.no_grad():
        for pep, pseudo, y in loader:
            pep = pep.to(device)
            pseudo = pseudo.to(device)
            y = y.to(device)
            y_pred = model(pep, pseudo)
            loss = loss_fn(y_pred, y)
            loss_sum += loss.item() * y.size(0)
            n += y.size(0)
    return loss_sum / max(n, 1)


epochs = 3
for epoch in range(1, epochs+1):
    model.train()
    for pep, pseudo, y in train_loader:
        pep = pep.to(device)
        pseudo = pseudo.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        y_pred = model(pep, pseudo)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()
    val_loss = evaluate(val_loader)
    print(f"Epoch {epoch} | val MSE: {val_loss:.4f}")


Epoch 1 | val MSE: 0.0656
Epoch 2 | val MSE: 0.0568
Epoch 3 | val MSE: 0.0531


In [None]:
# Sanity-check one forward pass
pep, pseudo, y = next(iter(train_loader))
print(pep.shape, pseudo.shape, y.shape)
with torch.no_grad():
    out = model(pep.to(device), pseudo.to(device))
print(out.shape, out[:5].squeeze().cpu().numpy())

torch.Size([256, 20, 12]) torch.Size([256, 34, 20]) torch.Size([256, 1])
torch.Size([256, 1]) [0.10367729 0.33692557 0.63325983 0.07783097 0.3745204 ]


In [None]:
# Validation metrics (Pearson/Spearman) on the val set (SciPy optional)
import numpy as np

try:
    from scipy.stats import pearsonr, spearmanr  # type: ignore
    use_scipy = True
except Exception:
    use_scipy = False

model.eval()
all_y, all_pred = [], []
with torch.no_grad():
    for pep, pseudo, y in val_loader:
        y = y.numpy().ravel()
        pred = model(pep.to(device), pseudo.to(device)).cpu().numpy().ravel()
        all_y.append(y)
        all_pred.append(pred)
all_y = np.concatenate(all_y)
all_pred = np.concatenate(all_pred)

if use_scipy:
    p = pearsonr(all_y, all_pred)[0]
    s = spearmanr(all_y, all_pred)[0]
else:
    # Pearson without SciPy
    y_mean = all_y.mean(); p_mean = all_pred.mean()
    num = ((all_y - y_mean) * (all_pred - p_mean)).sum()
    den = np.sqrt(((all_y - y_mean) ** 2).sum()) * np.sqrt(((all_pred - p_mean) ** 2).sum())
    p = num / (den + 1e-12)
    # Spearman via rank correlation
    y_rank = all_y.argsort().argsort().astype(float)
    p_rank = all_pred.argsort().argsort().astype(float)
    yrm = y_rank.mean(); prm = p_rank.mean()
    s_num = ((y_rank - yrm) * (p_rank - prm)).sum()
    s_den = np.sqrt(((y_rank - yrm) ** 2).sum()) * np.sqrt(((p_rank - prm) ** 2).sum())
    s = s_num / (s_den + 1e-12)

print(f"Pearson: {p:.3f} | Spearman: {s:.3f}")


In [None]:
# Save model helper and quick inference function
import torch
from pathlib import Path

ckpt_path = Path('tmp/baby_netmhcpan.pt')
ckpt_path.parent.mkdir(parents=True, exist_ok=True)

def save_model(model, path=ckpt_path):
    torch.save(model.state_dict(), path)
    print(f'Saved to {path}')


def predict_peptide_allele(peptide: str, allele: str):
    akey = canonicalize_allele(allele)
    if akey not in canon_to_pseudo_34:
        raise ValueError(f'No pseudo for allele {allele}')
    pep = one_hot_sequence(peptide).unsqueeze(0)  # (1, 20, L)
    pseudo = one_hot_pseudo(canon_to_pseudo_34[akey]).unsqueeze(0)  # (1, 34, 20)
    with torch.no_grad():
        pred = model(pep.to(device), pseudo.to(device)).item()
    return pred

# Example usage (after training):
# save_model(model)
# predict_peptide_allele('SIINFEKL', 'H-2-Kb')
