In [None]:
# -*- coding: utf-8 -*-
# Comprehensive TCR-peptide-MHC Binding Prediction Benchmark
# Dependencies: torch, transformers, numpy, pandas, scikit-learn, tqdm, requests, hashlib

import os, sys, math, json, time, random, requests, warnings, hashlib
from typing import List, Dict, Tuple, Set, Optional, Any, Union
from collections import defaultdict, Counter
from dataclasses import dataclass, field
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from transformers import EsmModel, EsmTokenizer
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score
from sklearn.exceptions import ConvergenceWarning
from sklearn.feature_extraction.text import HashingVectorizer

try:
    from tqdm import tqdm
except ImportError:
    raise ImportError("Please install tqdm: pip install tqdm")

# --------------------------
# Configuration & Utils
# --------------------------

@dataclass
class ExperimentConfig:
    seed: int = 42
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    n_runs: int = 3  # Multiple runs for CI
    test_size: float = 0.2
    val_size: float = 0.1
    batch_size: int = 64
    epochs: int = 10
    lr: float = 1e-4
    
    # ESM config
    esm_model_name: str = "facebook/esm2_t12_35M_UR50D"
    esm_freeze_layers: int = 6
    d_model: int = 128
    
    # k-mer config
    kmer_range: Tuple[int, int] = field(default_factory=lambda: (3, 5))
    hash_features: int = 100000
    
    # CNN config
    cnn_channels: List[int] = field(default_factory=lambda: [32, 64, 128])
    cnn_kernel_sizes: List[int] = field(default_factory=lambda: [3, 5, 7])
    
    # LSTM config
    lstm_hidden: int = 128
    lstm_layers: int = 2
    
    # Retrieval config
    retrieval_top_k: int = 200
    
    # TCR distance config
    tcr_knn_k: List[int] = field(default_factory=lambda: [1, 5, 10])

AA_STANDARD = list("ACDEFGHIKLMNPQRSTVWY")
AA_SET = set(AA_STANDARD)
AA_TO_IDX = {aa: i for i, aa in enumerate(AA_STANDARD)}

# MHC pseudo-sequences
MHC_PSEUDO_SEQUENCES = {
    "HLA-A*02:01": "GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRMEPRAPWIEQEGPEYWDGETRKVKAHSQTHRVDLGTLRGYYNQSEAGSHTVQRMYGCDVGSDWRFLRGYHQYAYDGKDYIALKEDLRSWTAADMAAQTTKHKWEAAHVAEQLRAYLEGTCVEWLRRYLENGKETLQRTDAPKTHMTHHAVSDHEATLRCWALSFYPAEITLTWQRDGEDQTQDTELVETRPAGDGTFQKWAAVVVPSGQEQRYTCHVQHEGLPKPLTLRWE",
    "HLA-A*01:01": "GSHSMRYFFTSVSRPGRGEPRFIAMGYVDDTQFVRFDSDAASQKMEPRAPWIEQEGPEYWDRETQKAKGNEQSFRVDLRTLLGYYNQSEDGSHTIQIMYGCDVGPDGRLLRGYDQYAYDGKDYIALNEDLRSWTAADTAAQITQRKWEAARVAEQLRAYLEGTCVEWLRRYLENGKDKLERADPPKTHVTHHPISDHEATLRCWALGFYPAEITLTWQRDGEDQTQDTELVETRPAGDGTFQKWAAVVVPSGEEQRYTCHVQHEGLPKPLTLRWE",
    "HLA-B*07:02": "GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQKMEPRAPWIEQEGPEYWDRETQKAKGNEQSFRVDLRTLLGYYNQSEDGSHTIQIMYGCDVGPDGRLLRGYDQYAYDGKDYIALNEDLRSWTAADTAAQITQRKWEAARVAEQLRAYLEGTCVEWLRRYLENGKDKLERADPPKTHVTHHPISDHEATLRCWALGFYPAEITLTWQRDGEDQTQDTELVETRPAGDGTFQKWAAVVVPSGEEQRYTCHVQHEGLPKPLTLRWE"
}

