In [None]:
# -*- coding: utf-8 -*-
# Comprehensive TCR-peptide-MHC Binding Prediction Benchmark
# Main Table + Auxiliary Methods (BX series)
# Dependencies: torch, transformers, numpy, pandas, scikit-learn, tqdm, requests

import os, sys, math, json, time, random, requests, warnings, hashlib
from typing import List, Dict, Tuple, Set, Optional, Any, Union
from collections import defaultdict, Counter
from dataclasses import dataclass, field
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from transformers import EsmModel, EsmTokenizer, BertModel, BertTokenizer
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score
from sklearn.exceptions import ConvergenceWarning
from sklearn.feature_extraction.text import HashingVectorizer

try:
    from tqdm import tqdm
except ImportError:
    raise ImportError("Please install tqdm: pip install tqdm")

# --------------------------
# Configuration & Utils
# --------------------------

@dataclass
class ExperimentConfig:
    seed: int = 42
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    n_runs: int = 3
    test_size: float = 0.2
    val_size: float = 0.1
    batch_size: int = 64
    epochs: int = 10
    lr: float = 1e-4
    
    # ESM config
    esm_model_name: str = "facebook/esm2_t12_35M_UR50D"
    esm_freeze_layers: int = 6
    d_model: int = 128
    
    # ProtBert config
    protbert_model_name: str = "Rostlab/prot_bert"
    protbert_freeze_layers: int = 8
    
    # k-mer config
    kmer_range: Tuple[int, int] = field(default_factory=lambda: (3, 5))
    hash_features: int = 100000
    
    # CNN config
    cnn_channels: List[int] = field(default_factory=lambda: [32, 64, 128])
    cnn_kernel_sizes: List[int] = field(default_factory=lambda: [3, 5, 7])
    
    # LSTM config
    lstm_hidden: int = 128
    lstm_layers: int = 2
    
    # Transformer config
    transformer_layers: int = 4
    transformer_heads: int = 8
    
    # GNN config
    gnn_hidden: int = 128
    gnn_layers: int = 2
    
    # VAE config
    vae_latent_dim: int = 64
    
    # Retrieval config
    retrieval_top_k: int = 200
    
    # TCR distance config
    tcr_knn_k: List[int] = field(default_factory=lambda: [1, 5, 10])

AA_STANDARD = list("ACDEFGHIKLMNPQRSTVWY")
AA_SET = set(AA_STANDARD)
AA_TO_IDX = {aa: i for i, aa in enumerate(AA_STANDARD)}

# MHC pseudo-sequences
MHC_PSEUDO_SEQUENCES = {
    "HLA-A*02:01": "GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRMEPRAPWIEQEGPEYWDGETRKVKAHSQTHRVDLGTLRGYYNQSEAGSHTVQRMYGCDVGSDWRFLRGYHQYAYDGKDYIALKEDLRSWTAADMAAQTTKHKWEAAHVAEQLRAYLEGTCVEWLRRYLENGKETLQRTDAPKTHMTHHAVSDHEATLRCWALSFYPAEITLTWQRDGEDQTQDTELVETRPAGDGTFQKWAAVVVPSGQEQRYTCHVQHEGLPKPLTLRWE",
    "HLA-A*01:01": "GSHSMRYFFTSVSRPGRGEPRFIAMGYVDDTQFVRFDSDAASQKMEPRAPWIEQEGPEYWDRETQKAKGNEQSFRVDLRTLLGYYNQSEDGSHTIQIMYGCDVGPDGRLLRGYDQYAYDGKDYIALNEDLRSWTAADTAAQITQRKWEAARVAEQLRAYLEGTCVEWLRRYLENGKDKLERADPPKTHVTHHPISDHEATLRCWALGFYPAEITLTWQRDGEDQTQDTELVETRPAGDGTFQKWAAVVVPSGEEQRYTCHVQHEGLPKPLTLRWE",
    "HLA-B*07:02": "GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQKMEPRAPWIEQEGPEYWDRETQKAKGNEQSFRVDLRTLLGYYNQSEDGSHTIQIMYGCDVGPDGRLLRGYDQYAYDGKDYIALNEDLRSWTAADTAAQITQRKWEAARVAEQLRAYLEGTCVEWLRRYLENGKDKLERADPPKTHVTHHPISDHEATLRCWALGFYPAEITLTWQRDGEDQTQDTELVETRPAGDGTFQKWAAVVVPSGEEQRYTCHVQHEGLPKPLTLRWE"
}

def get_mhc_sequence(mhc_name: str) -> str:
    return MHC_PSEUDO_SEQUENCES.get(mhc_name, MHC_PSEUDO_SEQUENCES["HLA-A*02:01"])

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def aa_to_indices(sequence: str, max_len: int = 50) -> np.ndarray:
    """Convert amino acid sequence to indices"""
    seq = sequence.upper()[:max_len]
    indices = np.zeros(max_len, dtype=np.long)
    for i, aa in enumerate(seq):
        if aa in AA_TO_IDX:
            indices[i] = AA_TO_IDX[aa]
    return indices

def compute_sequence_similarity(seq1: str, seq2: str) -> float:
    """Compute sequence similarity using simple alignment"""
    from difflib import SequenceMatcher
    return SequenceMatcher(None, seq1, seq2).ratio()

# --------------------------
# Data Management (Enhanced)
# --------------------------

