### Imports

In [8]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
import json
from pathlib import Path

### Constants

In [9]:
BATCH_SIZE = 32
EMBED_DIM = 32
HIDDEN_DIM = 64
NUM_EPOCHS = 90
MAX_SEQ_LEN = 22
NUM_CDR3S = 40
VOCAB_SIZE = 22  # 20 amino acids + <PAD>=0 + <UNK>=1
LABEL2IDX = {'tumor': 1, 'normal': 0}

# File paths
basedir = Path("/Users/tusharsingh/Work/Project/DL-cdr3-tumor")
jsonl_path = basedir / "processed" / "cdr3_tumor_normal.jsonl"

### Data Conversion

In [10]:
class CDR3Dataset(Dataset):
    def __init__(self, jsonl_path):
        self.samples = []
        with open(jsonl_path, 'r') as f:
            for line in f:
                item = json.loads(line)
                label = LABEL2IDX[item['label']]
                cdr3_tensor = torch.tensor(item['cdr3s'], dtype=torch.long)
                self.samples.append((cdr3_tensor, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

### Load Data

In [11]:
dataset = CDR3Dataset(jsonl_path)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_set, val_set = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE)

### LSTM Model 

In [12]:
class LSTMClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.pool = nn.AdaptiveMaxPool1d(1)
        self.fc = nn.Linear(hidden_dim, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        B, T, L = x.size()  # B=batch, T=CDR3s, L=AA per CDR3
        x = x.view(B * T, L)
        x = self.embedding(x)
        x, _ = self.lstm(x)
        x = x[:, -1, :]  # Take last LSTM output
        x = x.view(B, T, -1).permute(0, 2, 1)  # B x H x T
        x = self.pool(x).squeeze(2)
        x = self.fc(x)
        return self.sigmoid(x).squeeze(1)

### Train loop

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LSTMClassifier(vocab_size=VOCAB_SIZE, embed_dim=EMBED_DIM, hidden_dim=HIDDEN_DIM).to(device)
loss_fn = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

In [14]:
for epoch in range(1, NUM_EPOCHS + 1):
    model.train()
    train_loss, correct, total = 0, 0, 0
    for x, y in train_loader:
        x, y = x.to(device), torch.tensor(y, dtype=torch.float32).to(device)
        preds = model(x)
        loss = loss_fn(preds, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * len(y)
        correct += ((preds > 0.5).long() == y.long()).sum().item()
        total += len(y)

    train_acc = correct / total

    model.eval()
    val_correct, val_total = 0, 0
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), torch.tensor(y, dtype=torch.float32).to(device)
            preds = model(x)
            val_correct += ((preds > 0.5).long() == y.long()).sum().item()
            val_total += len(y)

    val_acc = val_correct / val_total
    print(f"Epoch {epoch}: Train Loss = {train_loss/total:.4f}, Acc = {train_acc:.4f}")
    print(f"\t\tVal Acc = {val_acc:.4f}")


  x, y = x.to(device), torch.tensor(y, dtype=torch.float32).to(device)
  x, y = x.to(device), torch.tensor(y, dtype=torch.float32).to(device)


Epoch 1: Train Loss = 0.8678, Acc = 0.4941
		Val Acc = 0.5234
Epoch 2: Train Loss = 0.7115, Acc = 0.5106
		Val Acc = 0.4860
Epoch 3: Train Loss = 0.7021, Acc = 0.4518
		Val Acc = 0.5047
Epoch 4: Train Loss = 0.6962, Acc = 0.4682
		Val Acc = 0.4860
Epoch 5: Train Loss = 0.6974, Acc = 0.4988
		Val Acc = 0.5140
Epoch 6: Train Loss = 0.7029, Acc = 0.5082
		Val Acc = 0.4860
Epoch 7: Train Loss = 0.7005, Acc = 0.4871
		Val Acc = 0.4860
Epoch 8: Train Loss = 0.7113, Acc = 0.4988
		Val Acc = 0.5140
Epoch 9: Train Loss = 0.7093, Acc = 0.4729
		Val Acc = 0.4860
Epoch 10: Train Loss = 0.6980, Acc = 0.4565
		Val Acc = 0.5234
Epoch 11: Train Loss = 0.6946, Acc = 0.4541
		Val Acc = 0.4860
Epoch 12: Train Loss = 0.6997, Acc = 0.5106
		Val Acc = 0.5140
Epoch 13: Train Loss = 0.7079, Acc = 0.4800
		Val Acc = 0.4860
Epoch 14: Train Loss = 0.7110, Acc = 0.5012
		Val Acc = 0.5140
Epoch 15: Train Loss = 0.7014, Acc = 0.4965
		Val Acc = 0.5234
Epoch 16: Train Loss = 0.6976, Acc = 0.4941
		Val Acc = 0.4953
E

We experimented with an LSTM architecture to better capture sequential dependencies in each CDR3 sequence. However, the LSTM model also plateaued early in training:

Best Validation Accuracy: ~56.0%

Training Accuracy (Epoch 90): 57.4%

Val Accuracy (Epoch 90): 53.2%

Despite being more expressive, the LSTM model failed to significantly outperform mean pooling.

