In [None]:
import pandas as pd
import numpy as np
import copy
from math import ceil
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [None]:
df_pos = pd.read_csv("/content/df_pos.csv")   # you get df_pos from running notebook 01_therasabdab_preprocessing.ipynb
df_unl = pd.read_csv("/content/df_unl.csv")   # you get df_unl from running notebook 02_oas_preprocessing.ipynb

In [None]:
df_pos

Unnamed: 0,Therapeutic,HeavySequence,label,cdr3_aa,v_family
0,abelacimab,QVQLLESGGGLVQPGGSLRLSCAASGFTFSTAAMSWVRQAPGKGLE...,1,ARELSYLYSGYYFDY,IGHV3
1,adalimumab,EVQLVESGGGLVQPGRSLRLSCAASGFTFDDYAMHWVRQAPGKGLE...,1,AKVSYLSTASSLDY,IGHV3
2,afimkibart,QVQLVQSGAEVKKPGASVKVSCKASGYDFTYYGISWVRQAPGQGLE...,1,ARENYYGSGAYRGGMDV,IGHV1
3,alemtuzumab,QVQLQESGPGLVRPSQTLSLTCTVSGFTFTDFYMNWVRQPPGRGLE...,1,AREGHTAAPFDY,IGHV4
4,alirocumab,EVQLVESGGGLVQPGGSLRLSCAASGFTFNNYAMNWVRQAPGKGLD...,1,AKDSNWGNFDL,IGHV3
...,...,...,...,...,...
177,xeligekimab,EVQLVESGGGLVQPGGSLRLSCAASGMSMSDYWMNWVRQAPGKGLE...,1,VRDYYDLISDYYIHYWYFDL,IGHV3
178,zamerovimab,QVQLVQSGAEVKKPGASVKVSCKASGYSFTDYIMLWVRQAPGQRLE...,1,ARQGGDGNYVLFDY,IGHV1
179,zigakibart,QVQLVQSGAEVKKPGASVKVSCKASGYTFTSYVMHWVRQAPGQGLE...,1,ARGLGYALYYAMDY,IGHV1
180,ziltivekimab,EVQLVESGGGLVQPGGSLRLSCAASGFTISSNYMIWVRQAPGKGLE...,1,ARWADDHPPWIDL,IGHV3


In [None]:
df_unl

Unnamed: 0,seq_id,sequence_alignment_aa,label,cdr3_aa,v_family
0,OAS_0,SVKVSCKASGYTFSNYHIHWVRQAPGQGLEWMGIVTPGSGATSYAE...,0,ARETPATGYGDH,IGHV1
1,OAS_1,SVKVSCKASGYTFSNYGISWVRQAPGQGLEWMGWISAYNGDTNYAQ...,0,ARDYRLLDQMLVIDAFDI,IGHV1
2,OAS_2,ASVKVSCKASGYTFTSYGIGWVRQAPGQGLEWMGWISAYNGNTNYA...,0,ERVPATGALWNFDY,IGHV1
3,OAS_3,SVKVSCKASGGTFSSYVINWVRQAPGQGLEWMGRIIPFSGTTNYAQ...,0,AREAVGATFAF,IGHV1
4,OAS_4,QVQLVQSGAEVKKPGASVTVSCQASGYTFSHYALHWVRQAPGQSLE...,0,ARGELYLDS,IGHV1
...,...,...,...,...,...
1994,OAS_1994,SVKVPCKASGYTLSDYALNWVRQAPGQGLEWMGWLNTITGTPTYEQ...,0,ATKGHCSGDGCPGWYV,IGHV7
1995,OAS_1995,ASVKVSCKASGYTFTTYGINWVRQAPGQGLEWTGWINTNTGKPTFA...,0,AHMTPNHSL,IGHV7
1996,OAS_1996,ASVKVSCKASGYKFTNYPMNWVRQAPGHGPEWMGWINTYYGNPTYA...,0,AIGYSFGSVGDGFDR,IGHV7
1997,OAS_1997,QVQLVQSGSELKKPGASVKVSCKASGYAFTRYPMNWVRQAPGQGLE...,0,ARERGYGYLHLDY,IGHV7