def get_mhc_sequence(mhc_name: str) -> str:
    return MHC_PSEUDO_SEQUENCES.get(mhc_name, MHC_PSEUDO_SEQUENCES["HLA-A*02:01"])

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def aa_to_onehot(sequence: str, max_len: int = 50) -> np.ndarray:
    """Convert amino acid sequence to one-hot encoding"""
    seq = sequence.upper()[:max_len]
    encoding = np.zeros((max_len, len(AA_STANDARD)))
    for i, aa in enumerate(seq):
        if aa in AA_TO_IDX:
            encoding[i, AA_TO_IDX[aa]] = 1
    return encoding

def aa_to_indices(sequence: str, max_len: int = 50) -> np.ndarray:
    """Convert amino acid sequence to indices"""
    seq = sequence.upper()[:max_len]
    indices = np.zeros(max_len, dtype=np.long)
    for i, aa in enumerate(seq):
        if aa in AA_TO_IDX:
            indices[i] = AA_TO_IDX[aa]
    return indices

# --------------------------
# Data Management
# --------------------------

class DataManager:
    def __init__(self, config: ExperimentConfig):
        self.config = config
        self.df_train = None
        self.df_val = None
        self.df_test = None
        
    def load_and_split_data(self, data_path: str = "data.csv"):
        """Load data and create train/val/test splits"""
        # Create demo data if not exists
        if not os.path.exists(data_path):
            self._create_demo_data(data_path)
        
        df = pd.read_csv(data_path)
        df = self._preprocess_data(df)
        
        # Generate negatives
        df_with_negs = self._build_negatives(df)
        
        # Split data
        self._split_data(df_with_negs)
        
        print(f"Data split: Train={len(self.df_train)}, Val={len(self.df_val)}, Test={len(self.df_test)}")
        
    def _create_demo_data(self, path: str):
        """Create demonstration dataset"""
        demo_data = [
            ["CDR3", "MHC", "Epitope"],
            ["CASSLEETQYF", "HLA-A*02:01", "GILGFVFTL"],
            ["CASSFRGTQYF", "HLA-A*02:01", "NLVPMVATV"],
            ["CASRPGLAGGRPEQYF", "HLA-A*02:01", "TPRVTGGGAM"],
            ["CSVEGGSTDTQYF", "HLA-A*02:01", "ELAGIGILTV"],
            ["CASSQDTQYF", "HLA-A*02:01", "LLWNGPMAV"],
            ["CAWRNTGQLYF", "HLA-A*02:01", "KLVALGINAV"],
            ["CASTLESGQYF", "HLA-A*02:01", "VTEHDTLLY"],
            ["CASSPPRVYNEQFF", "HLA-A*02:01", "LLWNGPMAV"],
            ["CASSPGQGAYNEQFF", "HLA-A*02:01", "GILGFVFTL"],
            ["CASSLEETQYF", "HLA-A*01:01", "TTPESANL"],
            ["CASSFRGTQYF", "HLA-B*07:02", "APRTLVYLL"],
            # Add more diverse examples
            ["CASSRGQGVYNEQFF", "HLA-A*02:01", "FLKEKGGL"],
            ["CASSPRGTDTQYF", "HLA-A*02:01", "IMDQVPFSV"],
            ["CASSFDRVGDNEQFF", "HLA-A*02:01", "KLGGALQAK"],
            ["CASSLVGAGGRPEQYF", "HLA-A*02:01", "GVYDGREHTV"],
            ["CASSQVGQGAYNEQFF", "HLA-A*01:01", "IVDCLTEMY"],
            ["CASSLGQGAVGEQFF", "HLA-B*07:02", "FPVRPQVPL"],
            ["CASSITGQGDNEQFF", "HLA-A*02:01", "YLQPRTFLL"],
            ["CASSFGQGAYNEQFF", "HLA-A*02:01", "ALWEIQQVV"],
        ]
        
        with open(path, "w") as f:
            for row in demo_data:
                f.write(",".join(row) + "\n")
    
    def _preprocess_data(self, df: pd.DataFrame) -> pd.DataFrame:
        """Clean and preprocess data"""
        df = df.copy()
        df["CDR3"] = df["CDR3"].astype(str).str.upper()
        df["MHC"] = df["MHC"].astype(str).str.upper()
        df["Epitope"] = df["Epitope"].astype(str).str.upper()
        df = df.dropna().drop_duplicates()
        df = df[df["Epitope"].map(lambda p: all(ch in AA_SET for ch in p))]
        df = df[df["CDR3"].map(lambda p: all(ch in AA_SET for ch in p))]
        return df
    
    def _build_negatives(self, df: pd.DataFrame, k_neg: int = 4) -> pd.DataFrame:
        """Generate negative samples"""
        df = df.copy()
        df["label"] = 1
        
        # Group peptides by length
        peps_by_len = defaultdict(list)
        for pep in df["Epitope"].unique():
            peps_by_len[len(pep)].append(pep)
        
        # Generate negatives
        negatives = []
        for _, row in df.iterrows():
            cdr3, mhc, pep = row["CDR3"], row["MHC"], row["Epitope"]
            
            # Same-length peptide shuffling
            candidates = peps_by_len[len(pep)]
            neg_peps = [p for p in candidates if p != pep]
            
            if len(neg_peps) >= k_neg:
                selected_negs = random.sample(neg_peps, k_neg)
            else:
                selected_negs = neg_peps + [self._mutate_peptide(pep) for _ in range(k_neg - len(neg_peps))]
            
            for neg_pep in selected_negs:
                negatives.append({"CDR3": cdr3, "MHC": mhc, "Epitope": neg_pep, "label": 0})
        
        df_neg = pd.DataFrame(negatives)
        return pd.concat([df, df_neg], ignore_index=True).drop_duplicates()
    
    def _mutate_peptide(self, peptide: str, n_mut: int = 1) -> str:
        """Mutate peptide for negative sampling"""
        s = list(peptide)
        for _ in range(n_mut):
            pos = random.randint(0, len(s) - 1)
            original = s[pos]
            s[pos] = random.choice([aa for aa in AA_STANDARD if aa != original])
        return "".join(s)
    
    def _split_data(self, df: pd.DataFrame):
        """Split data by epitope to avoid leakage"""
        positive_epitopes = df[df["label"] == 1]["Epitope"].unique()
        random.shuffle(positive_epitopes)
        
        n_test = int(len(positive_epitopes) * self.config.test_size)
        n_val = int(len(positive_epitopes) * self.config.val_size)
        
        test_epitopes = set(positive_epitopes[:n_test])
        val_epitopes = set(positive_epitopes[n_test:n_test + n_val])
        train_epitopes = set(positive_epitopes[n_test + n_val:])
        
        self.df_test = df[df["Epitope"].isin(test_epitopes)].reset_index(drop=True)
        self.df_val = df[df["Epitope"].isin(val_epitopes)].reset_index(drop=True)
        self.df_train = df[df["Epitope"].isin(train_epitopes)].reset_index(drop=True)

