### Imports 

In [1]:
import json
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, roc_auc_score
from pathlib import Path

### Constants

In [2]:
basedir = Path("/Users/tusharsingh/Work/Project/DL-cdr3-tumor")
jsonl_file = basedir /"processed"/ "cdr3_tumor_normal.jsonl"
model_path = basedir / "mean_pool_best_model.pt"
plot_dir = basedir / "plots"

BATCH_SIZE = 22
EMBEDDING_DIM = 32
VOCAB_SIZE = 22  # 20 AAs + PAD + UNK

### Load Data

In [3]:
class PatientCDR3Dataset(torch.utils.data.Dataset):
    def __init__(self, jsonl_file):
        with open(jsonl_file) as f:
            self.data = [json.loads(line) for line in f]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        patient = self.data[idx]
        x = torch.tensor(patient['cdr3s'], dtype=torch.long)
        y = 1 if patient['label'] == 'tumor' else 0
        return x, torch.tensor(y, dtype=torch.float)

### Model

In [4]:
class MeanPoolModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.fc = nn.Linear(embedding_dim, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        emb = self.embedding(x)                    # [B, CDR3, LEN, D]
        pooled = emb.mean(dim=1).mean(dim=1)       # mean over CDR3s then over LEN
        return self.sigmoid(self.fc(pooled)).squeeze()

### Evaluation

In [5]:
def evaluate_predictions(model, loader):
    model.eval()
    y_true, y_probs = [], []
    with torch.no_grad():
        for x, y in loader:
            output = model(x)
            y_true.extend(y.numpy())
            y_probs.extend(output.numpy())
    y_pred = [1 if p > 0.5 else 0 for p in y_probs]
    return np.array(y_true), np.array(y_pred), np.array(y_probs)

### Plots

In [6]:
def plot_confusion_matrix(y_true, y_pred, path):
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(6,5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=["Normal", "Tumor"],
                yticklabels=["Normal", "Tumor"])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Confusion Matrix")
    plt.tight_layout()
    plt.savefig(path)
    plt.close()

def plot_roc_curve(y_true, y_probs, path):
    fpr, tpr, _ = roc_curve(y_true, y_probs)
    auc = roc_auc_score(y_true, y_probs)
    plt.figure(figsize=(6,5))
    plt.plot(fpr, tpr, label=f"AUC = {auc:.2f}")
    plt.plot([0,1], [0,1], '--', color='gray')
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("ROC Curve")
    plt.legend()
    plt.tight_layout()
    plt.savefig(path)
    plt.close()

### Run

In [7]:
if __name__ == "__main__":
    dataset = PatientCDR3Dataset(jsonl_file)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE)

    model = MeanPoolModel(VOCAB_SIZE, EMBEDDING_DIM)
    model.load_state_dict(torch.load(model_path))
    model.eval()

    y_true, y_pred, y_probs = evaluate_predictions(model, loader)

    plot_confusion_matrix(y_true, y_pred, plot_dir / "confusion_matrix.png")
    plot_roc_curve(y_true, y_probs, plot_dir / "roc_curve.png")

    print("Classification Report:\n")
    print(classification_report(y_true, y_pred, target_names=["Normal", "Tumor"]))

Classification Report:

              precision    recall  f1-score   support

      Normal       0.61      0.41      0.49       266
       Tumor       0.56      0.74      0.64       266

    accuracy                           0.58       532
   macro avg       0.58      0.58      0.56       532
weighted avg       0.58      0.58      0.56       532



  model.load_state_dict(torch.load(model_path))