In [None]:
SEQ_COL_POS = "HeavySequence"
SEQ_COL_UNL = "sequence_alignment_aa"
VFAM_COL    = "v_family"
MASK_N      = 8                        # number of N-terminal residues to mask
MASK_CHAR   = "X"

def mask_first_n(seq: str, n: int = 8, mask_char: str = "X") -> str:
    """
    Replace the first n amino acids with mask_char.
    If sequence is shorter than n, mask the whole thing.
    """
    if pd.isna(seq):
        return seq
    s = str(seq).strip().upper()
    L = len(s)
    if L <= n:
        return mask_char * L
    return mask_char * n + s[n:]

# apply to positives
df_pos = df_pos.copy()
df_pos[SEQ_COL_POS] = df_pos[SEQ_COL_POS].astype(str).str.strip().str.upper()
df_pos["seq_masked"] = df_pos[SEQ_COL_POS].apply(lambda s: mask_first_n(s, MASK_N, MASK_CHAR))

# apply to unlabeled
df_unl = df_unl.copy()
df_unl[SEQ_COL_UNL] = df_unl[SEQ_COL_UNL].astype(str).str.strip().str.upper()
df_unl["seq_masked"] = df_unl[SEQ_COL_UNL].apply(lambda s: mask_first_n(s, MASK_N, MASK_CHAR))

print("Example masked positive:", df_pos["seq_masked"].iloc[0])
print("Example masked unlabeled:", df_unl["seq_masked"].iloc[0])

Example masked positive: XXXXXXXXGGLVQPGGSLRLSCAASGFTFSTAAMSWVRQAPGKGLEWVSGISGSGSSTYYADSVKGRFTISRDNSKNTLYLQMNSLRAEDTAVYYCARELSYLYSGYYFDYWGQGTLVTVSS
Example masked unlabeled: XXXXXXXXSGYTFSNYHIHWVRQAPGQGLEWMGIVTPGSGATSYAEKFQGRVIMTGDMSTTTAFLERSSLRSDDTALYYCARETPATGYGDHWGQGTLVPVSS


In [None]:
df_pos.head()

Unnamed: 0,Therapeutic,HeavySequence,label,cdr3_aa,v_family,seq_masked
0,abelacimab,QVQLLESGGGLVQPGGSLRLSCAASGFTFSTAAMSWVRQAPGKGLE...,1,ARELSYLYSGYYFDY,IGHV3,XXXXXXXXGGLVQPGGSLRLSCAASGFTFSTAAMSWVRQAPGKGLE...
1,adalimumab,EVQLVESGGGLVQPGRSLRLSCAASGFTFDDYAMHWVRQAPGKGLE...,1,AKVSYLSTASSLDY,IGHV3,XXXXXXXXGGLVQPGRSLRLSCAASGFTFDDYAMHWVRQAPGKGLE...
2,afimkibart,QVQLVQSGAEVKKPGASVKVSCKASGYDFTYYGISWVRQAPGQGLE...,1,ARENYYGSGAYRGGMDV,IGHV1,XXXXXXXXAEVKKPGASVKVSCKASGYDFTYYGISWVRQAPGQGLE...
3,alemtuzumab,QVQLQESGPGLVRPSQTLSLTCTVSGFTFTDFYMNWVRQPPGRGLE...,1,AREGHTAAPFDY,IGHV4,XXXXXXXXPGLVRPSQTLSLTCTVSGFTFTDFYMNWVRQPPGRGLE...
4,alirocumab,EVQLVESGGGLVQPGGSLRLSCAASGFTFNNYAMNWVRQAPGKGLD...,1,AKDSNWGNFDL,IGHV3,XXXXXXXXGGLVQPGGSLRLSCAASGFTFNNYAMNWVRQAPGKGLD...