# --------------------------
# Evaluation Metrics
# --------------------------

class Evaluator:
    @staticmethod
    def compute_metrics(y_true: np.ndarray, y_pred: np.ndarray, y_prob: np.ndarray) -> Dict[str, float]:
        """Compute evaluation metrics"""
        return {
            "auc": roc_auc_score(y_true, y_prob),
            "auprc": average_precision_score(y_true, y_prob),
            "accuracy": accuracy_score(y_true, y_pred),
            "precision": ((y_pred == 1) & (y_true == 1)).sum() / max(1, (y_pred == 1).sum()),
            "recall": ((y_pred == 1) & (y_true == 1)).sum() / max(1, (y_true == 1).sum()),
        }
    
    @staticmethod
    def aggregate_metrics(metrics_list: List[Dict[str, float]]) -> Dict[str, Dict[str, float]]:
        """Aggregate metrics from multiple runs with CI"""
        aggregated = {}
        for metric in metrics_list[0].keys():
            values = [m[metric] for m in metrics_list]
            mean_val = np.mean(values)
            std_val = np.std(values)
            ci_lower = mean_val - 1.96 * std_val / np.sqrt(len(values))
            ci_upper = mean_val + 1.96 * std_val / np.sqrt(len(values))
            
            aggregated[metric] = {
                "mean": mean_val,
                "std": std_val,
                "ci_lower": ci_lower,
                "ci_upper": ci_upper
            }
        return aggregated

# --------------------------
# Base Model Classes
# --------------------------

class BaseModel:
    def __init__(self, config: ExperimentConfig):
        self.config = config
        
    def fit(self, df_train: pd.DataFrame, df_val: pd.DataFrame):
        raise NotImplementedError
        
    def predict(self, df_test: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]:
        """Return predictions and probabilities"""
        raise NotImplementedError

