In [1]:
# -*- coding: utf-8 -*-
# Enhanced ESM Fine-tuning for TCR-peptide-MHC binding prediction with Maximum Accuracy
# Dependencies: torch, transformers, numpy, pandas, scikit-learn, tqdm, requests, accelerate

import os, sys, math, json, time, random, requests, warnings
from typing import List, Dict, Tuple, Set
import numpy as np
import pandas as pd
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from transformers import EsmModel, EsmTokenizer
from sklearn.linear_model import LogisticRegression
from sklearn.exceptions import ConvergenceWarning
from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score, precision_recall_curve, roc_curve
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
import seaborn as sns

# New imports for acceleration
from accelerate import Accelerator
from accelerate.utils import set_seed as accelerate_set_seed

try:
    from tqdm import tqdm
except ImportError:
    raise ImportError("Please install tqdm: pip install tqdm")

# --------------------------
# Helper Functions
# --------------------------

def get_model_attr(model, attr_name):
    """Safely get model attribute, handling accelerator wrapping"""
    if hasattr(model, 'module'):
        return getattr(model.module, attr_name)
    else:
        return getattr(model, attr_name)

def set_model_attr(model, attr_name, value):
    """Safely set model attribute, handling accelerator wrapping"""
    if hasattr(model, 'module'):
        setattr(model.module, attr_name, value)
    else:
        setattr(model, attr_name, value)

# --------------------------
# Training History Logger
# --------------------------

class TrainingLogger:
    """Enhanced training logger for comprehensive metrics tracking"""
    
    def __init__(self, log_dir: str = "training_logs"):
        self.log_dir = log_dir
        os.makedirs(log_dir, exist_ok=True)
        
        self.history = {
            'epoch': [],
            'train_loss': [],
            'train_bce_loss': [],
            'train_infonce_loss': [],
            'val_loss': [],
            'val_bce_loss': [],
            'val_infonce_loss': [],
            'val_auc': [],
            'val_auprc': [],
            'val_accuracy': [],
            'val_precision': [],
            'val_recall': [],
            'val_f1': [],
            'learning_rate': [],
            'temperature': [],
            'train_time': [],
            'val_time': []
        }
        
        self.start_time = time.time()
        
    def log_epoch(self, epoch_data: Dict):
        """Log data for one epoch"""
        for key, value in epoch_data.items():
            if key in self.history:
                self.history[key].append(value)
    
    def save_logs(self, filename: str = "training_history.csv"):
        """Save training history to CSV"""
        df = pd.DataFrame(self.history)
        filepath = os.path.join(self.log_dir, filename)
        df.to_csv(filepath, index=False)
        print(f"[Info] Training history saved to {filepath}")
        return filepath
    
    def save_json(self, filename: str = "training_history.json"):
        """Save training history to JSON"""
        filepath = os.path.join(self.log_dir, filename)
        with open(filepath, 'w') as f:
            json.dump(self.history, f, indent=2)
        print(f"[Info] Training history saved to {filepath}")
        return filepath
    
    def plot_training_curves(self, filename: str = "training_curves.png"):
        """Generate training curve plots"""
        if len(self.history['epoch']) == 0:
            return
            
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        fig.suptitle('Training Progress', fontsize=16, fontweight='bold')
        
        epochs = self.history['epoch']
        
        # Loss curves
        axes[0, 0].plot(epochs, self.history['train_loss'], 'b-', label='Train Loss', linewidth=2)
        axes[0, 0].plot(epochs, self.history['val_loss'], 'r-', label='Val Loss', linewidth=2)
        axes[0, 0].set_title('Total Loss')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)
        
        # BCE Loss
        axes[0, 1].plot(epochs, self.history['train_bce_loss'], 'b-', label='Train BCE', linewidth=2)
        axes[0, 1].plot(epochs, self.history['val_bce_loss'], 'r-', label='Val BCE', linewidth=2)
        axes[0, 1].set_title('BCE Loss')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('BCE Loss')
        axes[0, 1].legend()
        axes[0, 1].grid(True, alpha=0.3)
        
        # InfoNCE Loss
        axes[0, 2].plot(epochs, self.history['train_infonce_loss'], 'b-', label='Train InfoNCE', linewidth=2)
        axes[0, 2].plot(epochs, self.history['val_infonce_loss'], 'r-', label='Val InfoNCE', linewidth=2)
        axes[0, 2].set_title('InfoNCE Loss')
        axes[0, 2].set_xlabel('Epoch')
        axes[0, 2].set_ylabel('InfoNCE Loss')
        axes[0, 2].legend()
        axes[0, 2].grid(True, alpha=0.3)
        
        # AUC and AUPRC
        axes[1, 0].plot(epochs, self.history['val_auc'], 'g-', label='AUC', linewidth=2)
        axes[1, 0].plot(epochs, self.history['val_auprc'], 'orange', label='AUPRC', linewidth=2)
        axes[1, 0].set_title('Validation Metrics')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('Score')
        axes[1, 0].legend()
        axes[1, 0].grid(True, alpha=0.3)
        axes[1, 0].set_ylim(0, 1)
        
        # Accuracy and F1
        axes[1, 1].plot(epochs, self.history['val_accuracy'], 'purple', label='Accuracy', linewidth=2)
        axes[1, 1].plot(epochs, self.history['val_f1'], 'brown', label='F1 Score', linewidth=2)
        axes[1, 1].set_title('Classification Metrics')
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('Score')
        axes[1, 1].legend()
        axes[1, 1].grid(True, alpha=0.3)
        axes[1, 1].set_ylim(0, 1)
        
        # Learning Rate and Temperature
        axes[1, 2].plot(epochs, self.history['learning_rate'], 'navy', label='Learning Rate', linewidth=2)
        ax2 = axes[1, 2].twinx()
        ax2.plot(epochs, self.history['temperature'], 'red', label='Temperature', linewidth=2)
        axes[1, 2].set_title('Learning Rate & Temperature')
        axes[1, 2].set_xlabel('Epoch')
        axes[1, 2].set_ylabel('Learning Rate')
        ax2.set_ylabel('Temperature')
        axes[1, 2].legend(loc='upper left')
        ax2.legend(loc='upper right')
        axes[1, 2].grid(True, alpha=0.3)
        
        plt.tight_layout()
        filepath = os.path.join(self.log_dir, filename)
        plt.savefig(filepath, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"[Info] Training curves saved to {filepath}")
        return filepath

# --------------------------
# Cross-Validation and Metrics
# --------------------------

