In [None]:
# -*- coding: utf-8 -*-
# Ablation Study for TCR-peptide-MHC Binding Prediction
# Based on Main Model B1: ESM Fine-tuning + InfoNCE + Temperature Scaling

import os, sys, math, json, time, random, warnings
from typing import List, Dict, Tuple, Optional
from collections import defaultdict
from dataclasses import dataclass, field
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from transformers import EsmModel, EsmTokenizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score
from sklearn.exceptions import ConvergenceWarning

try:
    from tqdm import tqdm
except ImportError:
    raise ImportError("Please install tqdm: pip install tqdm")

# --------------------------
# Configuration
# --------------------------

@dataclass
class AblationConfig:
    seed: int = 42
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    n_runs: int = 3
    test_size: float = 0.2
    val_size: float = 0.1
    batch_size: int = 32
    epochs: int = 8
    base_lr: float = 2e-4
    
    # ESM config
    esm_model_name: str = "facebook/esm2_t12_35M_UR50D"
    
    # Ablation ranges
    freeze_layers_options: List[int] = field(default_factory=lambda: [0, 3, 6, 9, 12])
    infonce_weight_options: List[float] = field(default_factory=lambda: [0.0, 0.05, 0.1, 0.2])
    d_model_options: List[int] = field(default_factory=lambda: [64, 128, 256])
    calibration_options: List[str] = field(default_factory=lambda: ["none", "temperature", "platt"])
    neg_ratio_options: List[int] = field(default_factory=lambda: [2, 4, 6, 8])
    mhc_encoding_options: List[str] = field(default_factory=lambda: ["pseudo_seq", "simple_embed"])

AA_STANDARD = list("ACDEFGHIKLMNPQRSTVWY")
AA_SET = set(AA_STANDARD)

# MHC pseudo-sequences
MHC_PSEUDO_SEQUENCES = {
    "HLA-A*02:01": "GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRMEPRAPWIEQEGPEYWDGETRKVKAHSQTHRVDLGTLRGYYNQSEAGSHTVQRMYGCDVGSDWRFLRGYHQYAYDGKDYIALKEDLRSWTAADMAAQTTKHKWEAAHVAEQLRAYLEGTCVEWLRRYLENGKETLQRTDAPKTHMTHHAVSDHEATLRCWALSFYPAEITLTWQRDGEDQTQDTELVETRPAGDGTFQKWAAVVVPSGQEQRYTCHVQHEGLPKPLTLRWE",
    "HLA-A*01:01": "GSHSMRYFFTSVSRPGRGEPRFIAMGYVDDTQFVRFDSDAASQKMEPRAPWIEQEGPEYWDRETQKAKGNEQSFRVDLRTLLGYYNQSEDGSHTIQIMYGCDVGPDGRLLRGYDQYAYDGKDYIALNEDLRSWTAADTAAQITQRKWEAARVAEQLRAYLEGTCVEWLRRYLENGKDKLERADPPKTHVTHHPISDHEATLRCWALGFYPAEITLTWQRDGEDQTQDTELVETRPAGDGTFQKWAAVVVPSGEEQRYTCHVQHEGLPKPLTLRWE",
    "HLA-B*07:02": "GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQKMEPRAPWIEQEGPEYWDRETQKAKGNEQSFRVDLRTLLGYYNQSEDGSHTIQIMYGCDVGPDGRLLRGYDQYAYDGKDYIALNEDLRSWTAADTAAQITQRKWEAARVAEQLRAYLEGTCVEWLRRYLENGKDKLERADPPKTHVTHHPISDHEATLRCWALGFYPAEITLTWQRDGEDQTQDTELVETRPAGDGTFQKWAAVVVPSGEEQRYTCHVQHEGLPKPLTLRWE"
}