class TemperatureScaling:
    """Temperature scaling for calibration"""
    def __init__(self):
        self.temperature = 1.0
    
    def fit(self, logits: np.ndarray, labels: np.ndarray):
        """Fit temperature parameter"""
        try:
            from scipy.optimize import minimize_scalar
            
            def nll(temp):
                if temp <= 0:
                    return 1e6
                scaled_logits = logits / temp
                probs = 1 / (1 + np.exp(-scaled_logits))
                probs = np.clip(probs, 1e-7, 1 - 1e-7)
                return -np.mean(labels * np.log(probs) + (1 - labels) * np.log(1 - probs))
            
            result = minimize_scalar(nll, bounds=(0.1, 10.0), method='bounded')
            self.temperature = result.x
        except:
            # Fallback if scipy is not available
            self.temperature = 1.0
    
    def apply(self, logits: np.ndarray) -> np.ndarray:
        """Apply temperature scaling"""
        return 1 / (1 + np.exp(-logits / self.temperature))

# --------------------------
# Dataset Classes
# --------------------------

class TCRDataset(Dataset):
    def __init__(self, df: pd.DataFrame, mode: str = "classification"):
        self.df = df.reset_index(drop=True)
        self.mode = mode
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        return {
            "cdr3": str(row["CDR3"]),
            "mhc": str(row["MHC"]),
            "epitope": str(row["Epitope"]),
            "label": float(row["label"])
        }

def collate_fn(batch):
    return {
        "cdr3": [item["cdr3"] for item in batch],
        "mhc": [item["mhc"] for item in batch],
        "epitope": [item["epitope"] for item in batch],
        "labels": torch.tensor([item["label"] for item in batch], dtype=torch.float32)
    }

# --------------------------
# B2: Frozen ESM + Linear Head (Simplified for demo)
# --------------------------

class FrozenESMModel(BaseModel):
    def __init__(self, config: ExperimentConfig):
        super().__init__(config)
        self.classifier = None
        self.temp_scaling = TemperatureScaling()
        
    def fit(self, df_train: pd.DataFrame, df_val: pd.DataFrame):
        # Extract simple features (length-based for demo)
        train_features = self._extract_simple_features(df_train)
        val_features = self._extract_simple_features(df_val)
        
        train_labels = df_train["label"].values
        val_labels = df_val["label"].values
        
        # Train linear classifier
        self.classifier = LogisticRegression(max_iter=1000, random_state=self.config.seed)
        self.classifier.fit(train_features, train_labels)
        
        # Temperature scaling
        val_logits = self.classifier.decision_function(val_features)
        self.temp_scaling.fit(val_logits, val_labels)
        
    def _extract_simple_features(self, df: pd.DataFrame) -> np.ndarray:
        """Extract simple features for demo purposes"""
        features = []
        for _, row in df.iterrows():
            cdr3_len = len(row["CDR3"])
            epitope_len = len(row["Epitope"])
            mhc_type = 1 if "A*02:01" in row["MHC"] else 0
            
            # Simple amino acid composition features
            cdr3_aa_counts = [row["CDR3"].count(aa) for aa in "ACDEFGHIKLMNPQRSTVWY"]
            epitope_aa_counts = [row["Epitope"].count(aa) for aa in "ACDEFGHIKLMNPQRSTVWY"]
            
            feature_vector = [cdr3_len, epitope_len, mhc_type] + cdr3_aa_counts + epitope_aa_counts
            features.append(feature_vector)
        
        return np.array(features)
    
    def predict(self, df_test: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]:
        test_features = self._extract_simple_features(df_test)
        
        logits = self.classifier.decision_function(test_features)
        probs = self.temp_scaling.apply(logits)
        preds = (probs > 0.5).astype(int)
        
        return preds, probs

# --------------------------
# B3: k-mer + LR/MLP
# --------------------------