class CrossValidationEvaluator:
    """5-fold cross-validation with confidence intervals"""
    
    def __init__(self, n_splits: int = 5, random_state: int = 42):
        self.n_splits = n_splits
        self.random_state = random_state
        self.kfold = KFold(n_splits=n_splits, shuffle=True, random_state=random_state)
        self.results = defaultdict(list)
        
    def compute_detailed_metrics(self, y_true: np.ndarray, y_prob: np.ndarray, threshold: float = 0.5) -> Dict:
        """Compute detailed classification metrics"""
        y_pred = (y_prob >= threshold).astype(int)
        
        # Basic metrics
        accuracy = accuracy_score(y_true, y_pred)
        
        # Handle edge cases for precision/recall
        tp = np.sum((y_true == 1) & (y_pred == 1))
        fp = np.sum((y_true == 0) & (y_pred == 1))
        fn = np.sum((y_true == 1) & (y_pred == 0))
        
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
        
        # AUC metrics
        try:
            auc = roc_auc_score(y_true, y_prob) if len(np.unique(y_true)) > 1 else 0.0
            auprc = average_precision_score(y_true, y_prob) if len(np.unique(y_true)) > 1 else 0.0
        except:
            auc = 0.0
            auprc = 0.0
        
        return {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'auc': auc,
            'auprc': auprc
        }
    
    def compute_confidence_interval(self, values: List[float], confidence: float = 0.95) -> Tuple[float, float, float]:
        """Compute mean and confidence interval using bootstrap"""
        values = np.array(values)
        
        if len(values) == 0:
            return 0.0, 0.0, 0.0
        
        mean_val = np.mean(values)
        
        # Bootstrap confidence interval
        n_bootstrap = 1000
        bootstrap_means = []
        
        for _ in range(n_bootstrap):
            bootstrap_sample = np.random.choice(values, size=len(values), replace=True)
            bootstrap_means.append(np.mean(bootstrap_sample))
        
        bootstrap_means = np.array(bootstrap_means)
        alpha = 1 - confidence
        lower_percentile = (alpha / 2) * 100
        upper_percentile = (1 - alpha / 2) * 100
        
        ci_lower = np.percentile(bootstrap_means, lower_percentile)
        ci_upper = np.percentile(bootstrap_means, upper_percentile)
        
        return mean_val, ci_lower, ci_upper
    
    def add_fold_result(self, fold: int, metrics: Dict):
        """Add results from one fold"""
        self.results[fold] = metrics
        
        # Add to overall results
        for metric_name, value in metrics.items():
            if isinstance(value, (int, float)):
                self.results[f'all_{metric_name}'].append(value)
    
    def get_summary_with_ci(self, confidence: float = 0.95) -> Dict:
        """Get summary statistics with confidence intervals"""
        summary = {}
        
        metric_names = ['accuracy', 'precision', 'recall', 'f1', 'auc', 'auprc']
        
        for metric in metric_names:
            values = self.results.get(f'all_{metric}', [])
            if values:
                mean_val, ci_lower, ci_upper = self.compute_confidence_interval(values, confidence)
                summary[metric] = {
                    'mean': mean_val,
                    'std': np.std(values),
                    'ci_lower': ci_lower,
                    'ci_upper': ci_upper,
                    'values': values
                }
        
        return summary
    
    def save_results(self, filepath: str = "cv_results.json"):
        """Save cross-validation results"""
        # Convert defaultdict to regular dict for JSON serialization
        results_dict = dict(self.results)
        
        # Add summary statistics
        results_dict['summary'] = self.get_summary_with_ci()
        
        with open(filepath, 'w') as f:
            json.dump(results_dict, f, indent=2, default=str)
        
        print(f"[Info] Cross-validation results saved to {filepath}")
        return filepath
    
    def plot_cv_results(self, filepath: str = "cv_results.png"):
        """Plot cross-validation results"""
        summary = self.get_summary_with_ci()
        
        if not summary:
            return
        
        metrics = list(summary.keys())
        means = [summary[m]['mean'] for m in metrics]
        ci_lowers = [summary[m]['ci_lower'] for m in metrics]
        ci_uppers = [summary[m]['ci_upper'] for m in metrics]
        
        # Error bars
        errors = [
            [means[i] - ci_lowers[i] for i in range(len(means))],
            [ci_uppers[i] - means[i] for i in range(len(means))]
        ]
        
        plt.figure(figsize=(12, 8))
        
        # Bar plot with error bars
        bars = plt.bar(metrics, means, yerr=errors, capsize=10, alpha=0.7, 
                      color=['skyblue', 'lightgreen', 'lightcoral', 'gold', 'plum', 'lightsalmon'])
        
        plt.title('5-Fold Cross-Validation Results with 95% Confidence Intervals', 
                 fontsize=14, fontweight='bold')
        plt.ylabel('Score', fontsize=12)
        plt.xlabel('Metrics', fontsize=12)
        plt.ylim(0, 1.1)
        
        # Add value labels on bars
        for i, (bar, mean_val, ci_l, ci_u) in enumerate(zip(bars, means, ci_lowers, ci_uppers)):
            plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
                    f'{mean_val:.3f}\n[{ci_l:.3f}, {ci_u:.3f}]',
                    ha='center', va='bottom', fontweight='bold', fontsize=10)
        
        plt.xticks(rotation=45)
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.savefig(filepath, dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f"[Info] Cross-validation plot saved to {filepath}")
        return filepath

# --------------------------
# Enhanced Configuration
# --------------------------

AA_STANDARD = list("ACDEFGHIKLMNPQRSTVWY")
AA_SET = set(AA_STANDARD)

# Extended MHC Class I pseudo-sequences
MHC_PSEUDO_SEQUENCES = {
    "HLA-A*02:01": "GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRMEPRAPWIEQEGPEYWDGETRKVKAHSQTHRVDLGTLRGYYNQSEAGSHTVQRMYGCDVGSDWRFLRGYHQYAYDGKDYIALKEDLRSWTAADMAAQTTKHKWEAAHVAEQLRAYLEGTCVEWLRRYLENGKETLQRTDAPKTHMTHHAVSDHEATLRCWALSFYPAEITLTWQRDGEDQTQDTELVETRPAGDGTFQKWAAVVVPSGQEQRYTCHVQHEGLPKPLTLRWE",
    "HLA-A*01:01": "GSHSMRYFFTSVSRPGRGEPRFIAMGYVDDTQFVRFDSDAASQKMEPRAPWIEQEGPEYWDRETQKAKGNEQSFRVDLRTLLGYYNQSEDGSHTIQIMYGCDVGPDGRLLRGYDQYAYDGKDYIALNEDLRSWTAADTAAQITQRKWEAARVAEQLRAYLEGTCVEWLRRYLENGKDKLERADPPKTHVTHHPISDHEATLRCWALGFYPAEITLTWQRDGEDQTQDTELVETRPAGDGTFQKWAAVVVPSGEEQRYTCHVQHEGLPKPLTLRWE",
    "HLA-B*07:02": "GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQKMEPRAPWIEQEGPEYWDRETQKAKGNEQSFRVDLRTLLGYYNQSEDGSHTIQIMYGCDVGPDGRLLRGYDQYAYDGKDYIALNEDLRSWTAADTAAQITQRKWEAARVAEQLRAYLEGTCVEWLRRYLENGKDKLERADPPKTHVTHHPISDHEATLRCWALGFYPAEITLTWQRDGEDQTQDTELVETRPAGDGTFQKWAAVVVPSGEEQRYTCHVQHEGLPKPLTLRWE",
    "HLA-A*03:01": "GSHSMRYFFTSVSRPGRGEPRFIAMGYVDDTQFVRFDSDAASQKMEPRAPWIEQEGPEYWDRETQKAKGNEQSFRVDLRTLLGYYNQSEDGSHTIQIMYGCDVGPDGRLLRGYDQYAYDGKDYIALNEDLRSWTAADTAAQITQRKWEAARVAEQLRAYLEGTCVEWLRRYLENGKDKLERADPPKTHVTHHPISDHEATLRCWALGFYPAEITLTWQRDGEDQTQDTELVETRPAGDGTFQKWAAVVVPSGEEQRYTCHVQHEGLPKPLTLRWE",
    "HLA-B*08:01": "GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQKMEPRAPWIEQEGPEYWDRETQKAKGNEQSFRVDLRTLLGYYNQSEDGSHTIQIMYGCDVGPDGRLLRGYDQYAYDGKDYIALNEDLRSWTAADTAAQITQRKWEAARVAEQLRAYLEGTCVEWLRRYLENGKDKLERADPPKTHVTHHPISDHEATLRCWALGFYPAEITLTWQRDGEDQTQDTELVETRPAGDGTFQKWAAVVVPSGEEQRYTCHVQHEGLPKPLTLRWE"
}