class DataManager:
    def __init__(self, config: ExperimentConfig):
        self.config = config
        self.df_train = None
        self.df_val = None
        self.df_test = None
        
    def load_and_split_data(self, data_path: str = "data.csv"):
        """Load data and create train/val/test splits"""
        if not os.path.exists(data_path):
            self._create_enhanced_demo_data(data_path)
        
        df = pd.read_csv(data_path)
        df = self._preprocess_data(df)
        df_with_negs = self._build_negatives(df)
        self._split_data(df_with_negs)
        
        print(f"Data split: Train={len(self.df_train)}, Val={len(self.df_val)}, Test={len(self.df_test)}")
        
    def _create_enhanced_demo_data(self, path: str):
        """Create enhanced demonstration dataset"""
        demo_data = [
            ["CDR3", "MHC", "Epitope"],
            # HLA-A*02:01 epitopes
            ["CASSLEETQYF", "HLA-A*02:01", "GILGFVFTL"],
            ["CASSFRGTQYF", "HLA-A*02:01", "NLVPMVATV"],
            ["CASRPGLAGGRPEQYF", "HLA-A*02:01", "TPRVTGGGAM"],
            ["CSVEGGSTDTQYF", "HLA-A*02:01", "ELAGIGILTV"],
            ["CASSQDTQYF", "HLA-A*02:01", "LLWNGPMAV"],
            ["CAWRNTGQLYF", "HLA-A*02:01", "KLVALGINAV"],
            ["CASTLESGQYF", "HLA-A*02:01", "VTEHDTLLY"],
            ["CASSPPRVYNEQFF", "HLA-A*02:01", "LLWNGPMAV"],
            ["CASSPGQGAYNEQFF", "HLA-A*02:01", "GILGFVFTL"],
            ["CASSRGQGVYNEQFF", "HLA-A*02:01", "FLKEKGGL"],
            ["CASSPRGTDTQYF", "HLA-A*02:01", "IMDQVPFSV"],
            ["CASSFDRVGDNEQFF", "HLA-A*02:01", "KLGGALQAK"],
            ["CASSLVGAGGRPEQYF", "HLA-A*02:01", "GVYDGREHTV"],
            ["CASSITGQGDNEQFF", "HLA-A*02:01", "YLQPRTFLL"],
            ["CASSFGQGAYNEQFF", "HLA-A*02:01", "ALWEIQQVV"],
            # HLA-A*01:01 epitopes
            ["CASSLEETQYF", "HLA-A*01:01", "TTPESANL"],
            ["CASSQVGQGAYNEQFF", "HLA-A*01:01", "IVDCLTEMY"],
            ["CASSRGDTQYF", "HLA-A*01:01", "VTEHDTLLY"],
            ["CASSPPGQGAYNEQFF", "HLA-A*01:01", "TTPESANL"],
            ["CASSLVGAYNEQFF", "HLA-A*01:01", "IVDCLTEMY"],
            # HLA-B*07:02 epitopes  
            ["CASSFRGTQYF", "HLA-B*07:02", "APRTLVYLL"],
            ["CASSLGQGAVGEQFF", "HLA-B*07:02", "FPVRPQVPL"],
            ["CASSRGDNEQFF", "HLA-B*07:02", "APRTLVYLL"],
            ["CASSPPGAYNEQFF", "HLA-B*07:02", "FPVRPQVPL"],
            ["CASSLVGNEQFF", "HLA-B*07:02", "GPRLGVRAT"],
            # Additional diverse examples
            ["CASSYDRGDTQYF", "HLA-A*02:01", "GLCTLVAML"],
            ["CASSQGQGAYNEQFF", "HLA-A*02:01", "RLRAEAQVK"],
            ["CASSLGDTQYF", "HLA-A*01:01", "ASNENMETM"],
            ["CASSPRNEQFF", "HLA-B*07:02", "RPRGEVRFL"],
            ["CASSRGQGDTQYF", "HLA-A*02:01", "FRDYVDRFYKTLRAEQASQE"],  # Longer epitope
        ]
        
        with open(path, "w") as f:
            for row in demo_data:
                f.write(",".join(row) + "\n")
    
    def _preprocess_data(self, df: pd.DataFrame) -> pd.DataFrame:
        """Clean and preprocess data"""
        df = df.copy()
        df["CDR3"] = df["CDR3"].astype(str).str.upper()
        df["MHC"] = df["MHC"].astype(str).str.upper()
        df["Epitope"] = df["Epitope"].astype(str).str.upper()
        df = df.dropna().drop_duplicates()
        df = df[df["Epitope"].map(lambda p: all(ch in AA_SET for ch in p))]
        df = df[df["CDR3"].map(lambda p: all(ch in AA_SET for ch in p))]
        return df
    
    def _build_negatives(self, df: pd.DataFrame, k_neg: int = 4) -> pd.DataFrame:
        """Generate negative samples"""
        df = df.copy()
        df["label"] = 1
        
        # Group peptides by length for realistic negatives
        peps_by_len = defaultdict(list)
        for pep in df["Epitope"].unique():
            peps_by_len[len(pep)].append(pep)
        
        negatives = []
        for _, row in df.iterrows():
            cdr3, mhc, pep = row["CDR3"], row["MHC"], row["Epitope"]
            
            # Strategy 1: Same-length different peptides
            candidates = peps_by_len[len(pep)]
            neg_peps = [p for p in candidates if p != pep]
            
            if len(neg_peps) >= k_neg:
                selected_negs = random.sample(neg_peps, k_neg)
            else:
                selected_negs = neg_peps + [self._mutate_peptide(pep) for _ in range(k_neg - len(neg_peps))]
            
            for neg_pep in selected_negs:
                negatives.append({"CDR3": cdr3, "MHC": mhc, "Epitope": neg_pep, "label": 0})
            
            # Strategy 2: Same peptide, different TCR
            other_tcrs = df[df["Epitope"] != pep]["CDR3"].unique()
            if len(other_tcrs) > 0:
                neg_tcr = random.choice(other_tcrs)
                negatives.append({"CDR3": neg_tcr, "MHC": mhc, "Epitope": pep, "label": 0})
        
        df_neg = pd.DataFrame(negatives)
        return pd.concat([df, df_neg], ignore_index=True).drop_duplicates()
    
    def _mutate_peptide(self, peptide: str, n_mut: int = 1) -> str:
        """Mutate peptide for negative sampling"""
        s = list(peptide)
        for _ in range(n_mut):
            pos = random.randint(0, len(s) - 1)
            original = s[pos]
            s[pos] = random.choice([aa for aa in AA_STANDARD if aa != original])
        return "".join(s)
    
    def _split_data(self, df: pd.DataFrame):
        """Split data by epitope to avoid leakage"""
        positive_epitopes = df[df["label"] == 1]["Epitope"].unique()
        random.shuffle(positive_epitopes)
        
        n_test = max(1, int(len(positive_epitopes) * self.config.test_size))
        n_val = max(1, int(len(positive_epitopes) * self.config.val_size))
        
        test_epitopes = set(positive_epitopes[:n_test])
        val_epitopes = set(positive_epitopes[n_test:n_test + n_val])
        train_epitopes = set(positive_epitopes[n_test + n_val:])
        
        self.df_test = df[df["Epitope"].isin(test_epitopes)].reset_index(drop=True)
        self.df_val = df[df["Epitope"].isin(val_epitopes)].reset_index(drop=True)
        self.df_train = df[df["Epitope"].isin(train_epitopes)].reset_index(drop=True)