class KmerFeatureExtractor:
    def __init__(self, kmer_range: Tuple[int, int], hash_features: int):
        self.kmer_range = kmer_range
        self.hash_features = hash_features
        self.vectorizer = HashingVectorizer(
            n_features=hash_features,
            analyzer='char',
            ngram_range=kmer_range,
            binary=True
        )
        
    def extract_kmers(self, sequence: str) -> Set[str]:
        """Extract k-mers from sequence"""
        kmers = set()
        for k in range(self.kmer_range[0], self.kmer_range[1] + 1):
            for i in range(len(sequence) - k + 1):
                kmers.add(sequence[i:i + k])
        return kmers
    
    def fit_transform(self, sequences: List[str]) -> np.ndarray:
        """Fit and transform sequences to k-mer features"""
        kmer_strings = []
        for seq in sequences:
            kmers = self.extract_kmers(seq)
            kmer_strings.append(" ".join(kmers))
        
        return self.vectorizer.fit_transform(kmer_strings).toarray()
    
    def transform(self, sequences: List[str]) -> np.ndarray:
        """Transform sequences to k-mer features"""
        kmer_strings = []
        for seq in sequences:
            kmers = self.extract_kmers(seq)
            kmer_strings.append(" ".join(kmers))
        
        return self.vectorizer.transform(kmer_strings).toarray()