def get_mhc_sequence(mhc_name: str) -> str:
    """Get MHC amino acid sequence from allele name"""
    return MHC_PSEUDO_SEQUENCES.get(mhc_name, MHC_PSEUDO_SEQUENCES["HLA-A*02:01"])

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    accelerate_set_seed(seed)

# --------------------------
# Enhanced ESM Encoder
# --------------------------

class EnhancedESMEncoder:
    """Enhanced ESM encoder with minimal freezing for maximum accuracy"""
    def __init__(self, model_name: str = "facebook/esm2_t12_35M_UR50D", device: str = "cpu", 
                 freeze_layers: int = 2):
        print(f"[Info] Loading Enhanced ESM model: {model_name}")
        self.device = device
        self.tokenizer = EsmTokenizer.from_pretrained(model_name)
        self.model = EsmModel.from_pretrained(model_name).to(device)
        
        # Aggressive fine-tuning - only freeze first 2 layers
        self.model.train()
        
        total_layers = len(self.model.encoder.layer)
        trainable_layers = 0
        
        for i, layer in enumerate(self.model.encoder.layer):
            if i < freeze_layers:
                for param in layer.parameters():
                    param.requires_grad = False
            else:
                for param in layer.parameters():
                    param.requires_grad = True
                trainable_layers += 1
        
        # Always fine-tune embeddings and pooler for maximum accuracy
        for param in self.model.embeddings.parameters():
            param.requires_grad = True
        for param in self.model.pooler.parameters():
            param.requires_grad = True
            
        self.hidden_size = self.model.config.hidden_size
        print(f"[Info] Enhanced ESM loaded - Hidden size: {self.hidden_size}, Trainable layers: {trainable_layers}/{total_layers}")
    
    def encode_batch(self, sequences: List[str], max_length: int = 512) -> torch.Tensor:
        """Enhanced batch encoding with better preprocessing"""
        if not sequences:
            return torch.empty(0, self.hidden_size, device=self.device)
        
        # Enhanced sequence cleaning
        clean_seqs = []
        for seq in sequences:
            clean_seq = "".join([c for c in seq.upper() if c in AA_SET])
            clean_seqs.append(clean_seq if clean_seq else "A")
        
        # Batch tokenize with optimized settings
        inputs = self.tokenizer(
            clean_seqs, 
            return_tensors="pt", 
            padding=True, 
            truncation=True, 
            max_length=max_length
        ).to(self.device)
        
        # Forward pass
        outputs = self.model(**inputs)
        # Use [CLS] token representation
        cls_embeddings = outputs.last_hidden_state[:, 0, :]
        return cls_embeddings

# --------------------------
# Enhanced Model Architecture
# --------------------------

class MultiHeadAttention(nn.Module):
    """Multi-head attention for better representation fusion"""
    def __init__(self, d_model: int, n_heads: int = 8, dropout: float = 0.1):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(d_model, d_model)
        
    def forward(self, q, k, v):
        batch_size = q.size(0)
        
        # Transform and reshape
        q = self.q_linear(q).view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
        k = self.k_linear(k).view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
        v = self.v_linear(v).view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
        
        # Attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        
        context = torch.matmul(attn, v)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        
        return self.out(context)