# --------------------------
# Base Classes
# --------------------------

class BaseModel:
    def __init__(self, config: ExperimentConfig):
        self.config = config
        
    def fit(self, df_train: pd.DataFrame, df_val: pd.DataFrame):
        raise NotImplementedError
        
    def predict(self, df_test: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]:
        """Return predictions and probabilities"""
        raise NotImplementedError

class TemperatureScaling:
    """Temperature scaling for calibration"""
    def __init__(self):
        self.temperature = 1.0
    
    def fit(self, logits: np.ndarray, labels: np.ndarray):
        """Fit temperature parameter"""
        try:
            from scipy.optimize import minimize_scalar
            
            def nll(temp):
                if temp <= 0:
                    return 1e6
                scaled_logits = logits / temp
                probs = 1 / (1 + np.exp(-scaled_logits))
                probs = np.clip(probs, 1e-7, 1 - 1e-7)
                return -np.mean(labels * np.log(probs) + (1 - labels) * np.log(1 - probs))
            
            result = minimize_scalar(nll, bounds=(0.1, 10.0), method='bounded')
            self.temperature = result.x
        except:
            self.temperature = 1.0
    
    def apply(self, logits: np.ndarray) -> np.ndarray:
        """Apply temperature scaling"""
        return 1 / (1 + np.exp(-logits / self.temperature))

class Evaluator:
    @staticmethod
    def compute_metrics(y_true: np.ndarray, y_pred: np.ndarray, y_prob: np.ndarray) -> Dict[str, float]:
        """Compute evaluation metrics"""
        return {
            "auc": roc_auc_score(y_true, y_prob),
            "auprc": average_precision_score(y_true, y_prob),
            "accuracy": accuracy_score(y_true, y_pred),
            "precision": ((y_pred == 1) & (y_true == 1)).sum() / max(1, (y_pred == 1).sum()),
            "recall": ((y_pred == 1) & (y_true == 1)).sum() / max(1, (y_true == 1).sum()),
        }
    
    @staticmethod
    def aggregate_metrics(metrics_list: List[Dict[str, float]]) -> Dict[str, Dict[str, float]]:
        """Aggregate metrics from multiple runs with CI"""
        aggregated = {}
        for metric in metrics_list[0].keys():
            values = [m[metric] for m in metrics_list]
            mean_val = np.mean(values)
            std_val = np.std(values)
            ci_lower = mean_val - 1.96 * std_val / np.sqrt(len(values))
            ci_upper = mean_val + 1.96 * std_val / np.sqrt(len(values))
            
            aggregated[metric] = {
                "mean": mean_val,
                "std": std_val,
                "ci_lower": ci_lower,
                "ci_upper": ci_upper
            }
        return aggregated

# --------------------------
# B1: ESM Fine-tuning + InfoNCE + Temperature Scaling (Main Model)
# --------------------------