In [None]:
# CDR3-based grouping + stratified splitting for unlabeled (OAS)

SEQ_COL_UNL   = "seq_masked"
VFAM_COL      = "v_family"
CDR3_COL_UNL  = "cdr3_aa"
TRAIN_FRAC    = 0.7
VAL_FRAC      = 0.15
SEED   = 59

unl = df_unl.copy()

# sanity: drop rows with missing cdr3 or v_family
unl = unl.dropna(subset=[CDR3_COL_UNL, VFAM_COL, SEQ_COL_UNL])

# Build a "clone table": one row per (family, cdr3), with count
clone_table = (
    unl
    .groupby([VFAM_COL, CDR3_COL_UNL])
    .size()
    .reset_index(name="n_seq")
)

print("Number of unique clones in unlabeled:", len(clone_table))

# Assign each clone to train / val / test *within each family*
rng = np.random.default_rng(SEED)



def assign_clones_within_family(fam_df: pd.DataFrame) -> pd.DataFrame:
    """
    fam_df: rows for a single V-family in clone_table
    Returns fam_df with an extra 'split' column (train/val/test).
    """
    df = fam_df.sample(frac=1.0, random_state=SEED)
    total = df["n_seq"].sum()
    train_target = TRAIN_FRAC * total
    val_target   = (TRAIN_FRAC + VAL_FRAC) * total

    cum = 0
    splits = []
    for _, row in df.iterrows():
        if cum < train_target:
            s = "train"
        elif cum < val_target:
            s = "val"
        else:
            s = "test"
        splits.append(s)
        cum += row["n_seq"]

    df = df.copy()
    df["split"] = splits
    return df

# apply per family
split_clone_table_list = []
for fam, fam_df in clone_table.groupby(VFAM_COL):
    split_clone_table_list.append(assign_clones_within_family(fam_df))

split_clone_table = pd.concat(split_clone_table_list, ignore_index=True)

print(split_clone_table["split"].value_counts())

# Merge the split info back onto the full unlabeled dataframe
unl = unl.merge(
    split_clone_table[[VFAM_COL, CDR3_COL_UNL, "split"]],
    on=[VFAM_COL, CDR3_COL_UNL],
    how="left"
)

# Sanity: check no NA splits
assert unl["split"].isna().sum() == 0, "Some unlabeled rows got no split!"

print("Unlabeled counts per split:")
print(unl["split"].value_counts())


Number of unique clones in unlabeled: 1781
split
train    1253
val       268
test      260
Name: count, dtype: int64
Unlabeled counts per split:
split
train    1403
val       301
test      295
Name: count, dtype: int64


In [None]:
CDR3_COL_POS = "cdr3_aa"
SEQ_COL_POS  = "seq_masked"

pos = df_pos.copy()

# build clone table for positives
pos_clone_table = (
    pos
    .dropna(subset=[CDR3_COL_POS])
    .groupby([VFAM_COL, CDR3_COL_POS])
    .size()
    .reset_index(name="n_seq")
)

print("Number of unique positive clones:", len(pos_clone_table))

# use the SAME logic to split clones within each family
split_pos_clones_list = []
for fam, fam_df in pos_clone_table.groupby(VFAM_COL):
    split_pos_clones_list.append(assign_clones_within_family(fam_df))

split_pos_clones = pd.concat(split_pos_clones_list, ignore_index=True)

# merge back to positives
pos = pos.merge(
    split_pos_clones[[VFAM_COL, CDR3_COL_POS, "split"]],
    on=[VFAM_COL, CDR3_COL_POS],
    how="left"
)

# some rows might have missing CDR3 (if any) – default them to train
pos["split"] = pos["split"].fillna("train")


print("Positives counts per split:")
print(pos["split"].value_counts())

Number of unique positive clones: 180
Positives counts per split:
split
train    132
test      25
val       25
Name: count, dtype: int64