class EnhancedTCRModel(nn.Module):
    """Enhanced model with maximum capacity for accuracy"""
    def __init__(self, esm_encoder: EnhancedESMEncoder, d_model: int = 512, dropout: float = 0.1):
        super().__init__()
        self.esm_encoder = esm_encoder
        self.esm_hidden_size = esm_encoder.hidden_size
        self.d_model = d_model
        
        # Enhanced projection layers with more capacity
        self.proj_tcr = nn.Sequential(
            nn.Linear(self.esm_hidden_size, d_model * 2),
            nn.LayerNorm(d_model * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 2, d_model),
            nn.LayerNorm(d_model),
            nn.GELU(),
            nn.Dropout(dropout)
        )
        
        self.proj_mhc = nn.Sequential(
            nn.Linear(self.esm_hidden_size, d_model * 2),
            nn.LayerNorm(d_model * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 2, d_model),
            nn.LayerNorm(d_model),
            nn.GELU(),
            nn.Dropout(dropout)
        )
        
        self.proj_peptide = nn.Sequential(
            nn.Linear(self.esm_hidden_size, d_model * 2),
            nn.LayerNorm(d_model * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 2, d_model),
            nn.LayerNorm(d_model),
            nn.GELU(),
            nn.Dropout(dropout)
        )
        
        # Cross-attention between components
        self.tcr_peptide_attention = MultiHeadAttention(d_model, n_heads=8, dropout=dropout)
        self.mhc_peptide_attention = MultiHeadAttention(d_model, n_heads=8, dropout=dropout)
        
        # Enhanced fusion with multiple layers
        self.fusion = nn.Sequential(
            nn.Linear(d_model * 5, d_model * 4),  # 5 = tcr + mhc + peptide + 2 attention outputs
            nn.LayerNorm(d_model * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 4, d_model * 2),
            nn.LayerNorm(d_model * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 2, d_model),
            nn.LayerNorm(d_model),
            nn.GELU(),
            nn.Dropout(dropout)
        )
        
        # Multi-layer classification head
        self.classifier = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.LayerNorm(d_model // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model // 2, d_model // 4),
            nn.LayerNorm(d_model // 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model // 4, 1)
        )
        
        # InfoNCE projection for contrastive learning
        self.infonce_proj = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.LayerNorm(d_model // 2),
            nn.GELU(),
            nn.Linear(d_model // 2, d_model // 4)
        )
        
        # Learnable temperature
        self.temperature = nn.Parameter(torch.tensor(0.07))
        
    def forward(self, cdr3_seqs: List[str], mhc_alleles: List[str], peptide_seqs: List[str]) -> Dict[str, torch.Tensor]:
        # Convert MHC alleles to sequences
        mhc_seqs = [get_mhc_sequence(allele) for allele in mhc_alleles]
        
        # ESM encoding
        cdr3_emb = self.esm_encoder.encode_batch(cdr3_seqs)
        mhc_emb = self.esm_encoder.encode_batch(mhc_seqs)
        peptide_emb = self.esm_encoder.encode_batch(peptide_seqs)
        
        # Projection
        cdr3_proj = self.proj_tcr(cdr3_emb)
        mhc_proj = self.proj_mhc(mhc_emb)
        peptide_proj = self.proj_peptide(peptide_emb)
        
        # Cross-attention
        tcr_pep_attn = self.tcr_peptide_attention(
            cdr3_proj.unsqueeze(1), peptide_proj.unsqueeze(1), peptide_proj.unsqueeze(1)
        ).squeeze(1)
        
        mhc_pep_attn = self.mhc_peptide_attention(
            mhc_proj.unsqueeze(1), peptide_proj.unsqueeze(1), peptide_proj.unsqueeze(1)
        ).squeeze(1)
        
        # Enhanced fusion
        combined = torch.cat([cdr3_proj, mhc_proj, peptide_proj, tcr_pep_attn, mhc_pep_attn], dim=-1)
        fused = self.fusion(combined)
        
        # Classification
        logits = self.classifier(fused).squeeze(-1)
        
        # InfoNCE features
        infonce_features = self.infonce_proj(fused)
        
        return {
            'logits': logits,
            'infonce_features': infonce_features,
            'fused_features': fused
        }

# --------------------------
# Enhanced Loss Functions
# --------------------------

class InfoNCELoss(nn.Module):
    """InfoNCE contrastive loss for better representation learning"""
    def __init__(self, temperature: float = 0.07):
        super().__init__()
        self.temperature = temperature
        
    def forward(self, features: torch.Tensor, labels: torch.Tensor, temperature: torch.Tensor):
        # Normalize features
        features = F.normalize(features, dim=1)
        
        batch_size = features.shape[0]
        if batch_size < 2:
            return torch.tensor(0.0, device=features.device)
        
        # Compute similarity matrix
        sim_matrix = torch.matmul(features, features.t()) / temperature
        
        # Create positive mask
        labels = labels.contiguous().view(-1, 1)
        mask = torch.eq(labels, labels.t()).float().to(features.device)
        mask = mask - torch.eye(batch_size, device=features.device)  # Remove self-similarity
        
        # InfoNCE loss
        exp_sim = torch.exp(sim_matrix)
        sum_exp_sim = exp_sim.sum(dim=1, keepdim=True)
        
        loss = 0
        num_positives = 0
        
        for i in range(batch_size):
            pos_sim = exp_sim[i] * mask[i]
            if pos_sim.sum() > 0:
                loss += -torch.log(pos_sim.sum() / sum_exp_sim[i].clamp(min=1e-8))
                num_positives += 1
        
        return loss / max(1, num_positives)

# --------------------------
# Enhanced Data Processing (same as before - abbreviated for space)
# --------------------------

def sliding_window_peptides(seq: str, lengths: List[int]) -> Set[str]:
    """Extract peptides of specified lengths from sequence"""
    outs = set()
    L = len(seq)
    for k in lengths:
        if k <= 0 or k > L: 
            continue
        for i in range(0, L - k + 1):
            p = seq[i:i+k]
            if all(ch in AA_SET for ch in p):
                outs.add(p)
    return outs

def enhanced_mutate_peptide(peptide: str, n_mut: int = 1, mutation_strategies: List[str] = None) -> str:
    """Enhanced peptide mutation with multiple strategies"""
    if mutation_strategies is None:
        mutation_strategies = ['random', 'conservative', 'hydrophobic']
    
    s = list(peptide)
    L = len(s)
    strategy = random.choice(mutation_strategies)
    
    # Conservative mutations (similar amino acids)
    conservative_groups = {
        'A': ['V', 'I', 'L'], 'V': ['A', 'I', 'L'], 'I': ['A', 'V', 'L'], 'L': ['A', 'V', 'I'],
        'F': ['Y', 'W'], 'Y': ['F', 'W'], 'W': ['F', 'Y'],
        'S': ['T'], 'T': ['S'],
        'D': ['E'], 'E': ['D'],
        'N': ['Q'], 'Q': ['N'],
        'K': ['R'], 'R': ['K']
    }
    
    # Hydrophobic amino acids
    hydrophobic = ['A', 'V', 'I', 'L', 'M', 'F', 'Y', 'W']
    
    n = min(n_mut, L)
    positions = random.sample(range(L), n)
    
    for pos in positions:
        original = s[pos]
        
        if strategy == 'conservative' and original in conservative_groups:
            candidates = conservative_groups[original]
        elif strategy == 'hydrophobic' and original in hydrophobic:
            candidates = hydrophobic
        else:  # random
            candidates = [aa for aa in AA_STANDARD if aa != original]
        
        if candidates:
            s[pos] = random.choice(candidates)
    
    return "".join(s)

def augment_positive_samples(df_positive: pd.DataFrame, target_count: int) -> pd.DataFrame:
    """Augment positive samples through various data augmentation techniques"""
    augmented_rows = []
    current_count = len(df_positive)
    
    if current_count >= target_count:
        return df_positive
    
    needed_samples = target_count - current_count
    print(f"[Info] Augmenting {needed_samples} positive samples...")
    
    # Convert to list for easier sampling
    positive_rows = df_positive.to_dict('records')
    
    for _ in range(needed_samples):
        # Randomly select a positive sample to augment
        base_row = random.choice(positive_rows)
        cdr3, mhc, peptide = base_row["CDR3"], base_row["MHC"], base_row["Epitope"]
        
        # Choose augmentation strategy
        strategy = random.choice(['peptide_mutation', 'cdr3_mutation', 'conservative_mutation'])
        
        if strategy == 'peptide_mutation':
            # Light mutation of peptide (1-2 positions)
            new_peptide = enhanced_mutate_peptide(peptide, n_mut=random.choice([1, 2]), 
                                                 mutation_strategies=['conservative'])
            augmented_rows.append({
                "CDR3": cdr3,
                "MHC": mhc,
                "Epitope": new_peptide,
                "label": 1
            })
        
        elif strategy == 'cdr3_mutation':
            # Light mutation of CDR3 (1-2 positions for longer sequences)
            if len(cdr3) > 8:
                new_cdr3 = enhanced_mutate_peptide(cdr3, n_mut=random.choice([1, 2]), 
                                                  mutation_strategies=['conservative'])
                augmented_rows.append({
                    "CDR3": new_cdr3,
                    "MHC": mhc,
                    "Epitope": peptide,
                    "label": 1
                })
            else:
                # For shorter CDR3, just duplicate with slight variation in context
                augmented_rows.append({
                    "CDR3": cdr3,
                    "MHC": mhc,
                    "Epitope": peptide,
                    "label": 1
                })
        
        else:  # conservative_mutation
            # Very conservative mutation of both peptide and CDR3
            new_peptide = enhanced_mutate_peptide(peptide, n_mut=1, 
                                                 mutation_strategies=['conservative'])
            if len(cdr3) > 8:
                new_cdr3 = enhanced_mutate_peptide(cdr3, n_mut=1, 
                                                  mutation_strategies=['conservative'])
            else:
                new_cdr3 = cdr3
            
            augmented_rows.append({
                "CDR3": new_cdr3,
                "MHC": mhc,
                "Epitope": new_peptide,
                "label": 1
            })
    
    # Combine original and augmented samples
    df_augmented = pd.DataFrame(augmented_rows)
    df_combined = pd.concat([df_positive, df_augmented], axis=0, ignore_index=True)
    df_combined = df_combined.drop_duplicates()
    
    print(f"[Info] Augmented positive samples: {current_count} -> {len(df_combined)}")
    return df_combined

def enhanced_build_negatives_with_balance(df: pd.DataFrame, k_neg: int = 8, balance_ratio: float = 1.5) -> pd.DataFrame:
    """Enhanced negative sampling with positive sample balancing"""
    df = df.copy()
    df["label"] = 1
    original_positive_count = len(df)
    
    # Group peptides by length and properties
    peps = df["Epitope"].dropna().astype(str).str.upper().tolist()
    peps = [p for p in peps if all(c in AA_SET for c in p)]
    
    peps_by_len = {}
    for p in set(peps):
        peps_by_len.setdefault(len(p), []).append(p)
    
    # Group TCRs by MHC
    tcrs_by_mhc = {}
    for _, row in df.iterrows():
        mhc = str(row["MHC"])
        tcr = str(row["CDR3"])
        tcrs_by_mhc.setdefault(mhc, []).append(tcr)
    
    rows = []
    
    for _, r in df.iterrows():
        c, m, p = str(r["CDR3"]), str(r["MHC"]), str(r["Epitope"])
        p = p.upper()
        L = len(p)
        
        # Strategy 1: Different peptides, same length
        cand_list = peps_by_len.get(L, [])
        neg_peptides = []
        
        for _ in range(k_neg // 4):
            if cand_list:
                q = random.choice(cand_list)
                if q != p:
                    neg_peptides.append(q)
        
        # Strategy 2: Mutated peptides (multiple mutation levels)
        for mut_level in [1, 2, 3]:
            for _ in range(k_neg // 4):
                neg_peptides.append(enhanced_mutate_peptide(p, n_mut=mut_level))
        
        # Strategy 3: Cross-MHC negatives
        for other_mhc, other_tcrs in tcrs_by_mhc.items():
            if other_mhc != m and other_tcrs:
                neg_tcr = random.choice(other_tcrs)
                rows.append({"CDR3": neg_tcr, "MHC": m, "Epitope": p, "label": 0})
        
        # Add peptide negatives
        for q in neg_peptides:
            rows.append({"CDR3": c, "MHC": m, "Epitope": q, "label": 0})
    
    df_neg = pd.DataFrame(rows)
    negative_count = len(df_neg)
    
    print(f"[Info] Generated {negative_count} negative samples for {original_positive_count} positive samples")
    
    # Balance positive samples if needed
    target_positive_count = max(original_positive_count, int(negative_count / balance_ratio))
    
    if target_positive_count > original_positive_count:
        print(f"[Info] Balancing dataset: target positive samples = {target_positive_count}")
        df_balanced_positive = augment_positive_samples(df, target_positive_count)
    else:
        df_balanced_positive = df
    
    # Combine balanced positive and negative samples
    out = pd.concat([df_balanced_positive[["CDR3","MHC","Epitope","label"]], df_neg], axis=0, ignore_index=True)
    out = out.drop_duplicates()
    
    final_positive = len(out[out["label"] == 1])
    final_negative = len(out[out["label"] == 0])
    
    print(f"[Info] Final balanced dataset: {final_positive} positives + {final_negative} negatives = {len(out)} total")
    print(f"[Info] Positive/Negative ratio: {final_positive/final_negative:.2f}")
    
    return out

def split_by_epitope(df_all: pd.DataFrame, seed: int = 42, ratio=(8,1,1)) -> pd.DataFrame:
    """Split data by epitope to avoid data leakage"""
    train_r, val_r, test_r = ratio
    peps = df_all[df_all["label"]==1]["Epitope"].drop_duplicates().sample(frac=1.0, random_state=seed).tolist()
    n = len(peps)
    n_train = int(n * train_r / sum(ratio))
    n_val = int(n * val_r / sum(ratio))
    train_peps = set(peps[:n_train])
    val_peps = set(peps[n_train:n_train+n_val])
    
    def split_label(p):
        if p in train_peps: return "train"
        if p in val_peps: return "val"
        return "test"
    
    df_all = df_all.copy()
    df_all["split"] = df_all["Epitope"].map(split_label)
    return df_all

# --------------------------
# Enhanced Dataset
# --------------------------

class EnhancedDataset(Dataset):
    def __init__(self, df: pd.DataFrame):
        self.df = df.reset_index(drop=True).copy()
        
    def __len__(self):
        return len(self.df)
        
    def __getitem__(self, idx: int) -> Dict:
        r = self.df.iloc[idx]
        return {
            "cdr3": str(r["CDR3"]),
            "mhc": str(r["MHC"]), 
            "peptide": str(r["Epitope"]),
            "label": float(r["label"]) if "label" in self.df.columns else 1.0
        }
    
    def collate_fn(self, batch: List[Dict]) -> Dict:
        cdr3_seqs = [item["cdr3"] for item in batch]
        mhc_alleles = [item["mhc"] for item in batch]
        peptide_seqs = [item["peptide"] for item in batch]
        labels = torch.tensor([item["label"] for item in batch], dtype=torch.float32)
        
        return {
            "cdr3_seqs": cdr3_seqs,
            "mhc_alleles": mhc_alleles,
            "peptide_seqs": peptide_seqs,
            "labels": labels
        }

# --------------------------
# Fixed Training with Logging
# --------------------------

def evaluate_model(model, dataloader, bce_loss, infonce_loss, device):
    """Evaluate model and return detailed metrics"""
    model.eval()
    total_loss = 0.0
    total_bce = 0.0
    total_infonce = 0.0
    all_logits = []
    all_labels = []
    
    with torch.no_grad():
        for batch in dataloader:
            labels = batch["labels"].to(device)
            outputs = model(batch["cdr3_seqs"], batch["mhc_alleles"], batch["peptide_seqs"])
            
            bce = bce_loss(outputs['logits'], labels)
            # Fixed: Use helper function to get temperature
            temperature = get_model_attr(model, 'temperature')
            infonce = infonce_loss(outputs['infonce_features'], labels, temperature)
            
            total_loss += bce.item() + 0.3 * infonce.item()  # Weighted combination
            total_bce += bce.item()
            total_infonce += infonce.item()
            
            all_logits.extend(outputs['logits'].cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # Compute metrics
    logits_array = np.array(all_logits)
    labels_array = np.array(all_labels)
    probs_array = torch.sigmoid(torch.tensor(logits_array)).numpy()
    
    cv_evaluator = CrossValidationEvaluator()
    metrics = cv_evaluator.compute_detailed_metrics(labels_array, probs_array)
    
    metrics.update({
        'total_loss': total_loss / len(dataloader),
        'bce_loss': total_bce / len(dataloader),
        'infonce_loss': total_infonce / len(dataloader)
    })
    
    return metrics

def train_enhanced_model_with_logging(df_train: pd.DataFrame, df_val: pd.DataFrame, esm_encoder: EnhancedESMEncoder, 
                                    d_model=512, lr=1e-4, batch_size=16, epochs=25, log_dir="training_logs"):
    """Enhanced training with comprehensive logging"""
    
    # Initialize accelerator with BF16
    accelerator = Accelerator(
        mixed_precision='bf16',
        gradient_accumulation_steps=2,
        log_with="tensorboard",
        project_dir=log_dir
    )
    
    device = accelerator.device
    print(f"[Info] Training Enhanced Model with Logging - Device: {device}")
    print(f"[Info] Mixed Precision: {accelerator.mixed_precision}")
    print(f"[Info] Epochs: {epochs}, Batch Size: {batch_size}, Model Dim: {d_model}")
    
    # Initialize logger
    logger = TrainingLogger(log_dir=log_dir)
    
    # Move ESM encoder to correct device
    esm_encoder.model = esm_encoder.model.to(device)
    esm_encoder.device = device
    
    model = EnhancedTCRModel(esm_encoder, d_model=d_model).to(device)
    
    # Enhanced loss functions
    bce_loss = nn.BCEWithLogitsLoss()
    infonce_loss = InfoNCELoss()
    
    # Advanced optimizer with different learning rates
    esm_params = []
    proj_params = []
    attn_params = []
    
    for name, param in model.named_parameters():
        if 'esm_encoder' in name:
            esm_params.append(param)
        elif 'attention' in name:
            attn_params.append(param)
        else:
            proj_params.append(param)
    
    optimizer = torch.optim.AdamW([
        {'params': esm_params, 'lr': lr * 0.05, 'weight_decay': 0.01},
        {'params': attn_params, 'lr': lr * 0.5, 'weight_decay': 0.05},
        {'params': proj_params, 'lr': lr, 'weight_decay': 0.1}
    ])
    
    # Enhanced scheduler with warmup
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=epochs//3, T_mult=2, eta_min=1e-7
    )
    
    train_dataset = EnhancedDataset(df_train)
    val_dataset = EnhancedDataset(df_val)
    
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        collate_fn=train_dataset.collate_fn, 
        num_workers=0,
        pin_memory=True,
        drop_last=True
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=batch_size, 
        shuffle=False,
        collate_fn=val_dataset.collate_fn, 
        num_workers=0,
        pin_memory=True
    )
    
    # Prepare everything with accelerator
    model, optimizer, train_loader, val_loader, scheduler = accelerator.prepare(
        model, optimizer, train_loader, val_loader, scheduler
    )
    
    best_val_auc = 0.0
    best_model_state = None
    patience = 8
    no_improve = 0
    
    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()
        
        # Training
        model.train()
        train_loss = 0.0
        train_bce = 0.0
        train_infonce = 0.0
        train_steps = 0
        
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch} Training", disable=not accelerator.is_local_main_process)
        
        for batch in progress_bar:
            with accelerator.accumulate(model):
                labels = batch["labels"]
                
                outputs = model(batch["cdr3_seqs"], batch["mhc_alleles"], batch["peptide_seqs"])
                
                # Combined loss with InfoNCE
                bce = bce_loss(outputs['logits'], labels)
                # Fixed: Use helper function to get temperature
                temperature = get_model_attr(model, 'temperature')
                infonce = infonce_loss(outputs['infonce_features'], labels, temperature)
                
                # Dynamic weight for InfoNCE
                infonce_weight = min(0.5, 0.1 + 0.4 * epoch / epochs)
                total_loss = bce + infonce_weight * infonce
                
                # Backward pass with accelerator
                accelerator.backward(total_loss)
                
                # Gradient clipping
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(model.parameters(), 1.0)
                
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                
                # Accumulate losses
                train_loss += total_loss.item()
                train_bce += bce.item()
                train_infonce += infonce.item()
                train_steps += 1
                
                # Update progress bar
                if accelerator.is_local_main_process:
                    progress_bar.set_postfix({
                        'loss': f"{total_loss.item():.4f}",
                        'bce': f"{bce.item():.4f}",
                        'infonce': f"{infonce.item():.4f}"
                    })
        
        train_time = time.time() - epoch_start_time
        
        # Validation
        val_start_time = time.time()
        val_metrics = evaluate_model(model, val_loader, bce_loss, infonce_loss, device)
        val_time = time.time() - val_start_time
        
        # Calculate averages
        avg_train_loss = train_loss / max(1, train_steps)
        avg_train_bce = train_bce / max(1, train_steps)
        avg_train_infonce = train_infonce / max(1, train_steps)
        
        current_lr = scheduler.get_last_lr()[0]
        current_temp = get_model_attr(model, 'temperature').item()
        
        # Log epoch data
        if accelerator.is_local_main_process:
            epoch_data = {
                'epoch': epoch,
                'train_loss': avg_train_loss,
                'train_bce_loss': avg_train_bce,
                'train_infonce_loss': avg_train_infonce,
                'val_loss': val_metrics['total_loss'],
                'val_bce_loss': val_metrics['bce_loss'],
                'val_infonce_loss': val_metrics['infonce_loss'],
                'val_auc': val_metrics['auc'],
                'val_auprc': val_metrics['auprc'],
                'val_accuracy': val_metrics['accuracy'],
                'val_precision': val_metrics['precision'],
                'val_recall': val_metrics['recall'],
                'val_f1': val_metrics['f1'],
                'learning_rate': current_lr,
                'temperature': current_temp,
                'train_time': train_time,
                'val_time': val_time
            }
            
            logger.log_epoch(epoch_data)
            
            print(f"Epoch {epoch}:")
            print(f"  Train: Loss={avg_train_loss:.4f}, BCE={avg_train_bce:.4f}, InfoNCE={avg_train_infonce:.4f}")
            print(f"  Val: Loss={val_metrics['total_loss']:.4f}, AUC={val_metrics['auc']:.4f}, AUPRC={val_metrics['auprc']:.4f}")
            print(f"  Val: Acc={val_metrics['accuracy']:.4f}, F1={val_metrics['f1']:.4f}, Prec={val_metrics['precision']:.4f}, Rec={val_metrics['recall']:.4f}")
            print(f"  LR={current_lr:.2e}, Temp={current_temp:.4f}, Time={train_time:.1f}s+{val_time:.1f}s")
            
            # Early stopping based on AUC
            if val_metrics['auc'] > best_val_auc:
                best_val_auc = val_metrics['auc']
                best_model_state = accelerator.get_state_dict(model)
                no_improve = 0
                print(f"  ‚úì New best model! AUC: {best_val_auc:.4f}")
            else:
                no_improve += 1
                if no_improve >= patience:
                    print(f"  Early stopping at epoch {epoch}")
                    break
        
        # Wait for all processes
        accelerator.wait_for_everyone()
    
    # Save training logs and plots
    if accelerator.is_local_main_process:
        logger.save_logs()
        logger.save_json()
        logger.plot_training_curves()
    
    # Load best model
    if accelerator.is_local_main_process and best_model_state:
        accelerator.load_state(model, best_model_state)
        print(f"[Info] Loaded best model with AUC: {best_val_auc:.4f}")
    
    accelerator.wait_for_everyone()
    
    # Return unwrapped model and best AUC
    unwrapped_model = accelerator.unwrap_model(model)
    return unwrapped_model, best_val_auc

# --------------------------
# 5-Fold Cross-Validation Implementation
# --------------------------

def run_cross_validation(df_all: pd.DataFrame, esm_model_name: str = "facebook/esm2_t12_35M_UR50D",
                        d_model: int = 512, lr: float = 1e-4, batch_size: int = 16, epochs: int = 15):
    """Run 5-fold cross-validation with detailed logging"""
    
    print("\n" + "="*70)
    print("RUNNING 5-FOLD CROSS-VALIDATION")
    print("="*70)
    
    # Initialize cross-validation evaluator
    cv_evaluator = CrossValidationEvaluator(n_splits=5, random_state=42)
    
    # Get unique epitopes for splitting
    positive_df = df_all[df_all["label"] == 1].copy()
    unique_epitopes = positive_df["Epitope"].unique()
    
    kfold = KFold(n_splits=5, shuffle=True, random_state=42)
    
    for fold, (train_epitope_idx, val_epitope_idx) in enumerate(kfold.split(unique_epitopes)):
        print(f"\n[Fold {fold + 1}/5] Starting training...")
        
        # Split epitopes
        train_epitopes = set(unique_epitopes[train_epitope_idx])
        val_epitopes = set(unique_epitopes[val_epitope_idx])
        
        # Split data based on epitopes
        train_mask = df_all["Epitope"].isin(train_epitopes)
        val_mask = df_all["Epitope"].isin(val_epitopes)
        
        df_train_fold = df_all[train_mask].reset_index(drop=True)
        df_val_fold = df_all[val_mask].reset_index(drop=True)
        
        print(f"[Fold {fold + 1}] Train: {len(df_train_fold)} samples, Val: {len(df_val_fold)} samples")
        
        # Initialize ESM encoder for this fold
        esm_encoder = EnhancedESMEncoder(
            model_name=esm_model_name,
            device="cpu",  # Will be handled by accelerator
            freeze_layers=2
        )
        
        # Train model for this fold
        fold_log_dir = f"training_logs/fold_{fold + 1}"
        model, best_auc = train_enhanced_model_with_logging(
            df_train_fold, df_val_fold, esm_encoder,
            d_model=d_model, lr=lr, batch_size=batch_size, epochs=epochs,
            log_dir=fold_log_dir
        )
        
        # Move model to CPU for evaluation
        model = model.cpu()
        model.esm_encoder.device = "cpu"
        model.esm_encoder.model = model.esm_encoder.model.cpu()
        
        # Evaluate on validation set
        val_dataset = EnhancedDataset(df_val_fold)
        val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=val_dataset.collate_fn)
        
        model.eval()
        all_logits = []
        all_labels = []
        
        with torch.no_grad():
            for batch in val_loader:
                labels = batch["labels"]
                outputs = model(batch["cdr3_seqs"], batch["mhc_alleles"], batch["peptide_seqs"])
                
                all_logits.extend(outputs['logits'].cpu().numpy())
                all_labels.extend(labels.numpy())
        
        # Compute metrics for this fold
        logits_array = np.array(all_logits)
        labels_array = np.array(all_labels)
        probs_array = torch.sigmoid(torch.tensor(logits_array)).numpy()
        
        fold_metrics = cv_evaluator.compute_detailed_metrics(labels_array, probs_array)
        fold_metrics['best_training_auc'] = best_auc
        
        cv_evaluator.add_fold_result(fold + 1, fold_metrics)
        
        print(f"[Fold {fold + 1}] Results:")
        print(f"  AUC: {fold_metrics['auc']:.4f}")
        print(f"  AUPRC: {fold_metrics['auprc']:.4f}")
        print(f"  Accuracy: {fold_metrics['accuracy']:.4f}")
        print(f"  F1: {fold_metrics['f1']:.4f}")
        print(f"  Precision: {fold_metrics['precision']:.4f}")
        print(f"  Recall: {fold_metrics['recall']:.4f}")
        
        # Clean up memory
        del model, esm_encoder
        torch.cuda.empty_cache()
    
    # Save and display results
    cv_results_path = cv_evaluator.save_results("cv_results.json")
    cv_plot_path = cv_evaluator.plot_cv_results("cv_results.png")
    
    # Display final summary
    summary = cv_evaluator.get_summary_with_ci()
    
    print("\n" + "="*70)
    print("5-FOLD CROSS-VALIDATION RESULTS WITH 95% CONFIDENCE INTERVALS")
    print("="*70)
    
    for metric, stats in summary.items():
        mean = stats['mean']
        ci_lower = stats['ci_lower']
        ci_upper = stats['ci_upper']
        std = stats['std']
        
        print(f"{metric.upper():>10}: {mean:.4f} ¬± {std:.4f} (CI: [{ci_lower:.4f}, {ci_upper:.4f}])")
    
    return cv_evaluator, summary

# --------------------------
# Main Enhanced Pipeline
# --------------------------

def main():
    set_seed(42)
    
    print("[Info] Enhanced TCR Analysis Pipeline with Logging and Cross-Validation")
    print("="*70)
    
    # Load data
    if not os.path.exists("data.csv"):
        raise FileNotFoundError("data.csv not found! Please ensure your data file exists.")
    
    print("[Info] Loading training data...")
    df = pd.read_csv("data.csv")
    required_cols = {"CDR3", "MHC", "Epitope"}
    if not required_cols.issubset(set(df.columns)):
        raise ValueError(f"data.csv must contain columns: {required_cols}")
    
    # Enhanced data cleaning
    df["CDR3"] = df["CDR3"].astype(str).str.upper()
    df["MHC"] = df["MHC"].astype(str).str.upper()
    df["Epitope"] = df["Epitope"].astype(str).str.upper()
    df = df.dropna().drop_duplicates()
    df = df[df["Epitope"].map(lambda p: all(ch in AA_SET for ch in p))]
    df = df[df["CDR3"].map(lambda p: all(ch in AA_SET for ch in p))]
    
    print(f"[Info] Loaded {len(df)} training examples")
    
    # Enhanced negative sampling with balanced positive augmentation
    print("[Step] Generating enhanced negative samples with positive balancing...")
    df_all = enhanced_build_negatives_with_balance(df, k_neg=10, balance_ratio=1.2)
    
    # Run 5-fold cross-validation
    print("[Step] Running 5-fold cross-validation...")
    cv_evaluator, cv_summary = run_cross_validation(
        df_all, 
        esm_model_name="facebook/esm2_t12_35M_UR50D",
        d_model=128,
        lr=1e-4,
        batch_size=64,
        epochs=10 
    )
    
    print("\n" + "="*70)
    print("‚úÖ ENHANCED ANALYSIS WITH LOGGING AND CV COMPLETED!")
    print("üìä Cross-Validation Results:")
    for metric, stats in cv_summary.items():
        print(f"   {metric.upper()}: {stats['mean']:.4f} ¬± {stats['std']:.4f} (CI: [{stats['ci_lower']:.4f}, {stats['ci_upper']:.4f}])")
    print("üìÅ Generated Files:")
    print("   ‚Ä¢ training_logs/ - Training history and plots")
    print("   ‚Ä¢ cv_results.json - Cross-validation results")
    print("   ‚Ä¢ cv_results.png - Cross-validation plot")
    print("="*70)

if __name__ == "__main__":
    main()

2025-09-23 18:13:35.957687: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-09-23 18:13:37.137574: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


[Info] Enhanced TCR Analysis Pipeline with Logging and Cross-Validation
[Info] Loading training data...
[Info] Loaded 5004 training examples
[Step] Generating enhanced negative samples with positive balancing...
[Info] Generated 209419 negative samples for 5004 positive samples
[Info] Balancing dataset: target positive samples = 174515
[Info] Augmenting 169511 positive samples...
[Info] Augmented positive samples: 5004 -> 166094
[Info] Final balanced dataset: 166094 positives + 86644 negatives = 252738 total
[Info] Positive/Negative ratio: 1.92
[Step] Running 5-fold cross-validation...

RUNNING 5-FOLD CROSS-VALIDATION

[Fold 1/5] Starting training...
[Fold 1] Train: 175813 samples, Val: 54811 samples
[Info] Loading Enhanced ESM model: facebook/esm2_t12_35M_UR50D


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[Info] Enhanced ESM loaded - Hidden size: 480, Trainable layers: 10/12
[Info] Training Enhanced Model with Logging - Device: cuda
[Info] Mixed Precision: bf16
[Info] Epochs: 10, Batch Size: 64, Model Dim: 128


Epoch 1 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2747/2747 [18:36<00:00,  2.46it/s, loss=0.6077, bce=0.5257, infonce=0.5860]


Epoch 1:
  Train: Loss=0.6652, BCE=0.5791, InfoNCE=0.6150
  Val: Loss=0.6084, AUC=0.5720, AUPRC=0.7591
  Val: Acc=0.7261, F1=0.8357, Prec=0.7314, Rec=0.9748
  LR=6.00e-07, Temp=0.0783, Time=1116.9s+184.9s
  ‚úì New best model! AUC: 0.5720


Epoch 2 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2747/2747 [18:08<00:00,  2.52it/s, loss=0.5458, bce=0.4547, infonce=0.5062]


Epoch 2:
  Train: Loss=0.6419, BCE=0.5367, InfoNCE=0.5846
  Val: Loss=0.6646, AUC=0.5092, AUPRC=0.7104
  Val: Acc=0.7216, F1=0.8318, Prec=0.7321, Rec=0.9631
  LR=6.09e-07, Temp=0.0814, Time=1088.9s+181.6s


Epoch 3 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2747/2747 [18:04<00:00,  2.53it/s, loss=0.6427, bce=0.5324, infonce=0.5014]


Epoch 3:
  Train: Loss=0.6381, BCE=0.5139, InfoNCE=0.5644
  Val: Loss=0.7143, AUC=0.5190, AUPRC=0.7228
  Val: Acc=0.6867, F1=0.7994, Prec=0.7370, Rec=0.8733
  LR=3.71e-06, Temp=0.0837, Time=1084.7s+181.6s


Epoch 4 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2747/2747 [18:04<00:00,  2.53it/s, loss=0.6059, bce=0.4664, infonce=0.5363]


Epoch 4:
  Train: Loss=0.6067, BCE=0.4702, InfoNCE=0.5251
  Val: Loss=0.7235, AUC=0.5440, AUPRC=0.7495
  Val: Acc=0.7160, F1=0.8287, Prec=0.7284, Rec=0.9610
  LR=6.14e-07, Temp=0.0844, Time=1084.2s+181.7s


Epoch 5 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2747/2747 [18:03<00:00,  2.54it/s, loss=0.6621, bce=0.4943, infonce=0.5592]


Epoch 5:
  Train: Loss=0.6127, BCE=0.4585, InfoNCE=0.5141
  Val: Loss=0.7257, AUC=0.5794, AUPRC=0.8000
  Val: Acc=0.7142, F1=0.8279, Prec=0.7268, Rec=0.9618
  LR=4.83e-06, Temp=0.0842, Time=1083.5s+181.7s
  ‚úì New best model! AUC: 0.5794


Epoch 6 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2747/2747 [18:05<00:00,  2.53it/s, loss=0.5675, bce=0.4049, infonce=0.4783]


Epoch 6:
  Train: Loss=0.6250, BCE=0.4520, InfoNCE=0.5090
  Val: Loss=0.7637, AUC=0.5711, AUPRC=0.8017
  Val: Acc=0.7056, F1=0.8213, Prec=0.7254, Rec=0.9464
  LR=3.71e-06, Temp=0.0845, Time=1085.1s+182.0s


Epoch 7 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2747/2747 [18:04<00:00,  2.53it/s, loss=0.5542, bce=0.3869, infonce=0.4403]


Epoch 7:
  Train: Loss=0.5931, BCE=0.4135, InfoNCE=0.4725
  Val: Loss=0.7592, AUC=0.6198, AUPRC=0.8303
  Val: Acc=0.7109, F1=0.8270, Prec=0.7226, Rec=0.9665
  LR=2.05e-06, Temp=0.0838, Time=1084.7s+181.5s
  ‚úì New best model! AUC: 0.6198


Epoch 8 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2747/2747 [18:05<00:00,  2.53it/s, loss=0.5268, bce=0.3516, infonce=0.4170]


Epoch 8:
  Train: Loss=0.5601, BCE=0.3762, InfoNCE=0.4379
  Val: Loss=0.7603, AUC=0.6373, AUPRC=0.8378
  Val: Acc=0.7154, F1=0.8301, Prec=0.7241, Rec=0.9725
  LR=6.16e-07, Temp=0.0837, Time=1085.3s+181.8s
  ‚úì New best model! AUC: 0.6373


Epoch 9 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2747/2747 [18:05<00:00,  2.53it/s, loss=0.5436, bce=0.3720, infonce=0.3730]


Epoch 9:
  Train: Loss=0.5518, BCE=0.3584, InfoNCE=0.4203
  Val: Loss=0.7212, AUC=0.6146, AUPRC=0.8258
  Val: Acc=0.7103, F1=0.8245, Prec=0.7271, Rec=0.9520
  LR=5.00e-06, Temp=0.0822, Time=1085.1s+181.6s


Epoch 10 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2747/2747 [18:04<00:00,  2.53it/s, loss=0.7248, bce=0.4564, infonce=0.5370]


Epoch 10:
  Train: Loss=0.6508, BCE=0.4142, InfoNCE=0.4733
  Val: Loss=0.7106, AUC=0.6251, AUPRC=0.8317
  Val: Acc=0.7154, F1=0.8272, Prec=0.7307, Rec=0.9531
  LR=4.83e-06, Temp=0.0809, Time=1084.7s+182.0s
[Info] Training history saved to training_logs/fold_1/training_history.csv
[Info] Training history saved to training_logs/fold_1/training_history.json
[Info] Training curves saved to training_logs/fold_1/training_curves.png


TypeError: Accelerator.load_state() takes from 1 to 2 positional arguments but 3 were given