class ESMSequenceEncoder:
    def __init__(self, model_name: str, device: str, freeze_layers: int = 6):
        print(f"[Info] Loading ESM model: {model_name}")
        self.device = device
        self.tokenizer = EsmTokenizer.from_pretrained(model_name)
        self.model = EsmModel.from_pretrained(model_name).to(device)
        
        # Selective unfreezing for fine-tuning
        total_layers = len(self.model.encoder.layer)
        for i, layer in enumerate(self.model.encoder.layer):
            if i < freeze_layers:
                for param in layer.parameters():
                    param.requires_grad = False
            else:
                for param in layer.parameters():
                    param.requires_grad = True
        
        # Always fine-tune embeddings and pooler
        for param in self.model.embeddings.parameters():
            param.requires_grad = True
        if hasattr(self.model, 'pooler') and self.model.pooler:
            for param in self.model.pooler.parameters():
                param.requires_grad = True
        
        self.hidden_size = self.model.config.hidden_size
        print(f"[Info] ESM loaded, hidden_size: {self.hidden_size}, trainable layers: {total_layers - freeze_layers}")
        
    def encode_batch(self, sequences: List[str], max_length: int = 512) -> torch.Tensor:
        clean_seqs = []
        for seq in sequences:
            clean_seq = "".join([c for c in seq.upper() if c in AA_SET])
            clean_seqs.append(clean_seq if clean_seq else "A")
        
        inputs = self.tokenizer(
            clean_seqs, return_tensors="pt", padding=True, 
            truncation=True, max_length=max_length
        ).to(self.device)
        
        outputs = self.model(**inputs)
        return outputs.last_hidden_state[:, 0, :]  # [CLS] token