In [None]:
# splitting the data

USE_COLS = ["seq_masked", VFAM_COL, "label", "split"]

pos_small = pos[USE_COLS].copy()
unl_small = unl[USE_COLS].copy()

def get_split_df(pos_df, unl_df, split_name: str, seed: int = SEED):
    df = pd.concat(
        [
            pos_df[pos_df["split"] == split_name],
            unl_df[unl_df["split"] == split_name]
        ],
        ignore_index=True
    )
    return df.sample(frac=1.0, random_state=seed).reset_index(drop=True)

train_df = get_split_df(pos_small, unl_small, "train")
val_df   = get_split_df(pos_small, unl_small, "val")
test_df  = get_split_df(pos_small, unl_small, "test")

print("Label counts:")
print("  train:", train_df["label"].value_counts().to_dict())
print("  val:  ", val_df["label"].value_counts().to_dict())
print("  test: ", test_df["label"].value_counts().to_dict())

print("\nTrain V-family distribution:")
print(train_df[VFAM_COL].value_counts(normalize=True))

print("\nVal V-family distribution:")
print(val_df[VFAM_COL].value_counts(normalize=True))

print("\nTest V-family distribution:")
print(test_df[VFAM_COL].value_counts(normalize=True))

Label counts:
  train: {0: 1403, 1: 132}
  val:   {0: 301, 1: 25}
  test:  {0: 295, 1: 25}

Train V-family distribution:
v_family
IGHV3    0.456678
IGHV1    0.319218
IGHV4    0.110098
IGHV5    0.044951
IGHV2    0.038436
IGHV7    0.017590
IGHV6    0.012378
IGHV8    0.000651
Name: proportion, dtype: float64

Val V-family distribution:
v_family
IGHV3    0.463190
IGHV1    0.322086
IGHV4    0.110429
IGHV5    0.039877
IGHV2    0.039877
IGHV7    0.015337
IGHV6    0.009202
Name: proportion, dtype: float64

Test V-family distribution:
v_family
IGHV3    0.46250
IGHV1    0.32500
IGHV4    0.11250
IGHV5    0.04375
IGHV2    0.03750
IGHV7    0.01250
IGHV6    0.00625
Name: proportion, dtype: float64


In [None]:
# Define vocabulary + tokenizer for amino acids

# 20 canonical AA + X (mask)
AMINO_ACIDS = list("ACDEFGHIKLMNPQRSTVWY") + ["X"]

aa_to_idx = {aa: i + 1 for i, aa in enumerate(AMINO_ACIDS)}
PAD_IDX = 0

def seq_to_indices(seq: str):
    """Convert an AA string to list of integer indices."""
    seq = str(seq).strip().upper()
    return [aa_to_idx.get(aa, aa_to_idx["X"]) for aa in seq]    #If aa is anything else (e.g., B, Z, *, etc.) → also map to index for "X"

In [None]:
# Dataset class + collate function

class AntibodyDataset(Dataset):
    def __init__(self, df: pd.DataFrame, seq_col: str = "seq_masked", label_col: str = "label"):
        self.seqs = df[seq_col].tolist()
        self.labels = df[label_col].astype(float).values

    def __len__(self):
        return len(self.seqs)

    def __getitem__(self, idx):
        seq = self.seqs[idx]
        label = self.labels[idx]
        idx_seq = seq_to_indices(seq)
        return torch.tensor(idx_seq, dtype=torch.long), torch.tensor(label, dtype=torch.float32)      #



def collate_batch(batch):
    """
    batch: list of (seq_tensor_1d, label_tensor)
    We pad sequences to the max length in this batch.
    """
    seqs, labels = zip(*batch)

    lengths = torch.tensor([len(s) for s in seqs], dtype=torch.long)
    max_len = int(lengths.max())

    padded = torch.full((len(seqs), max_len), PAD_IDX, dtype=torch.long)


    for i, s in enumerate(seqs):
        L = len(s)
        padded[i, :L] = s

    labels = torch.stack(labels)
    return padded, lengths, labels