def get_mhc_sequence(mhc_name: str, encoding_type: str = "pseudo_seq") -> str:
    if encoding_type == "pseudo_seq":
        return MHC_PSEUDO_SEQUENCES.get(mhc_name, MHC_PSEUDO_SEQUENCES["HLA-A*02:01"])
    else:  # simple_embed
        # Simple embedding: just use a short representative sequence
        mhc_to_simple = {
            "HLA-A*02:01": "GSHSMRYFFTSV",
            "HLA-A*01:01": "GSHSMRYFFAMT", 
            "HLA-B*07:02": "GSHSMRYFFBPT"
        }
        return mhc_to_simple.get(mhc_name, "GSHSMRYFFDEF")

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

# --------------------------
# Data Management
# --------------------------

class AblationDataManager:
    def __init__(self, config: AblationConfig):
        self.config = config
        self.df_train = None
        self.df_val = None
        self.df_test = None
        
    def load_and_split_data(self, data_path: str = "data.csv", neg_ratio: int = 4):
        """Load data and create train/val/test splits with specified negative ratio"""
        if not os.path.exists(data_path):
            raise FileNotFoundError(f"Data file {data_path} not found. Please ensure data is available.")
        
        df = pd.read_csv(data_path)
        df = self._preprocess_data(df)
        df_with_negs = self._build_negatives(df, k_neg=neg_ratio)
        self._split_data(df_with_negs)
        
    def _preprocess_data(self, df: pd.DataFrame) -> pd.DataFrame:
        """Clean and preprocess data"""
        df = df.copy()
        df["CDR3"] = df["CDR3"].astype(str).str.upper()
        df["MHC"] = df["MHC"].astype(str).str.upper()
        df["Epitope"] = df["Epitope"].astype(str).str.upper()
        df = df.dropna().drop_duplicates()
        df = df[df["Epitope"].map(lambda p: all(ch in AA_SET for ch in p))]
        df = df[df["CDR3"].map(lambda p: all(ch in AA_SET for ch in p))]
        return df
    
    def _build_negatives(self, df: pd.DataFrame, k_neg: int = 4) -> pd.DataFrame:
        """Generate negative samples"""
        df = df.copy()
        df["label"] = 1
        
        # Group peptides by length
        peps_by_len = defaultdict(list)
        for pep in df["Epitope"].unique():
            peps_by_len[len(pep)].append(pep)
        
        negatives = []
        for _, row in df.iterrows():
            cdr3, mhc, pep = row["CDR3"], row["MHC"], row["Epitope"]
            
            # Same-length different peptides
            candidates = peps_by_len[len(pep)]
            neg_peps = [p for p in candidates if p != pep]
            
            if len(neg_peps) >= k_neg:
                selected_negs = random.sample(neg_peps, k_neg)
            else:
                selected_negs = neg_peps + [self._mutate_peptide(pep) for _ in range(k_neg - len(neg_peps))]
            
            for neg_pep in selected_negs:
                negatives.append({"CDR3": cdr3, "MHC": mhc, "Epitope": neg_pep, "label": 0})
            
            # Different TCR for same peptide
            other_tcrs = df[df["Epitope"] != pep]["CDR3"].unique()
            if len(other_tcrs) > 0:
                neg_tcr = random.choice(other_tcrs)
                negatives.append({"CDR3": neg_tcr, "MHC": mhc, "Epitope": pep, "label": 0})
        
        df_neg = pd.DataFrame(negatives)
        return pd.concat([df, df_neg], ignore_index=True).drop_duplicates()
    
    def _mutate_peptide(self, peptide: str, n_mut: int = 1) -> str:
        """Mutate peptide for negative sampling"""
        s = list(peptide)
        for _ in range(n_mut):
            pos = random.randint(0, len(s) - 1)
            original = s[pos]
            s[pos] = random.choice([aa for aa in AA_STANDARD if aa != original])
        return "".join(s)
    
    def _split_data(self, df: pd.DataFrame):
        """Split data by epitope to avoid leakage"""
        positive_epitopes = df[df["label"] == 1]["Epitope"].unique()
        random.shuffle(positive_epitopes)
        
        n_test = max(1, int(len(positive_epitopes) * self.config.test_size))
        n_val = max(1, int(len(positive_epitopes) * self.config.val_size))
        
        test_epitopes = set(positive_epitopes[:n_test])
        val_epitopes = set(positive_epitopes[n_test:n_test + n_val])
        train_epitopes = set(positive_epitopes[n_test + n_val:])
        
        self.df_test = df[df["Epitope"].isin(test_epitopes)].reset_index(drop=True)
        self.df_val = df[df["Epitope"].isin(val_epitopes)].reset_index(drop=True)
        self.df_train = df[df["Epitope"].isin(train_epitopes)].reset_index(drop=True)

