### Imports

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import pandas as pd
import json
from pathlib import Path
import numpy as np
import random

### Constants

In [2]:
# Paths
basedir = Path("/Users/tusharsingh/Work/Project/DL-cdr3-tumor")
jsonl_file = basedir / "processed" / "cdr3_tumor_normal.jsonl"
# Constants
MAX_SEQ_LEN = 22
MAX_CDR3_PER_PATIENT = 20
VOCAB_SIZE = 22  # 20 AAs + PAD + UNK
EMBEDDING_DIM = 32
BATCH_SIZE = 32
EPOCHS = 100
LEARNING_RATE = 0.1

### Dataset Class

In [3]:
class PatientCDR3Dataset(Dataset):
    def __init__(self, jsonl_path):
        self.samples = []
        with open(jsonl_path, 'r') as f:
            for line in f:
                entry = json.loads(line)
                self.samples.append(entry)

        self.label_map = {"tumor": 1, "normal": 0}

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        x = np.array(self.samples[idx]["cdr3s"], dtype=np.int64)
        y = self.label_map[self.samples[idx]["label"]]
        return torch.tensor(x), torch.tensor(y)


### Mean Pooling model 

In [4]:
class MeanPoolModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(MeanPoolModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.fc = nn.Linear(embedding_dim, 1)
    
    def forward(self, x):
        # x shape: [batch_size, num_cdr3s, seq_len]
        bsz, n_seqs, seq_len = x.shape
        x = x.view(bsz * n_seqs, seq_len)

        # Embed
        embedded = self.embedding(x)  # [bsz * n_seqs, seq_len, embed_dim]

        # Mean pool over sequence
        mean_seq = embedded.mean(dim=1)  # [bsz * n_seqs, embed_dim]

        # Reshape back to [batch_size, n_seqs, embed_dim]
        mean_seq = mean_seq.view(bsz, n_seqs, -1)

        # Mean pool across CDR3s
        pooled = mean_seq.mean(dim=1)  # [batch_size, embed_dim]

        out = self.fc(pooled)  # [batch_size, 1]
        return torch.sigmoid(out).squeeze()


### Train and Evaluate Functions

In [5]:
def train_epoch(model, loader, criterion, optimizer):
    model.train()
    total_loss = 0
    correct = 0
    for x, y in loader:
        optimizer.zero_grad()
        output = model(x)
        loss = criterion(output, y.float())
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        preds = (output > 0.5).int()
        correct += (preds == y).sum().item()
    
    return total_loss / len(loader), correct / len(loader.dataset)

def evaluate(model, loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for x, y in loader:
            output = model(x)
            preds = (output > 0.5).int()
            correct += (preds == y).sum().item()
    return correct / len(loader.dataset)


### Training 

In [6]:
# Load dataset
full_dataset = PatientCDR3Dataset(jsonl_file)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_ds, val_ds = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE)

# Initialize model
model = MeanPoolModel(VOCAB_SIZE, EMBEDDING_DIM)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Train loop
best_val_acc = 0
for epoch in range(1, EPOCHS + 1):
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer)
    val_acc = evaluate(model, val_loader)
    print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}, Acc = {train_acc:.4f}")
    print(f"\tVal Acc = {val_acc:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), basedir / "mean_pool_best_model.pt")
        print("Best model saved")


Epoch 1: Train Loss = 0.6961, Acc = 0.4918
	Val Acc = 0.4579
Best model saved
Epoch 2: Train Loss = 0.6999, Acc = 0.5106
	Val Acc = 0.5234
Best model saved
Epoch 3: Train Loss = 0.7191, Acc = 0.4753
	Val Acc = 0.4579
Epoch 4: Train Loss = 0.6973, Acc = 0.4988
	Val Acc = 0.5327
Best model saved
Epoch 5: Train Loss = 0.6968, Acc = 0.5365
	Val Acc = 0.5421
Best model saved
Epoch 6: Train Loss = 0.6944, Acc = 0.5271
	Val Acc = 0.5327
Epoch 7: Train Loss = 0.6855, Acc = 0.5741
	Val Acc = 0.5701
Best model saved
Epoch 8: Train Loss = 0.6813, Acc = 0.5765
	Val Acc = 0.5327
Epoch 9: Train Loss = 0.6875, Acc = 0.5482
	Val Acc = 0.4673
Epoch 10: Train Loss = 0.6820, Acc = 0.5812
	Val Acc = 0.5421
Epoch 11: Train Loss = 0.6805, Acc = 0.5765
	Val Acc = 0.5234
Epoch 12: Train Loss = 0.6881, Acc = 0.5647
	Val Acc = 0.5140
Epoch 13: Train Loss = 0.6980, Acc = 0.5035
	Val Acc = 0.5607
Epoch 14: Train Loss = 0.6899, Acc = 0.5600
	Val Acc = 0.4860
Epoch 15: Train Loss = 0.6837, Acc = 0.5741
	Val Acc = 0

This model used tokenized and padded CDR3 sequences, followed by a simple embedding and mean pooling across the 20 normalized CDR3s per patient. Despite its simplicity, it achieved moderate performance:

Final Validation Accuracy: 54.2%

Best Epoch Accuracy: 58.8% (Train), 54.2% (Validation)