In [None]:
# Create the datasets + dataloaders:

BATCH_SIZE = 32

train_dataset = AntibodyDataset(train_df)
val_dataset   = AntibodyDataset(val_df)
test_dataset  = AntibodyDataset(test_df)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                          collate_fn=collate_batch)
val_loader   = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                          collate_fn=collate_batch)
test_loader  = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False,
                           collate_fn=collate_batch)


In [None]:
# # Build a small CNN model with embedding

class CNNClassifier(nn.Module):
    def __init__(self,
                 vocab_size: int,
                 embed_dim: int = 64,
                 num_filters: int = 128,
                 kernel_sizes=(3, 5, 7),
                 dropout: float = 0.3):
        super().__init__()
        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=embed_dim,
            padding_idx=PAD_IDX
        )

        self.convs = nn.ModuleList([
            nn.Conv1d(in_channels=embed_dim,
                      out_channels=num_filters,
                      kernel_size=k)
            for k in kernel_sizes
        ])


        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(num_filters * len(kernel_sizes), 1)

    def forward(self, x):
        """
        x: (batch, seq_len) with token indices
        returns logits: (batch,)
        """
        emb = self.embedding(x)
        emb = emb.transpose(1, 2)

        conv_outs = []
        for conv in self.convs:
            c = conv(emb)
            c = torch.relu(c)
            # global max pool over L'
            pooled, _ = torch.max(c, dim=2)
            conv_outs.append(pooled)

        h = torch.cat(conv_outs, dim=1)
        h = self.dropout(h)
        logits = self.fc(h).squeeze(1)
        return logits


In [None]:
vocab_size = len(aa_to_idx) + 1  # +1 for padding index 0
model = CNNClassifier(vocab_size=vocab_size, embed_dim=64, num_filters=128, kernel_sizes=(3, 5, 7), dropout=0.3)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(device)

cpu


In [None]:
# Loss function and Optimizer


num_pos = (train_df["label"] == 1).sum()
num_neg = (train_df["label"] == 0).sum()
print("Train positives:", num_pos, "negatives:", num_neg)

pos_weight_value = num_neg / max(num_pos, 1)
print("pos_weight =", pos_weight_value)

pos_weight = torch.tensor(pos_weight_value, dtype=torch.float32, device=device)

criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)


Train positives: 132 negatives: 1403
pos_weight = 10.628787878787879


In [None]:
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# define train/test/validate function

def run_epoch(model, loader, optimizer=None):
    """
    If optimizer is None → evaluation mode.
    Otherwise → training mode.
    """
    if optimizer is None:
        model.eval()
    else:
        model.train()

    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    for batch in loader:
        x, lengths, y = batch
        x = x.to(device)
        y = y.to(device)

        logits = model(x)
        loss = criterion(logits, y)

        if optimizer is not None:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        total_loss += loss.item() * x.size(0)

        # compute accuracy just as a rough metric
        with torch.no_grad():
            probs = torch.sigmoid(logits)
            preds = (probs >= 0.5).float()
            total_correct += (preds == y).sum().item()
            total_samples += x.size(0)

    avg_loss = total_loss / max(total_samples, 1)
    avg_acc  = total_correct / max(total_samples, 1)
    return avg_loss, avg_acc



# Train and validate the model

EPOCHS = 15

best_val_loss = float("inf")
best_state = None

for epoch in range(1, EPOCHS + 1):
    train_loss, train_acc = run_epoch(model, train_loader, optimizer)
    val_loss, val_acc     = run_epoch(model, val_loader, optimizer=None)

    print(f"Epoch {epoch:02d} | "
          f"train loss: {train_loss:.4f}, acc: {train_acc:.3f} | "
          f"val loss: {val_loss:.4f}, acc: {val_acc:.3f}")

    # simple early-stopping on val loss                  # keep the best-performing model on validation set, even if later epochs overfit.
    if val_loss < best_val_loss:                         # tracks lowest validation loss so far
        best_val_loss = val_loss
        best_state = copy.deepcopy(model.state_dict())
        best_epoch = epoch