class ESMFineTuneModel(nn.Module):
    def __init__(self, esm_encoder: ESMSequenceEncoder, d_model: int = 128, dropout: float = 0.1):
        super().__init__()
        self.esm_encoder = esm_encoder
        self.d_model = d_model
        
        # Projectors
        self.proj_tcr = nn.Sequential(
            nn.Linear(esm_encoder.hidden_size, d_model),
            nn.LayerNorm(d_model),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        self.proj_mhc = nn.Sequential(
            nn.Linear(esm_encoder.hidden_size, d_model),
            nn.LayerNorm(d_model),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        self.proj_peptide = nn.Sequential(
            nn.Linear(esm_encoder.hidden_size, d_model),
            nn.LayerNorm(d_model),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(d_model * 3, d_model),
            nn.LayerNorm(d_model),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, 1)
        )
        
        # InfoNCE projector
        self.infonce_proj = nn.Linear(d_model * 3, d_model)
        
    def forward(self, cdr3_seqs: List[str], mhc_alleles: List[str], epitope_seqs: List[str]):
        # Convert MHC alleles to sequences
        mhc_seqs = [get_mhc_sequence(allele) for allele in mhc_alleles]
        
        # Encode
        tcr_emb = self.esm_encoder.encode_batch(cdr3_seqs)
        mhc_emb = self.esm_encoder.encode_batch(mhc_seqs)
        pep_emb = self.esm_encoder.encode_batch(epitope_seqs)
        
        # Project
        tcr_proj = self.proj_tcr(tcr_emb)
        mhc_proj = self.proj_mhc(mhc_emb)
        pep_proj = self.proj_peptide(pep_emb)
        
        # Concatenate
        combined = torch.cat([tcr_proj, mhc_proj, pep_proj], dim=-1)
        
        # Classification logits
        logits = self.classifier(combined).squeeze(-1)
        
        # InfoNCE features
        infonce_features = self.infonce_proj(combined)
        
        return logits, infonce_features

class InfoNCELoss(nn.Module):
    def __init__(self, temperature: float = 0.07):
        super().__init__()
        self.temperature = temperature
        
    def forward(self, features: torch.Tensor, labels: torch.Tensor):
        # Normalize features
        features = F.normalize(features, dim=1)
        
        # Check if we have positive pairs
        pos_mask = (labels == 1).float()
        
        if pos_mask.sum() < 2:
            return torch.tensor(0.0, device=features.device)
        
        # Compute similarities
        sim_matrix = torch.matmul(features, features.t()) / self.temperature
        
        # InfoNCE loss for positive pairs
        pos_pairs = []
        for i in range(len(labels)):
            if labels[i] == 1:
                for j in range(i + 1, len(labels)):
                    if labels[j] == 1:
                        pos_pairs.append((i, j))
        
        if not pos_pairs:
            return torch.tensor(0.0, device=features.device)
        
        loss = 0
        for i, j in pos_pairs:
            numerator = torch.exp(sim_matrix[i, j])
            denominator = torch.sum(torch.exp(sim_matrix[i, :]))
            loss += -torch.log(numerator / (denominator + 1e-8))
        
        return loss / len(pos_pairs)

class TCRDataset(Dataset):
    def __init__(self, df: pd.DataFrame):
        self.df = df.reset_index(drop=True)
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        return {
            "cdr3": str(row["CDR3"]),
            "mhc": str(row["MHC"]),
            "epitope": str(row["Epitope"]),
            "label": float(row["label"])
        }

def collate_fn(batch):
    return {
        "cdr3": [item["cdr3"] for item in batch],
        "mhc": [item["mhc"] for item in batch],
        "epitope": [item["epitope"] for item in batch],
        "labels": torch.tensor([item["label"] for item in batch], dtype=torch.float32)
    }

class ESMFineTuneWithInfoNCE(BaseModel):
    def __init__(self, config: ExperimentConfig):
        super().__init__(config)
        self.model = None
        self.temp_scaling = TemperatureScaling()
        
    def fit(self, df_train: pd.DataFrame, df_val: pd.DataFrame):
        # Initialize model
        esm_encoder = ESMSequenceEncoder(
            self.config.esm_model_name, 
            self.config.device,
            self.config.esm_freeze_layers
        )
        
        self.model = ESMFineTuneModel(esm_encoder, self.config.d_model).to(self.config.device)
        
        # Training setup
        esm_params = []
        proj_params = []
        
        for name, param in self.model.named_parameters():
            if 'esm_encoder' in name:
                esm_params.append(param)
            else:
                proj_params.append(param)
        
        optimizer = torch.optim.AdamW([
            {'params': esm_params, 'lr': self.config.lr * 0.1},
            {'params': proj_params, 'lr': self.config.lr}
        ], weight_decay=0.01)
        
        bce_loss = nn.BCEWithLogitsLoss()
        infonce_loss = InfoNCELoss()
        
        # Data loaders
        train_dataset = TCRDataset(df_train)
        val_dataset = TCRDataset(df_val)
        train_loader = DataLoader(train_dataset, batch_size=self.config.batch_size, 
                                shuffle=True, collate_fn=collate_fn)
        val_loader = DataLoader(val_dataset, batch_size=self.config.batch_size, 
                              shuffle=False, collate_fn=collate_fn)
        
        # Training loop
        for epoch in range(self.config.epochs):
            self.model.train()
            for batch in train_loader:
                labels = batch["labels"].to(self.config.device)
                
                logits, features = self.model(batch["cdr3"], batch["mhc"], batch["epitope"])
                
                # Combined loss
                bce = bce_loss(logits, labels)
                info_nce = infonce_loss(features, labels)
                loss = bce + 0.1 * info_nce
                
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                optimizer.step()
        
        # Temperature scaling
        self.model.eval()
        val_logits = []
        val_labels = []
        
        with torch.no_grad():
            for batch in val_loader:
                labels = batch["labels"].to(self.config.device)
                logits, _ = self.model(batch["cdr3"], batch["mhc"], batch["epitope"])
                
                val_logits.extend(logits.cpu().numpy())
                val_labels.extend(labels.cpu().numpy())
        
        self.temp_scaling.fit(np.array(val_logits), np.array(val_labels))
        
    def predict(self, df_test: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]:
        test_dataset = TCRDataset(df_test)
        test_loader = DataLoader(test_dataset, batch_size=self.config.batch_size, 
                                shuffle=False, collate_fn=collate_fn)
        
        all_logits = []
        self.model.eval()
        
        with torch.no_grad():
            for batch in test_loader:
                logits, _ = self.model(batch["cdr3"], batch["mhc"], batch["epitope"])
                all_logits.extend(logits.cpu().numpy())
        
        logits = np.array(all_logits)
        probs = self.temp_scaling.apply(logits)
        preds = (probs > 0.5).astype(int)
        
        return preds, probs

# --------------------------
# Other Main Models (B2-B8) - Simplified versions
# --------------------------

class SimpleFeatureModel(BaseModel):
    """Simple feature-based model for B2, B7, B8"""
    def __init__(self, config: ExperimentConfig, model_type: str = "lr"):
        super().__init__(config)
        self.model_type = model_type
        self.classifier = None
        self.temp_scaling = TemperatureScaling()
        
    def fit(self, df_train: pd.DataFrame, df_val: pd.DataFrame):
        train_features = self._extract_features(df_train)
        val_features = self._extract_features(df_val)
        
        train_labels = df_train["label"].values
        val_labels = df_val["label"].values
        
        if self.model_type == "lr":
            self.classifier = LogisticRegression(max_iter=1000, random_state=self.config.seed)
        elif self.model_type == "rf":
            self.classifier = RandomForestClassifier(n_estimators=100, random_state=self.config.seed)
        elif self.model_type == "svm":
            self.classifier = SVC(probability=True, random_state=self.config.seed)
        
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore")
            self.classifier.fit(train_features, train_labels)
        
        # Temperature scaling
        if hasattr(self.classifier, 'decision_function'):
            val_logits = self.classifier.decision_function(val_features)
        else:
            val_probs = self.classifier.predict_proba(val_features)[:, 1]
            val_logits = np.log(val_probs / (1 - val_probs + 1e-8))
        
        self.temp_scaling.fit(val_logits, val_labels)
        
    def _extract_features(self, df: pd.DataFrame) -> np.ndarray:
        """Extract simple sequence features"""
        features = []
        for _, row in df.iterrows():
            cdr3_len = len(row["CDR3"])
            epitope_len = len(row["Epitope"])
            mhc_type = hash(row["MHC"]) % 10  # Simple MHC encoding
            
            # Amino acid composition
            cdr3_aa_counts = [row["CDR3"].count(aa) / len(row["CDR3"]) for aa in "ACDEFGHIKLMNPQRSTVWY"]
            epitope_aa_counts = [row["Epitope"].count(aa) / len(row["Epitope"]) for aa in "ACDEFGHIKLMNPQRSTVWY"]
            
            # Simple physicochemical properties
            hydrophobic_aas = set("AILMFPWV")
            charged_aas = set("DEKR")
            
            cdr3_hydrophobic = sum(1 for aa in row["CDR3"] if aa in hydrophobic_aas) / len(row["CDR3"])
            cdr3_charged = sum(1 for aa in row["CDR3"] if aa in charged_aas) / len(row["CDR3"])
            epitope_hydrophobic = sum(1 for aa in row["Epitope"] if aa in hydrophobic_aas) / len(row["Epitope"])
            epitope_charged = sum(1 for aa in row["Epitope"] if aa in charged_aas) / len(row["Epitope"])
            
            feature_vector = ([cdr3_len, epitope_len, mhc_type, cdr3_hydrophobic, cdr3_charged, 
                             epitope_hydrophobic, epitope_charged] + cdr3_aa_counts + epitope_aa_counts)
            features.append(feature_vector)
        
        return np.array(features)
    
    def predict(self, df_test: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]:
        test_features = self._extract_features(df_test)
        
        if hasattr(self.classifier, 'decision_function'):
            logits = self.classifier.decision_function(test_features)
        else:
            probs = self.classifier.predict_proba(test_features)[:, 1]
            logits = np.log(probs / (1 - probs + 1e-8))
        
        probs = self.temp_scaling.apply(logits)
        preds = (probs > 0.5).astype(int)
        
        return preds, probs

# --------------------------
# BX1: ProtBert-based Model (Auxiliary)
# --------------------------

class ProtBertModel(BaseModel):
    """ProtBert-based model as auxiliary method"""
    def __init__(self, config: ExperimentConfig):
        super().__init__(config)
        self.model = None
        self.temp_scaling = TemperatureScaling()
        
    def fit(self, df_train: pd.DataFrame, df_val: pd.DataFrame):
        # For demo, use simplified ProtBert-like features
        # In practice, you would load actual ProtBert model
        train_features = self._extract_protbert_features(df_train)
        val_features = self._extract_protbert_features(df_val)
        
        train_labels = df_train["label"].values
        val_labels = df_val["label"].values
        
        self.classifier = LogisticRegression(max_iter=1000, random_state=self.config.seed)
        self.classifier.fit(train_features, train_labels)
        
        val_logits = self.classifier.decision_function(val_features)
        self.temp_scaling.fit(val_logits, val_labels)
        
    def _extract_protbert_features(self, df: pd.DataFrame) -> np.ndarray:
        """Extract ProtBert-like features (simplified)"""
        features = []
        for _, row in df.iterrows():
            # Simulate ProtBert embeddings with enhanced features
            cdr3_seq = row["CDR3"]
            epitope_seq = row["Epitope"]
            mhc_seq = get_mhc_sequence(row["MHC"])
            
            # More sophisticated features than simple model
            feature_vector = []
            
            # Length features
            feature_vector.extend([len(cdr3_seq), len(epitope_seq), len(mhc_seq)])
            
            # N-gram features (simulate BERT-like attention to subsequences)
            for seq in [cdr3_seq, epitope_seq, mhc_seq[:20]]:  # Truncate MHC
                for k in [2, 3]:
                    ngrams = [seq[i:i+k] for i in range(len(seq)-k+1)]
                    ngram_hash = sum(hash(ngram) % 1000 for ngram in ngrams) / 1000  # Normalize
                    feature_vector.append(ngram_hash)
            
            # Positional encoding simulation
            for i, aa in enumerate(cdr3_seq[:20]):  # Truncate
                pos_encoding = math.sin(i / 10000) if aa in "ACDEFGHIKLMNPQRSTVWY" else 0
                feature_vector.append(pos_encoding)
            
            # Pad to fixed length
            while len(feature_vector) < 100:
                feature_vector.append(0.0)
            
            features.append(feature_vector[:100])
        
        return np.array(features)
    
    def predict(self, df_test: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]:
        test_features = self._extract_protbert_features(df_test)
        logits = self.classifier.decision_function(test_features)
        probs = self.temp_scaling.apply(logits)
        preds = (probs > 0.5).astype(int)
        return preds, probs

# --------------------------
# BX2: Simple Graph Neural Network (Auxiliary)
# --------------------------

class SimpleGraphModel(BaseModel):
    """Simple graph-based model using sequence similarity"""
    def __init__(self, config: ExperimentConfig):
        super().__init__(config)
        self.train_data = None
        self.similarity_threshold = 0.7
        
    def fit(self, df_train: pd.DataFrame, df_val: pd.DataFrame):
        self.train_data = df_train.copy()
        
    def predict(self, df_test: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]:
        predictions = []
        probabilities = []
        
        for _, test_row in df_test.iterrows():
            test_tcr = test_row["CDR3"]
            test_epitope = test_row["Epitope"]
            
            # Find similar examples in training data
            similar_scores = []
            
            for _, train_row in self.train_data.iterrows():
                tcr_sim = compute_sequence_similarity(test_tcr, train_row["CDR3"])
                epitope_sim = compute_sequence_similarity(test_epitope, train_row["Epitope"])
                
                # Graph-like aggregation: require both TCR and epitope similarity
                combined_sim = (tcr_sim + epitope_sim) / 2
                
                if combined_sim > self.similarity_threshold:
                    similar_scores.append(train_row["label"])
            
            if similar_scores:
                prob = np.mean(similar_scores)
            else:
                prob = 0.5  # Default for no similar examples
            
            pred = 1 if prob > 0.5 else 0
            predictions.append(pred)
            probabilities.append(prob)
        
        return np.array(predictions), np.array(probabilities)

# --------------------------
# BX3: Simple VAE-like Model (Auxiliary)
# --------------------------

class SimpleVAEModel(BaseModel):
    """Simple VAE-inspired model for sequence embedding"""
    def __init__(self, config: ExperimentConfig):
        super().__init__(config)
        self.encoder_tcr = None
        self.encoder_epitope = None
        self.classifier = None
        self.temp_scaling = TemperatureScaling()
        
    def fit(self, df_train: pd.DataFrame, df_val: pd.DataFrame):
        # Extract latent representations
        train_features = self._extract_vae_features(df_train)
        val_features = self._extract_vae_features(df_val)
        
        train_labels = df_train["label"].values
        val_labels = df_val["label"].values
        
        self.classifier = LogisticRegression(max_iter=1000, random_state=self.config.seed)
        self.classifier.fit(train_features, train_labels)
        
        val_logits = self.classifier.decision_function(val_features)
        self.temp_scaling.fit(val_logits, val_labels)
        
    def _extract_vae_features(self, df: pd.DataFrame) -> np.ndarray:
        """Extract VAE-like latent features"""
        features = []
        for _, row in df.iterrows():
            # Simulate VAE latent space with statistical moments
            cdr3_seq = row["CDR3"]
            epitope_seq = row["Epitope"]
            
            # Mean and variance-like features (simulate mu and sigma from VAE)
            cdr3_ascii = [ord(aa) for aa in cdr3_seq]
            epitope_ascii = [ord(aa) for aa in epitope_seq]
            
            # Statistical moments as latent features
            feature_vector = [
                np.mean(cdr3_ascii), np.std(cdr3_ascii),
                np.mean(epitope_ascii), np.std(epitope_ascii),
                np.median(cdr3_ascii), np.median(epitope_ascii),
                np.max(cdr3_ascii) - np.min(cdr3_ascii),  # Range
                np.max(epitope_ascii) - np.min(epitope_ascii),
                len(cdr3_seq), len(epitope_seq)
            ]
            
            # Add "learned" latent dimensions (random projections as simulation)
            for i in range(self.config.vae_latent_dim - 10):
                projection = sum(hash(f"{cdr3_seq}_{epitope_seq}_{i}") % 100 for _ in range(1)) / 100
                feature_vector.append(projection)
            
            features.append(feature_vector[:self.config.vae_latent_dim])
        
        return np.array(features)
    
    def predict(self, df_test: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]:
        test_features = self._extract_vae_features(df_test)
        logits = self.classifier.decision_function(test_features)
        probs = self.temp_scaling.apply(logits)
        preds = (probs > 0.5).astype(int)
        return preds, probs

# --------------------------
# Experiment Runner
# --------------------------

class ExperimentRunner:
    def __init__(self, config: ExperimentConfig):
        self.config = config
        self.data_manager = DataManager(config)
        self.evaluator = Evaluator()
        
    def run_all_experiments(self) -> pd.DataFrame:
        """Run all experiments and return results table"""
        print("Loading and preparing data...")
        self.data_manager.load_and_split_data()
        
        # Complete model set: Main Table + Auxiliary Methods
        models = {
            # Main Models
            "B1_ESM_InfoNCE_TempScale": ESMFineTuneWithInfoNCE,  # Our main model
            "B2_Simple_Features_LR": lambda config: SimpleFeatureModel(config, "lr"),
            "B3_Simple_Features_RF": lambda config: SimpleFeatureModel(config, "rf"),
            "B4_Simple_Features_SVM": lambda config: SimpleFeatureModel(config, "svm"),
            
            # Auxiliary Models (BX series)
            "BX1_ProtBert_Like": ProtBertModel,
            "BX2_Simple_Graph": SimpleGraphModel,
            "BX3_VAE_Like": SimpleVAEModel,
        }
        
        results = []
        
        for model_name, model_class in models.items():
            print(f"\n{'='*60}")
            print(f"Running {model_name}...")
            print(f"{'='*60}")
            
            model_results = []
            
            for run in range(self.config.n_runs):
                print(f"Run {run + 1}/{self.config.n_runs}")
                
                # Set seed for reproducibility
                set_seed(self.config.seed + run)
                
                try:
                    # Initialize and train model
                    model = model_class(self.config)
                    model.fit(self.data_manager.df_train, self.data_manager.df_val)
                    
                    # Predict on test set
                    y_pred, y_prob = model.predict(self.data_manager.df_test)
                    y_true = self.data_manager.df_test["label"].values
                    
                    # Compute metrics
                    metrics = self.evaluator.compute_metrics(y_true, y_pred, y_prob)
                    model_results.append(metrics)
                    
                    print(f"  AUC: {metrics['auc']:.4f}, AUPRC: {metrics['auprc']:.4f}, Acc: {metrics['accuracy']:.4f}")
                    
                except Exception as e:
                    print(f"  Error in {model_name} run {run + 1}: {str(e)}")
                    import traceback
                    traceback.print_exc()
                    continue
            
            if model_results:
                # Aggregate results
                aggregated = self.evaluator.aggregate_metrics(model_results)
                
                for metric, stats in aggregated.items():
                    results.append({
                        "Model": model_name,
                        "Metric": metric.upper(),
                        "Mean": stats["mean"],
                        "Std": stats["std"],
                        "CI_Lower": stats["ci_lower"],
                        "CI_Upper": stats["ci_upper"],
                        "CI_String": f"{stats['mean']:.4f} ± {1.96 * stats['std'] / np.sqrt(self.config.n_runs):.4f}"
                    })
            else:
                print(f"  No successful runs for {model_name}")
        
        return pd.DataFrame(results)
    
    def save_results(self, results_df: pd.DataFrame, output_path: str = "comprehensive_results.csv"):
        """Save results to CSV and print summary table"""
        results_df.to_csv(output_path, index=False)
        
        # Create summary table
        print(f"\n{'='*120}")
        print("COMPREHENSIVE TCR-PEPTIDE-MHC BINDING PREDICTION BENCHMARK")
        print(f"{'='*120}")
        
        # Pivot table for better readability
        pivot_df = results_df.pivot(index="Model", columns="Metric", values="CI_String")
        
        # Reorder columns for better presentation
        metric_order = ["AUC", "AUPRC", "ACCURACY", "PRECISION", "RECALL"]
        available_metrics = [col for col in metric_order if col in pivot_df.columns]
        pivot_df = pivot_df[available_metrics]
        
        print(pivot_df.to_string())
        
        # Highlight best performing models
        print(f"\n{'='*60}")
        print("PERFORMANCE HIGHLIGHTS:")
        print(f"{'='*60}")
        
        auc_results = results_df[results_df["Metric"] == "AUC"].sort_values("Mean", ascending=False)
        print("Top 3 Models by AUC:")
        for i, (_, row) in enumerate(auc_results.head(3).iterrows()):
            print(f"{i+1}. {row['Model']}: {row['CI_String']}")
        
        print(f"\nDetailed results saved to: {output_path}")
        print(f"{'='*120}")

# --------------------------
# Main Function
# --------------------------

def main():
    """Main experiment runner"""
    # Configuration
    config = ExperimentConfig(
        seed=42,
        n_runs=3,
        epochs=8,  # Increased for main model
        batch_size=32,
        lr=2e-4,
        d_model=128
    )
    
    print("Starting comprehensive TCR-peptide-MHC binding prediction benchmark...")
    print(f"Device: {config.device}")
    print(f"Configuration: epochs={config.epochs}, batch_size={config.batch_size}, lr={config.lr}")
    
    # Run experiments
    runner = ExperimentRunner(config)
    results_df = runner.run_all_experiments()
    
    # Save and display results
    runner.save_results(results_df)

if __name__ == "__main__":
    main()

2025-09-20 14:18:10.964036: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-09-20 14:18:11.016180: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Starting comprehensive TCR-peptide-MHC binding prediction benchmark...
Device: cuda
Configuration: epochs=8, batch_size=32, lr=0.0002
Loading and preparing data...
Data split: Train=20162, Val=2608, Test=6593

Running B1_ESM_InfoNCE_TempScale...
Run 1/3
[Info] Loading ESM model: facebook/esm2_t12_35M_UR50D


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[Info] ESM loaded, hidden_size: 480, trainable layers: 6
  AUC: 0.6428, AUPRC: 0.2744, Acc: 0.7767
Run 2/3
[Info] Loading ESM model: facebook/esm2_t12_35M_UR50D


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[Info] ESM loaded, hidden_size: 480, trainable layers: 6
  AUC: 0.6400, AUPRC: 0.2866, Acc: 0.7767
Run 3/3
[Info] Loading ESM model: facebook/esm2_t12_35M_UR50D


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[Info] ESM loaded, hidden_size: 480, trainable layers: 6
  AUC: 0.6156, AUPRC: 0.2886, Acc: 0.7767

Running B2_Simple_Features_LR...
Run 1/3
  AUC: 0.5086, AUPRC: 0.2108, Acc: 0.7767
Run 2/3
  AUC: 0.5086, AUPRC: 0.2108, Acc: 0.7767
Run 3/3
  AUC: 0.5086, AUPRC: 0.2108, Acc: 0.7767

Running B3_Simple_Features_RF...
Run 1/3


  val_logits = np.log(val_probs / (1 - val_probs + 1e-8))
  logits = np.log(probs / (1 - probs + 1e-8))


  AUC: 0.5373, AUPRC: 0.2556, Acc: 0.7767
Run 2/3


  val_logits = np.log(val_probs / (1 - val_probs + 1e-8))
  logits = np.log(probs / (1 - probs + 1e-8))


  AUC: 0.5373, AUPRC: 0.2556, Acc: 0.7767
Run 3/3


  val_logits = np.log(val_probs / (1 - val_probs + 1e-8))
  logits = np.log(probs / (1 - probs + 1e-8))


  AUC: 0.5373, AUPRC: 0.2556, Acc: 0.7767

Running B4_Simple_Features_SVM...
Run 1/3
  AUC: 0.5829, AUPRC: 0.2685, Acc: 0.7767
Run 2/3
  AUC: 0.5829, AUPRC: 0.2685, Acc: 0.7767
Run 3/3
  AUC: 0.5829, AUPRC: 0.2685, Acc: 0.7767

Running BX1_ProtBert_Like...
Run 1/3
  AUC: 0.4028, AUPRC: 0.1767, Acc: 0.7767
Run 2/3
  AUC: 0.4028, AUPRC: 0.1767, Acc: 0.7767
Run 3/3
  AUC: 0.4028, AUPRC: 0.1767, Acc: 0.7767

Running BX2_Simple_Graph...
Run 1/3
  AUC: 0.4628, AUPRC: 0.1955, Acc: 0.7702
Run 2/3
  AUC: 0.4628, AUPRC: 0.1955, Acc: 0.7702

Running BX3_VAE_Like...
Run 1/3
  AUC: 0.5593, AUPRC: 0.2704, Acc: 0.7767
Run 2/3
  AUC: 0.5593, AUPRC: 0.2704, Acc: 0.7767
Run 3/3
  AUC: 0.5593, AUPRC: 0.2704, Acc: 0.7767

COMPREHENSIVE TCR-PEPTIDE-MHC BINDING PREDICTION BENCHMARK
Metric                                AUC            AUPRC         ACCURACY        PRECISION           RECALL
Model                                                                                                        
B1_ESM_In