class KmerModel(BaseModel):
    def __init__(self, config: ExperimentConfig, model_type: str = "lr"):
        super().__init__(config)
        self.model_type = model_type
        self.tcr_extractor = KmerFeatureExtractor(config.kmer_range, config.hash_features // 3)
        self.mhc_extractor = KmerFeatureExtractor(config.kmer_range, config.hash_features // 3)
        self.pep_extractor = KmerFeatureExtractor(config.kmer_range, config.hash_features // 3)
        self.classifier = None
        self.temp_scaling = TemperatureScaling()
        
    def fit(self, df_train: pd.DataFrame, df_val: pd.DataFrame):
        # Extract features
        train_features = self._extract_features(df_train, fit=True)
        val_features = self._extract_features(df_val, fit=False)
        
        train_labels = df_train["label"].values
        val_labels = df_val["label"].values
        
        # Train classifier
        if self.model_type == "lr":
            self.classifier = LogisticRegression(max_iter=1000, random_state=self.config.seed)
        else:  # mlp
            self.classifier = MLPClassifier(
                hidden_layer_sizes=(256, 128),
                max_iter=500,
                random_state=self.config.seed
            )
        
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=ConvergenceWarning)
            self.classifier.fit(train_features, train_labels)
        
        # Temperature scaling
        if hasattr(self.classifier, 'decision_function'):
            val_logits = self.classifier.decision_function(val_features)
        else:
            val_probs = self.classifier.predict_proba(val_features)[:, 1]
            val_logits = np.log(val_probs / (1 - val_probs + 1e-8))
        
        self.temp_scaling.fit(val_logits, val_labels)
        
    def _extract_features(self, df: pd.DataFrame, fit: bool = False) -> np.ndarray:
        """Extract k-mer features"""
        tcr_seqs = df["CDR3"].tolist()
        mhc_seqs = [get_mhc_sequence(mhc) for mhc in df["MHC"].tolist()]
        pep_seqs = df["Epitope"].tolist()
        
        if fit:
            tcr_features = self.tcr_extractor.fit_transform(tcr_seqs)
            mhc_features = self.mhc_extractor.fit_transform(mhc_seqs)
            pep_features = self.pep_extractor.fit_transform(pep_seqs)
        else:
            tcr_features = self.tcr_extractor.transform(tcr_seqs)
            mhc_features = self.mhc_extractor.transform(mhc_seqs)
            pep_features = self.pep_extractor.transform(pep_seqs)
        
        return np.hstack([tcr_features, mhc_features, pep_features])
    
    def predict(self, df_test: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]:
        test_features = self._extract_features(df_test, fit=False)
        
        if hasattr(self.classifier, 'decision_function'):
            logits = self.classifier.decision_function(test_features)
        else:
            probs = self.classifier.predict_proba(test_features)[:, 1]
            logits = np.log(probs / (1 - probs + 1e-8))
        
        probs = self.temp_scaling.apply(logits)
        preds = (probs > 0.5).astype(int)
        
        return preds, probs

# --------------------------
# B7: TCR Distance + k-NN (Simplified)
# --------------------------

def compute_tcr_distance(tcr1: str, tcr2: str) -> float:
    """Simplified TCR distance (edit distance)"""
    # Simple Hamming distance for same length, edit distance otherwise
    if len(tcr1) == len(tcr2):
        return sum(c1 != c2 for c1, c2 in zip(tcr1, tcr2)) / len(tcr1)
    else:
        # Simple edit distance approximation
        from difflib import SequenceMatcher
        return 1.0 - SequenceMatcher(None, tcr1, tcr2).ratio()

class TCRDistanceModel(BaseModel):
    def __init__(self, config: ExperimentConfig, k: int = 5):
        super().__init__(config)
        self.k = k
        self.train_data = None
        
    def fit(self, df_train: pd.DataFrame, df_val: pd.DataFrame):
        # Store training data for k-NN
        self.train_data = df_train.copy()
        
    def predict(self, df_test: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]:
        predictions = []
        probabilities = []
        
        for _, test_row in df_test.iterrows():
            test_tcr = test_row["CDR3"]
            test_mhc = test_row["MHC"]
            test_epitope = test_row["Epitope"]
            
            # Find k nearest neighbors based on TCR distance
            distances = []
            for _, train_row in self.train_data.iterrows():
                train_tcr = train_row["CDR3"]
                
                # Only consider same MHC-epitope pairs
                if train_row["MHC"] == test_mhc and train_row["Epitope"] == test_epitope:
                    dist = compute_tcr_distance(test_tcr, train_tcr)
                    distances.append((dist, train_row["label"]))
            
            if distances:
                # Sort by distance and take k nearest
                distances.sort(key=lambda x: x[0])
                k_nearest = distances[:self.k]
                
                # Vote
                positive_votes = sum([1 for _, label in k_nearest if label == 1])
                prob = positive_votes / len(k_nearest)
                pred = 1 if prob > 0.5 else 0
            else:
                # No similar examples found
                prob = 0.0
                pred = 0
            
            predictions.append(pred)
            probabilities.append(prob)
        
        return np.array(predictions), np.array(probabilities)

# --------------------------
# B8: Peptide-MHC Prior + Random TCR
# --------------------------

class PeptideMHCPriorModel(BaseModel):
    def __init__(self, config: ExperimentConfig):
        super().__init__(config)
        self.mhc_binding_scores = {}
        
    def fit(self, df_train: pd.DataFrame, df_val: pd.DataFrame):
        # Compute peptide-MHC binding frequencies from training data
        for _, row in df_train.iterrows():
            key = (row["Epitope"], row["MHC"])
            if key not in self.mhc_binding_scores:
                self.mhc_binding_scores[key] = {"pos": 0, "neg": 0}
            
            if row["label"] == 1:
                self.mhc_binding_scores[key]["pos"] += 1
            else:
                self.mhc_binding_scores[key]["neg"] += 1
        
        # Convert to probabilities
        for key in self.mhc_binding_scores:
            total = self.mhc_binding_scores[key]["pos"] + self.mhc_binding_scores[key]["neg"]
            self.mhc_binding_scores[key] = self.mhc_binding_scores[key]["pos"] / total if total > 0 else 0.5
        
    def predict(self, df_test: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]:
        predictions = []
        probabilities = []
        
        for _, row in df_test.iterrows():
            key = (row["Epitope"], row["MHC"])
            
            if key in self.mhc_binding_scores:
                prob = self.mhc_binding_scores[key]
            else:
                # Unknown peptide-MHC pair, use random
                prob = random.random()
            
            pred = 1 if prob > 0.5 else 0
            predictions.append(pred)
            probabilities.append(prob)
        
        return np.array(predictions), np.array(probabilities)

# --------------------------
# Experiment Runner
# --------------------------

class ExperimentRunner:
    def __init__(self, config: ExperimentConfig):
        self.config = config
        self.data_manager = DataManager(config)
        self.evaluator = Evaluator()
        
    def run_all_experiments(self) -> pd.DataFrame:
        """Run all experiments and return results table"""
        print("Loading and preparing data...")
        self.data_manager.load_and_split_data()
        
        # Simplified model set for demo
        models = {
            "B2_Frozen_ESM": FrozenESMModel,
            "B3_KMer_LR": lambda config: KmerModel(config, "lr"),
            "B3_KMer_MLP": lambda config: KmerModel(config, "mlp"),
            "B7_TCRdist_k1": lambda config: TCRDistanceModel(config, k=1),
            "B7_TCRdist_k5": lambda config: TCRDistanceModel(config, k=5),
            "B7_TCRdist_k10": lambda config: TCRDistanceModel(config, k=10),
            "B8_PeptideMHC_Prior": PeptideMHCPriorModel,
        }
        
        results = []
        
        for model_name, model_class in models.items():
            print(f"\n{'='*60}")
            print(f"Running {model_name}...")
            print(f"{'='*60}")
            
            model_results = []
            
            for run in range(self.config.n_runs):
                print(f"Run {run + 1}/{self.config.n_runs}")
                
                # Set seed for reproducibility
                set_seed(self.config.seed + run)
                
                try:
                    # Initialize and train model
                    model = model_class(self.config)
                    model.fit(self.data_manager.df_train, self.data_manager.df_val)
                    
                    # Predict on test set
                    y_pred, y_prob = model.predict(self.data_manager.df_test)
                    y_true = self.data_manager.df_test["label"].values
                    
                    # Compute metrics
                    metrics = self.evaluator.compute_metrics(y_true, y_pred, y_prob)
                    model_results.append(metrics)
                    
                    print(f"  AUC: {metrics['auc']:.4f}, AUPRC: {metrics['auprc']:.4f}")
                    
                except Exception as e:
                    print(f"  Error in {model_name} run {run + 1}: {str(e)}")
                    continue
            
            if model_results:
                # Aggregate results
                aggregated = self.evaluator.aggregate_metrics(model_results)
                
                for metric, stats in aggregated.items():
                    results.append({
                        "Model": model_name,
                        "Metric": metric.upper(),
                        "Mean": stats["mean"],
                        "Std": stats["std"],
                        "CI_Lower": stats["ci_lower"],
                        "CI_Upper": stats["ci_upper"],
                        "CI_String": f"{stats['mean']:.4f} ± {1.96 * stats['std'] / np.sqrt(self.config.n_runs):.4f}"
                    })
        
        return pd.DataFrame(results)
    
    def save_results(self, results_df: pd.DataFrame, output_path: str = "experiment_results.csv"):
        """Save results to CSV and print summary table"""
        results_df.to_csv(output_path, index=False)
        
        # Create summary table
        print(f"\n{'='*100}")
        print("EXPERIMENT RESULTS SUMMARY")
        print(f"{'='*100}")
        
        # Pivot table for better readability
        pivot_df = results_df.pivot(index="Model", columns="Metric", values="CI_String")
        print(pivot_df.to_string())
        
        print(f"\nDetailed results saved to: {output_path}")
        print(f"{'='*100}")

# --------------------------
# Main Function
# --------------------------

def main():
    """Main experiment runner"""
    # Configuration
    config = ExperimentConfig(
        seed=42,
        n_runs=3,
        epochs=5,  # Reduced for demo
        batch_size=32,
        lr=2e-4
    )
    
    print("Starting comprehensive TCR-peptide-MHC binding prediction benchmark...")
    print(f"Configuration: {config}")
    
    # Run experiments
    runner = ExperimentRunner(config)
    results_df = runner.run_all_experiments()
    
    # Save and display results
    runner.save_results(results_df)

if __name__ == "__main__":
    main()

Starting comprehensive TCR-peptide-MHC binding prediction benchmark...
Configuration: ExperimentConfig(seed=42, device='cuda', n_runs=3, test_size=0.2, val_size=0.1, batch_size=32, epochs=5, lr=0.0002, esm_model_name='facebook/esm2_t12_35M_UR50D', esm_freeze_layers=6, d_model=128, kmer_range=(3, 5), hash_features=100000, cnn_channels=[32, 64, 128], cnn_kernel_sizes=[3, 5, 7], lstm_hidden=128, lstm_layers=2, retrieval_top_k=200, tcr_knn_k=[1, 5, 10])
Loading and preparing data...
Data split: Train=16993, Val=2479, Test=4992

Running B2_Frozen_ESM...
Run 1/3
  AUC: 0.4811, AUPRC: 0.2579
Run 2/3
  AUC: 0.4811, AUPRC: 0.2579
Run 3/3
  AUC: 0.4811, AUPRC: 0.2579

Running B3_KMer_LR...
Run 1/3
  AUC: 0.3881, AUPRC: 0.1729
Run 2/3
  AUC: 0.3881, AUPRC: 0.1729
Run 3/3
  AUC: 0.3881, AUPRC: 0.1729

Running B3_KMer_MLP...
Run 1/3
  AUC: 0.4905, AUPRC: 0.2087
Run 2/3
  AUC: 0.4905, AUPRC: 0.2087
Run 3/3
  AUC: 0.4905, AUPRC: 0.2087

Running B7_TCRdist_k1...
Run 1/3
  AUC: 0.5000, AUPRC: 0.2234
Ru