# load best model
if best_state is not None:
    model.load_state_dict(best_state)

Epoch 01 | train loss: 1.0856, acc: 0.681 | val loss: 0.6812, acc: 0.718
Epoch 02 | train loss: 0.6579, acc: 0.821 | val loss: 0.7224, acc: 0.709
Epoch 03 | train loss: 0.5793, acc: 0.824 | val loss: 0.4849, acc: 0.822
Epoch 04 | train loss: 0.4498, acc: 0.869 | val loss: 0.4212, acc: 0.813
Epoch 05 | train loss: 0.3992, acc: 0.883 | val loss: 0.3918, acc: 0.831
Epoch 06 | train loss: 0.3786, acc: 0.885 | val loss: 0.3595, acc: 0.871
Epoch 07 | train loss: 0.3433, acc: 0.906 | val loss: 0.3416, acc: 0.865
Epoch 08 | train loss: 0.3209, acc: 0.913 | val loss: 0.3544, acc: 0.883
Epoch 09 | train loss: 0.2608, acc: 0.930 | val loss: 0.3318, acc: 0.887
Epoch 10 | train loss: 0.2387, acc: 0.934 | val loss: 0.4006, acc: 0.951
Epoch 11 | train loss: 0.2091, acc: 0.943 | val loss: 0.4184, acc: 0.966
Epoch 12 | train loss: 0.2127, acc: 0.945 | val loss: 0.3555, acc: 0.868
Epoch 13 | train loss: 0.1974, acc: 0.954 | val loss: 0.5584, acc: 0.810
Epoch 14 | train loss: 0.2542, acc: 0.937 | val los

In [None]:
# Test the model

test_loss, test_acc = run_epoch(model, test_loader, optimizer=None)
print(f"Test loss: {test_loss:.4f}, acc: {test_acc:.3f}, best epoch:{best_epoch}")


Test loss: 0.2968, acc: 0.934, best epoch:9


In [None]:
# Calculate majorty baseline accuracy

test_counts = test_df["label"].value_counts()
print(test_counts)
baseline_acc = test_counts.max() / test_counts.sum()
print("Majority-class baseline acc:", baseline_acc)

label
0    295
1     25
Name: count, dtype: int64
Majority-class baseline acc: 0.921875


In [None]:
# Computes two evaluation metrics: AUROC and AUPRC

from sklearn.metrics import roc_auc_score, average_precision_score
import numpy as np

def collect_preds(model, loader):
    model.eval()
    all_probs = []
    all_labels = []

    with torch.no_grad():
        for x, lengths, y in loader:
            x = x.to(device)
            y = y.to(device)
            logits = model(x)
            probs = torch.sigmoid(logits)
            all_probs.append(probs.cpu().numpy())
            all_labels.append(y.cpu().numpy())

    all_probs = np.concatenate(all_probs)
    all_labels = np.concatenate(all_labels)
    return all_labels, all_probs

y_true, y_prob = collect_preds(model, test_loader)

print("Test AUROC:", roc_auc_score(y_true, y_prob))
print("Test AUPRC:", average_precision_score(y_true, y_prob))

Test AUROC: 0.9848135593220338
Test AUPRC: 0.879226344974023


In [None]:
# # saving them model parameters


# from pathlib import Path

# # Create models directory
# MODEL_PATH = Path("models")
# MODEL_PATH.mkdir(parents=True, exist_ok=True)

# # Create model save path
# MODEL_NAME = "model_00"
# MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME

# # Save the model state dict
# print(f"Saving model to: {MODEL_SAVE_PATH}")
# torch.save(obj=model.state_dict(),