# --------------------------
# Model Components
# --------------------------

class ESMSequenceEncoder:
    def __init__(self, model_name: str, device: str, freeze_layers: int = 6):
        self.device = device
        self.tokenizer = EsmTokenizer.from_pretrained(model_name)
        self.model = EsmModel.from_pretrained(model_name).to(device)
        
        # Selective unfreezing
        total_layers = len(self.model.encoder.layer)
        for i, layer in enumerate(self.model.encoder.layer):
            if i < freeze_layers:
                for param in layer.parameters():
                    param.requires_grad = False
            else:
                for param in layer.parameters():
                    param.requires_grad = True
        
        # Always fine-tune embeddings
        for param in self.model.embeddings.parameters():
            param.requires_grad = True
        if hasattr(self.model, 'pooler') and self.model.pooler:
            for param in self.model.pooler.parameters():
                param.requires_grad = True
        
        self.hidden_size = self.model.config.hidden_size
        
    def encode_batch(self, sequences: List[str], max_length: int = 512) -> torch.Tensor:
        clean_seqs = []
        for seq in sequences:
            clean_seq = "".join([c for c in seq.upper() if c in AA_SET])
            clean_seqs.append(clean_seq if clean_seq else "A")
        
        inputs = self.tokenizer(
            clean_seqs, return_tensors="pt", padding=True, 
            truncation=True, max_length=max_length
        ).to(self.device)
        
        outputs = self.model(**inputs)
        return outputs.last_hidden_state[:, 0, :]  # [CLS] token

class AblationModel(nn.Module):
    def __init__(self, esm_encoder: ESMSequenceEncoder, d_model: int = 128, dropout: float = 0.1):
        super().__init__()
        self.esm_encoder = esm_encoder
        self.d_model = d_model
        
        # Projectors
        self.proj_tcr = nn.Sequential(
            nn.Linear(esm_encoder.hidden_size, d_model),
            nn.LayerNorm(d_model),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        self.proj_mhc = nn.Sequential(
            nn.Linear(esm_encoder.hidden_size, d_model),
            nn.LayerNorm(d_model),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        self.proj_peptide = nn.Sequential(
            nn.Linear(esm_encoder.hidden_size, d_model),
            nn.LayerNorm(d_model),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(d_model * 3, d_model),
            nn.LayerNorm(d_model),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, 1)
        )
        
        # InfoNCE projector
        self.infonce_proj = nn.Linear(d_model * 3, d_model)
        
    def forward(self, cdr3_seqs: List[str], mhc_alleles: List[str], epitope_seqs: List[str], mhc_encoding: str = "pseudo_seq"):
        # Convert MHC alleles to sequences
        mhc_seqs = [get_mhc_sequence(allele, mhc_encoding) for allele in mhc_alleles]
        
        # Encode
        tcr_emb = self.esm_encoder.encode_batch(cdr3_seqs)
        mhc_emb = self.esm_encoder.encode_batch(mhc_seqs)
        pep_emb = self.esm_encoder.encode_batch(epitope_seqs)
        
        # Project
        tcr_proj = self.proj_tcr(tcr_emb)
        mhc_proj = self.proj_mhc(mhc_emb)
        pep_proj = self.proj_peptide(pep_emb)
        
        # Concatenate
        combined = torch.cat([tcr_proj, mhc_proj, pep_proj], dim=-1)
        
        # Classification logits
        logits = self.classifier(combined).squeeze(-1)
        
        # InfoNCE features
        infonce_features = self.infonce_proj(combined)
        
        return logits, infonce_features

class InfoNCELoss(nn.Module):
    def __init__(self, temperature: float = 0.07):
        super().__init__()
        self.temperature = temperature
        
    def forward(self, features: torch.Tensor, labels: torch.Tensor):
        features = F.normalize(features, dim=1)
        pos_mask = (labels == 1).float()
        
        if pos_mask.sum() < 2:
            return torch.tensor(0.0, device=features.device)
        
        sim_matrix = torch.matmul(features, features.t()) / self.temperature
        
        pos_pairs = []
        for i in range(len(labels)):
            if labels[i] == 1:
                for j in range(i + 1, len(labels)):
                    if labels[j] == 1:
                        pos_pairs.append((i, j))
        
        if not pos_pairs:
            return torch.tensor(0.0, device=features.device)
        
        loss = 0
        for i, j in pos_pairs:
            numerator = torch.exp(sim_matrix[i, j])
            denominator = torch.sum(torch.exp(sim_matrix[i, :]))
            loss += -torch.log(numerator / (denominator + 1e-8))
        
        return loss / len(pos_pairs)

class TemperatureScaling:
    def __init__(self):
        self.temperature = 1.0
    
    def fit(self, logits: np.ndarray, labels: np.ndarray):
        try:
            from scipy.optimize import minimize_scalar
            
            def nll(temp):
                if temp <= 0:
                    return 1e6
                scaled_logits = logits / temp
                probs = 1 / (1 + np.exp(-scaled_logits))
                probs = np.clip(probs, 1e-7, 1 - 1e-7)
                return -np.mean(labels * np.log(probs) + (1 - labels) * np.log(1 - probs))
            
            result = minimize_scalar(nll, bounds=(0.1, 10.0), method='bounded')
            self.temperature = result.x
        except:
            self.temperature = 1.0
    
    def apply(self, logits: np.ndarray) -> np.ndarray:
        return 1 / (1 + np.exp(-logits / self.temperature))

class PlattScaling:
    def __init__(self):
        self.classifier = None
    
    def fit(self, logits: np.ndarray, labels: np.ndarray):
        self.classifier = LogisticRegression()
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=ConvergenceWarning)
            self.classifier.fit(logits.reshape(-1, 1), labels)
    
    def apply(self, logits: np.ndarray) -> np.ndarray:
        if self.classifier is None:
            return 1 / (1 + np.exp(-logits))
        return self.classifier.predict_proba(logits.reshape(-1, 1))[:, 1]

# --------------------------
# Dataset
# --------------------------

class TCRDataset(Dataset):
    def __init__(self, df: pd.DataFrame):
        self.df = df.reset_index(drop=True)
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        return {
            "cdr3": str(row["CDR3"]),
            "mhc": str(row["MHC"]),
            "epitope": str(row["Epitope"]),
            "label": float(row["label"])
        }

def collate_fn(batch):
    return {
        "cdr3": [item["cdr3"] for item in batch],
        "mhc": [item["mhc"] for item in batch],
        "epitope": [item["epitope"] for item in batch],
        "labels": torch.tensor([item["label"] for item in batch], dtype=torch.float32)
    }

# --------------------------
# Ablation Experiment Runner
# --------------------------

class AblationExperiment:
    def __init__(self, config: AblationConfig):
        self.config = config
        self.data_manager = AblationDataManager(config)
        
    def run_single_experiment(self, freeze_layers: int, infonce_weight: float, d_model: int, 
                            calibration: str, neg_ratio: int, mhc_encoding: str) -> Dict[str, float]:
        """Run single experiment with specified parameters"""
        
        # Load data with specified negative ratio
        self.data_manager.load_and_split_data(neg_ratio=neg_ratio)
        
        # Initialize model
        esm_encoder = ESMSequenceEncoder(
            self.config.esm_model_name, 
            self.config.device,
            freeze_layers
        )
        
        model = AblationModel(esm_encoder, d_model).to(self.config.device)
        
        # Training setup
        esm_params = []
        proj_params = []
        
        for name, param in model.named_parameters():
            if 'esm_encoder' in name:
                esm_params.append(param)
            else:
                proj_params.append(param)
        
        optimizer = torch.optim.AdamW([
            {'params': esm_params, 'lr': self.config.base_lr * 0.1},
            {'params': proj_params, 'lr': self.config.base_lr}
        ], weight_decay=0.01)
        
        bce_loss = nn.BCEWithLogitsLoss()
        infonce_loss = InfoNCELoss()
        
        # Data loaders
        train_dataset = TCRDataset(self.data_manager.df_train)
        val_dataset = TCRDataset(self.data_manager.df_val)
        test_dataset = TCRDataset(self.data_manager.df_test)
        
        train_loader = DataLoader(train_dataset, batch_size=self.config.batch_size, 
                                shuffle=True, collate_fn=collate_fn)
        val_loader = DataLoader(val_dataset, batch_size=self.config.batch_size, 
                              shuffle=False, collate_fn=collate_fn)
        test_loader = DataLoader(test_dataset, batch_size=self.config.batch_size, 
                               shuffle=False, collate_fn=collate_fn)
        
        # Training loop
        model.train()
        for epoch in range(self.config.epochs):
            for batch in train_loader:
                labels = batch["labels"].to(self.config.device)
                
                logits, features = model(batch["cdr3"], batch["mhc"], batch["epitope"], mhc_encoding)
                
                # Combined loss
                bce = bce_loss(logits, labels)
                info_nce = infonce_loss(features, labels) if infonce_weight > 0 else torch.tensor(0.0)
                loss = bce + infonce_weight * info_nce
                
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
        
        # Calibration
        model.eval()
        val_logits = []
        val_labels = []
        
        with torch.no_grad():
            for batch in val_loader:
                labels = batch["labels"].to(self.config.device)
                logits, _ = model(batch["cdr3"], batch["mhc"], batch["epitope"], mhc_encoding)
                
                val_logits.extend(logits.cpu().numpy())
                val_labels.extend(labels.cpu().numpy())
        
        val_logits = np.array(val_logits)
        val_labels = np.array(val_labels)
        
        if calibration == "temperature":
            calibrator = TemperatureScaling()
            calibrator.fit(val_logits, val_labels)
        elif calibration == "platt":
            calibrator = PlattScaling()
            calibrator.fit(val_logits, val_labels)
        else:  # none
            calibrator = None
        
        # Test evaluation
        test_logits = []
        test_labels = []
        
        with torch.no_grad():
            for batch in test_loader:
                labels = batch["labels"].to(self.config.device)
                logits, _ = model(batch["cdr3"], batch["mhc"], batch["epitope"], mhc_encoding)
                
                test_logits.extend(logits.cpu().numpy())
                test_labels.extend(labels.cpu().numpy())
        
        test_logits = np.array(test_logits)
        test_labels = np.array(test_labels)
        
        # Apply calibration
        if calibrator:
            test_probs = calibrator.apply(test_logits)
        else:
            test_probs = 1 / (1 + np.exp(-test_logits))
        
        test_preds = (test_probs > 0.5).astype(int)
        
        # Compute metrics
        metrics = {
            "auc": roc_auc_score(test_labels, test_probs),
            "auprc": average_precision_score(test_labels, test_probs),
            "accuracy": accuracy_score(test_labels, test_preds),
            "precision": ((test_preds == 1) & (test_labels == 1)).sum() / max(1, (test_preds == 1).sum()),
            "recall": ((test_preds == 1) & (test_labels == 1)).sum() / max(1, (test_labels == 1).sum()),
        }
        
        return metrics
    
    def run_ablation_studies(self) -> pd.DataFrame:
        """Run comprehensive ablation studies"""
        
        results = []
        
        # 1. ESM Freezing Strategy Ablation
        print("=" * 60)
        print("1. ESM Freezing Strategy Ablation")
        print("=" * 60)
        
        base_params = {
            "infonce_weight": 0.1,
            "d_model": 128,
            "calibration": "temperature",
            "neg_ratio": 4,
            "mhc_encoding": "pseudo_seq"
        }
        
        for freeze_layers in self.config.freeze_layers_options:
            print(f"Testing freeze_layers = {freeze_layers}")
            
            run_metrics = []
            for run in range(self.config.n_runs):
                set_seed(self.config.seed + run)
                metrics = self.run_single_experiment(freeze_layers=freeze_layers, **base_params)
                run_metrics.append(metrics)
            
            # Aggregate metrics
            for metric_name in run_metrics[0].keys():
                values = [m[metric_name] for m in run_metrics]
                mean_val = np.mean(values)
                std_val = np.std(values)
                
                results.append({
                    "ablation_type": "freeze_layers",
                    "parameter": f"freeze_{freeze_layers}",
                    "metric": metric_name,
                    "mean": mean_val,
                    "std": std_val,
                    "ci_string": f"{mean_val:.4f} ± {1.96 * std_val / np.sqrt(self.config.n_runs):.4f}"
                })
        
        # 2. InfoNCE Weight Ablation
        print("=" * 60)
        print("2. InfoNCE Weight Ablation")
        print("=" * 60)
        
        base_params["freeze_layers"] = 6  # Reset to best from previous ablation
        
        for infonce_weight in self.config.infonce_weight_options:
            print(f"Testing infonce_weight = {infonce_weight}")
            
            run_metrics = []
            for run in range(self.config.n_runs):
                set_seed(self.config.seed + run)
                params = base_params.copy()
                params["infonce_weight"] = infonce_weight
                metrics = self.run_single_experiment(**params)
                run_metrics.append(metrics)
            
            # Aggregate metrics
            for metric_name in run_metrics[0].keys():
                values = [m[metric_name] for m in run_metrics]
                mean_val = np.mean(values)
                std_val = np.std(values)
                
                results.append({
                    "ablation_type": "infonce_weight",
                    "parameter": f"weight_{infonce_weight}",
                    "metric": metric_name,
                    "mean": mean_val,
                    "std": std_val,
                    "ci_string": f"{mean_val:.4f} ± {1.96 * std_val / np.sqrt(self.config.n_runs):.4f}"
                })
        
        # 3. Model Dimension Ablation
        print("=" * 60)
        print("3. Model Dimension Ablation")
        print("=" * 60)
        
        for d_model in self.config.d_model_options:
            print(f"Testing d_model = {d_model}")
            
            run_metrics = []
            for run in range(self.config.n_runs):
                set_seed(self.config.seed + run)
                params = base_params.copy()
                params["d_model"] = d_model
                metrics = self.run_single_experiment(**params)
                run_metrics.append(metrics)
            
            # Aggregate metrics
            for metric_name in run_metrics[0].keys():
                values = [m[metric_name] for m in run_metrics]
                mean_val = np.mean(values)
                std_val = np.std(values)
                
                results.append({
                    "ablation_type": "d_model",
                    "parameter": f"dim_{d_model}",
                    "metric": metric_name,
                    "mean": mean_val,
                    "std": std_val,
                    "ci_string": f"{mean_val:.4f} ± {1.96 * std_val / np.sqrt(self.config.n_runs):.4f}"
                })
        
        # 4. Calibration Method Ablation
        print("=" * 60)
        print("4. Calibration Method Ablation")
        print("=" * 60)
        
        for calibration in self.config.calibration_options:
            print(f"Testing calibration = {calibration}")
            
            run_metrics = []
            for run in range(self.config.n_runs):
                set_seed(self.config.seed + run)
                params = base_params.copy()
                params["calibration"] = calibration
                metrics = self.run_single_experiment(**params)
                run_metrics.append(metrics)
            
            # Aggregate metrics
            for metric_name in run_metrics[0].keys():
                values = [m[metric_name] for m in run_metrics]
                mean_val = np.mean(values)
                std_val = np.std(values)
                
                results.append({
                    "ablation_type": "calibration",
                    "parameter": f"calib_{calibration}",
                    "metric": metric_name,
                    "mean": mean_val,
                    "std": std_val,
                    "ci_string": f"{mean_val:.4f} ± {1.96 * std_val / np.sqrt(self.config.n_runs):.4f}"
                })
        
        # 5. Negative Sampling Ratio Ablation
        print("=" * 60)
        print("5. Negative Sampling Ratio Ablation")
        print("=" * 60)
        
        for neg_ratio in self.config.neg_ratio_options:
            print(f"Testing neg_ratio = {neg_ratio}")
            
            run_metrics = []
            for run in range(self.config.n_runs):
                set_seed(self.config.seed + run)
                params = base_params.copy()
                params["neg_ratio"] = neg_ratio
                metrics = self.run_single_experiment(**params)
                run_metrics.append(metrics)
            
            # Aggregate metrics
            for metric_name in run_metrics[0].keys():
                values = [m[metric_name] for m in run_metrics]
                mean_val = np.mean(values)
                std_val = np.std(values)
                
                results.append({
                    "ablation_type": "neg_ratio",
                    "parameter": f"ratio_{neg_ratio}",
                    "metric": metric_name,
                    "mean": mean_val,
                    "std": std_val,
                    "ci_string": f"{mean_val:.4f} ± {1.96 * std_val / np.sqrt(self.config.n_runs):.4f}"
                })
        
        # 6. MHC Encoding Ablation
        print("=" * 60)
        print("6. MHC Encoding Ablation")
        print("=" * 60)
        
        for mhc_encoding in self.config.mhc_encoding_options:
            print(f"Testing mhc_encoding = {mhc_encoding}")
            
            run_metrics = []
            for run in range(self.config.n_runs):
                set_seed(self.config.seed + run)
                params = base_params.copy()
                params["mhc_encoding"] = mhc_encoding
                metrics = self.run_single_experiment(**params)
                run_metrics.append(metrics)
            
            # Aggregate metrics
            for metric_name in run_metrics[0].keys():
                values = [m[metric_name] for m in run_metrics]
                mean_val = np.mean(values)
                std_val = np.std(values)
                
                results.append({
                    "ablation_type": "mhc_encoding",
                    "parameter": f"enc_{mhc_encoding}",
                    "metric": metric_name,
                    "mean": mean_val,
                    "std": std_val,
                    "ci_string": f"{mean_val:.4f} ± {1.96 * std_val / np.sqrt(self.config.n_runs):.4f}"
                })
        
        return pd.DataFrame(results)
    
    def save_results(self, results_df: pd.DataFrame, output_path: str = "ablation_results.csv"):
        """Save results to CSV and print summary"""
        results_df.to_csv(output_path, index=False)
        
        print(f"\n{'='*120}")
        print("ABLATION STUDY RESULTS SUMMARY")
        print(f"{'='*120}")
        
        # Show results by ablation type
        for ablation_type in results_df["ablation_type"].unique():
            print(f"\n{ablation_type.upper()} ABLATION:")
            print("-" * 60)
            
            subset = results_df[results_df["ablation_type"] == ablation_type]
            pivot = subset.pivot(index="parameter", columns="metric", values="ci_string")
            
            # Reorder columns
            metric_order = ["auc", "auprc", "accuracy", "precision", "recall"]
            available_metrics = [col for col in metric_order if col in pivot.columns]
            pivot = pivot[available_metrics]
            
            print(pivot.to_string())
        
        print(f"\nDetailed results saved to: {output_path}")
        print(f"{'='*120}")

# --------------------------
# Main Function
# --------------------------

def main():
    """Main ablation study runner"""
    config = AblationConfig(
        seed=42,
        n_runs=3,
        epochs=10,  # Reduced for faster ablation
        batch_size=64,
        base_lr=1e-4
    )
    
    print("Starting Ablation Study for TCR-peptide-MHC Binding Prediction")
    print(f"Device: {config.device}")
    print(f"Configuration: {config.n_runs} runs, {config.epochs} epochs per run")
    
    # Run ablation studies
    experiment = AblationExperiment(config)
    results_df = experiment.run_ablation_studies()
    
    # Save and display results
    experiment.save_results(results_df)

if __name__ == "__main__":
    main()

2025-09-23 11:51:11.039507: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-09-23 11:51:11.092162: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Starting Ablation Study for TCR-peptide-MHC Binding Prediction
Device: cuda
Configuration: 3 runs, 10 epochs per run
1. ESM Freezing Strategy Ablation
Testing freeze_layers = 0


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
