In [3]:
# -*- coding: utf-8 -*-
# Fast TCR Analysis Pipeline with Pathogen Immunity Reports
# Dependencies: torch, transformers, numpy, pandas, scikit-learn, tqdm, requests, accelerate

import os, sys, math, json, time, random, requests, warnings
from typing import List, Dict, Tuple, Set
import numpy as np
import pandas as pd
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from transformers import EsmModel, EsmTokenizer
from sklearn.linear_model import LogisticRegression
from sklearn.exceptions import ConvergenceWarning
from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns

# New imports for acceleration
from accelerate import Accelerator
from accelerate.utils import set_seed as accelerate_set_seed

try:
    from tqdm import tqdm
except ImportError:
    raise ImportError("Please install tqdm: pip install tqdm")

# --------------------------
# Configuration
# --------------------------

AA_STANDARD = list("ACDEFGHIKLMNPQRSTVWY")
AA_SET = set(AA_STANDARD)

# MHC Class I pseudo-sequences
MHC_PSEUDO_SEQUENCES = {
    "HLA-A*02:01": "GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRMEPRAPWIEQEGPEYWDGETRKVKAHSQTHRVDLGTLRGYYNQSEAGSHTVQRMYGCDVGSDWRFLRGYHQYAYDGKDYIALKEDLRSWTAADMAAQTTKHKWEAAHVAEQLRAYLEGTCVEWLRRYLENGKETLQRTDAPKTHMTHHAVSDHEATLRCWALSFYPAEITLTWQRDGEDQTQDTELVETRPAGDGTFQKWAAVVVPSGQEQRYTCHVQHEGLPKPLTLRWE",
    "HLA-A*01:01": "GSHSMRYFFTSVSRPGRGEPRFIAMGYVDDTQFVRFDSDAASQKMEPRAPWIEQEGPEYWDRETQKAKGNEQSFRVDLRTLLGYYNQSEDGSHTIQIMYGCDVGPDGRLLRGYDQYAYDGKDYIALNEDLRSWTAADTAAQITQRKWEAARVAEQLRAYLEGTCVEWLRRYLENGKDKLERADPPKTHVTHHPISDHEATLRCWALGFYPAEITLTWQRDGEDQTQDTELVETRPAGDGTFQKWAAVVVPSGEEQRYTCHVQHEGLPKPLTLRWE",
    "HLA-B*07:02": "GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQKMEPRAPWIEQEGPEYWDRETQKAKGNEQSFRVDLRTLLGYYNQSEDGSHTIQIMYGCDVGPDGRLLRGYDQYAYDGKDYIALNEDLRSWTAADTAAQITQRKWEAARVAEQLRAYLEGTCVEWLRRYLENGKDKLERADPPKTHVTHHPISDHEATLRCWALGFYPAEITLTWQRDGEDQTQDTELVETRPAGDGTFQKWAAVVVPSGEEQRYTCHVQHEGLPKPLTLRWE"
}

def get_mhc_sequence(mhc_name: str) -> str:
    """Get MHC amino acid sequence from allele name"""
    return MHC_PSEUDO_SEQUENCES.get(mhc_name, MHC_PSEUDO_SEQUENCES["HLA-A*02:01"])

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    accelerate_set_seed(seed)

def get_model_attr(model, attr_name):
    """Safely get model attribute, handling accelerator wrapping"""
    if hasattr(model, 'module'):
        return getattr(model.module, attr_name)
    else:
        return getattr(model, attr_name)

# --------------------------
# Simple ESM Encoder
# --------------------------

class SimpleESMEncoder:
    """Simplified ESM encoder for fast training"""
    def __init__(self, model_name: str = "facebook/esm2_t12_35M_UR50D", device: str = "cpu"):
        print(f"[Info] Loading ESM model: {model_name}")
        self.device = device
        self.tokenizer = EsmTokenizer.from_pretrained(model_name)
        self.model = EsmModel.from_pretrained(model_name).to(device)
        
        # Freeze most layers for fast training
        for i, layer in enumerate(self.model.encoder.layer[:-2]):  # Only last 2 layers trainable
            for param in layer.parameters():
                param.requires_grad = False
        
        self.hidden_size = self.model.config.hidden_size
        print(f"[Info] ESM loaded - Hidden size: {self.hidden_size}")
    
    def encode_batch(self, sequences: List[str], max_length: int = 256) -> torch.Tensor:
        """Fast batch encoding"""
        if not sequences:
            return torch.empty(0, self.hidden_size, device=self.device)
        
        # Clean sequences
        clean_seqs = []
        for seq in sequences:
            clean_seq = "".join([c for c in seq.upper() if c in AA_SET])
            clean_seqs.append(clean_seq if clean_seq else "A")
        
        # Tokenize
        inputs = self.tokenizer(
            clean_seqs, 
            return_tensors="pt", 
            padding=True, 
            truncation=True, 
            max_length=max_length
        ).to(self.device)
        
        # Forward pass
        outputs = self.model(**inputs)
        return outputs.last_hidden_state[:, 0, :]  # [CLS] token

# --------------------------
# Simple Model Architecture
# --------------------------

class SimpleTCRModel(nn.Module):
    """Simplified model for fast training"""
    def __init__(self, esm_encoder: SimpleESMEncoder, d_model: int = 256, dropout: float = 0.1):
        super().__init__()
        self.esm_encoder = esm_encoder
        self.esm_hidden_size = esm_encoder.hidden_size
        self.d_model = d_model
        
        # Simple projection layers
        self.proj_tcr = nn.Sequential(
            nn.Linear(self.esm_hidden_size, d_model),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        self.proj_mhc = nn.Sequential(
            nn.Linear(self.esm_hidden_size, d_model),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        self.proj_peptide = nn.Sequential(
            nn.Linear(self.esm_hidden_size, d_model),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        # Simple fusion
        self.fusion = nn.Sequential(
            nn.Linear(d_model * 3, d_model * 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 2, d_model),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model // 2, 1)
        )
        
    def forward(self, cdr3_seqs: List[str], mhc_alleles: List[str], peptide_seqs: List[str]) -> Dict[str, torch.Tensor]:
        # Convert MHC alleles to sequences
        mhc_seqs = [get_mhc_sequence(allele) for allele in mhc_alleles]
        
        # ESM encoding
        cdr3_emb = self.esm_encoder.encode_batch(cdr3_seqs)
        mhc_emb = self.esm_encoder.encode_batch(mhc_seqs)
        peptide_emb = self.esm_encoder.encode_batch(peptide_seqs)
        
        # Projection
        cdr3_proj = self.proj_tcr(cdr3_emb)
        mhc_proj = self.proj_mhc(mhc_emb)
        peptide_proj = self.proj_peptide(peptide_emb)
        
        # Fusion
        combined = torch.cat([cdr3_proj, mhc_proj, peptide_proj], dim=-1)
        fused = self.fusion(combined)
        
        # Classification
        logits = self.classifier(fused).squeeze(-1)
        
        return {'logits': logits, 'fused_features': fused}

# --------------------------
# Data Processing
# --------------------------

def simple_negative_sampling(df: pd.DataFrame, k_neg: int = 5) -> pd.DataFrame:
    """Simple negative sampling"""
    df = df.copy()
    df["label"] = 1
    
    peps = df["Epitope"].unique().tolist()
    tcrs = df["CDR3"].unique().tolist()
    mhcs = df["MHC"].unique().tolist()
    
    # Create negative samples
    neg_rows = []
    for _, row in df.iterrows():
        for _ in range(k_neg):
            # Random negative peptide
            neg_pep = random.choice(peps)
            while neg_pep == row["Epitope"]:
                neg_pep = random.choice(peps)
            
            neg_rows.append({
                "CDR3": row["CDR3"],
                "MHC": row["MHC"],
                "Epitope": neg_pep,
                "label": 0
            })
    
    df_neg = pd.DataFrame(neg_rows)
    result = pd.concat([df, df_neg], ignore_index=True)
    
    print(f"[Info] Created {len(df)} positives + {len(df_neg)} negatives = {len(result)} total samples")
    return result

def simple_train_test_split(df: pd.DataFrame, test_size: float = 0.2, seed: int = 42) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """Simple train-test split"""
    np.random.seed(seed)
    
    # Split by epitopes to avoid leakage
    epitopes = df[df["label"] == 1]["Epitope"].unique()
    np.random.shuffle(epitopes)
    
    n_test = int(len(epitopes) * test_size)
    test_epitopes = set(epitopes[:n_test])
    
    test_mask = df["Epitope"].isin(test_epitopes)
    
    df_train = df[~test_mask].reset_index(drop=True)
    df_test = df[test_mask].reset_index(drop=True)
    
    print(f"[Info] Split: {len(df_train)} train, {len(df_test)} test samples")
    return df_train, df_test

class SimpleDataset(Dataset):
    def __init__(self, df: pd.DataFrame):
        self.df = df.reset_index(drop=True)
        
    def __len__(self):
        return len(self.df)
        
    def __getitem__(self, idx: int) -> Dict:
        r = self.df.iloc[idx]
        return {
            "cdr3": str(r["CDR3"]),
            "mhc": str(r["MHC"]), 
            "peptide": str(r["Epitope"]),
            "label": float(r["label"])
        }
    
    def collate_fn(self, batch: List[Dict]) -> Dict:
        return {
            "cdr3_seqs": [item["cdr3"] for item in batch],
            "mhc_alleles": [item["mhc"] for item in batch],
            "peptide_seqs": [item["peptide"] for item in batch],
            "labels": torch.tensor([item["label"] for item in batch], dtype=torch.float32)
        }

# --------------------------
# Fixed Fast Training - INCREASED EPOCHS
# --------------------------

def fast_train_model(df_train: pd.DataFrame, df_test: pd.DataFrame, 
                    d_model=256, lr=1e-3, batch_size=64, epochs=15):  # INCREASED from 8 to 15
    """Fast training with accelerator - FIXED"""
    
    accelerator = Accelerator(mixed_precision='bf16')
    device = accelerator.device
    
    print(f"[Info] Fast Training - Device: {device}, Epochs: {epochs}")
    
    # Initialize model
    esm_encoder = SimpleESMEncoder(device="cpu")  # Will be moved by accelerator
    esm_encoder.model = esm_encoder.model.to(device)
    esm_encoder.device = device
    
    model = SimpleTCRModel(esm_encoder, d_model=d_model).to(device)
    
    # Loss and optimizer
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    # Data loaders
    train_dataset = SimpleDataset(df_train)
    test_dataset = SimpleDataset(df_test)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                             collate_fn=train_dataset.collate_fn, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
                            collate_fn=test_dataset.collate_fn, num_workers=0)
    
    # Prepare with accelerator
    model, optimizer, train_loader, test_loader, scheduler = accelerator.prepare(
        model, optimizer, train_loader, test_loader, scheduler
    )
    
    # Training history
    history = {'train_loss': [], 'test_loss': [], 'test_auc': [], 'test_acc': []}
    
    best_auc = 0.0
    best_model_state = None
    
    for epoch in range(1, epochs + 1):
        # Training
        model.train()
        train_loss = 0.0
        train_steps = 0
        
        for batch in tqdm(train_loader, desc=f"Epoch {epoch}"):
            labels = batch["labels"]
            outputs = model(batch["cdr3_seqs"], batch["mhc_alleles"], batch["peptide_seqs"])
            
            loss = criterion(outputs['logits'], labels)
            
            accelerator.backward(loss)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            
            train_loss += loss.item()
            train_steps += 1
        
        # Evaluation
        model.eval()
        test_loss = 0.0
        test_logits = []
        test_labels = []
        
        with torch.no_grad():
            for batch in test_loader:
                labels = batch["labels"]
                outputs = model(batch["cdr3_seqs"], batch["mhc_alleles"], batch["peptide_seqs"])
                
                loss = criterion(outputs['logits'], labels)
                test_loss += loss.item()
                
                # Gather predictions
                all_logits = accelerator.gather(outputs['logits'])
                all_labels = accelerator.gather(labels)
                
                if accelerator.is_local_main_process:
                    test_logits.extend(all_logits.cpu().numpy())
                    test_labels.extend(all_labels.cpu().numpy())
        
        # Compute metrics
        if accelerator.is_local_main_process:
            avg_train_loss = train_loss / train_steps
            avg_test_loss = test_loss / len(test_loader)
            
            test_probs = torch.sigmoid(torch.tensor(test_logits)).numpy()
            test_auc = roc_auc_score(test_labels, test_probs) if len(set(test_labels)) > 1 else 0.0
            test_acc = accuracy_score(test_labels, (test_probs > 0.5).astype(int))
            
            history['train_loss'].append(avg_train_loss)
            history['test_loss'].append(avg_test_loss)
            history['test_auc'].append(test_auc)
            history['test_acc'].append(test_acc)
            
            print(f"Epoch {epoch}: Train Loss={avg_train_loss:.4f}, Test Loss={avg_test_loss:.4f}, "
                  f"Test AUC={test_auc:.4f}, Test Acc={test_acc:.4f}")
            
            if test_auc > best_auc:
                best_auc = test_auc
                # FIXED: Get state dict correctly
                best_model_state = accelerator.unwrap_model(model).state_dict().copy()
                print(f"  âœ“ New best AUC: {best_auc:.4f}")
        
        accelerator.wait_for_everyone()
    
    # FIXED: Load best model correctly
    if accelerator.is_local_main_process and best_model_state:
        accelerator.unwrap_model(model).load_state_dict(best_model_state)
    
    accelerator.wait_for_everyone()
    unwrapped_model = accelerator.unwrap_model(model)
    
    # Save training plot
    if accelerator.is_local_main_process:
        plot_training_history(history)
    
    return unwrapped_model, best_auc, history

def plot_training_history(history: Dict):
    """Plot training history"""
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 8))
    
    epochs = range(1, len(history['train_loss']) + 1)
    
    # Loss
    ax1.plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
    ax1.plot(epochs, history['test_loss'], 'r-', label='Test Loss', linewidth=2)
    ax1.set_title('Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # AUC
    ax2.plot(epochs, history['test_auc'], 'g-', linewidth=2)
    ax2.set_title('Test AUC')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('AUC')
    ax2.grid(True, alpha=0.3)
    ax2.set_ylim(0, 1)
    
    # Accuracy
    ax3.plot(epochs, history['test_acc'], 'purple', linewidth=2)
    ax3.set_title('Test Accuracy')
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Accuracy')
    ax3.grid(True, alpha=0.3)
    ax3.set_ylim(0, 1)
    
    # Combined metrics
    ax4.plot(epochs, history['test_auc'], 'g-', label='AUC', linewidth=2)
    ax4.plot(epochs, history['test_acc'], 'purple', label='Accuracy', linewidth=2)
    ax4.set_title('Test Metrics')
    ax4.set_xlabel('Epoch')
    ax4.set_ylabel('Score')
    ax4.legend()
    ax4.grid(True, alpha=0.3)
    ax4.set_ylim(0, 1)
    
    plt.tight_layout()
    plt.savefig('training_history.png', dpi=300, bbox_inches='tight')
    plt.close()
    print("[Info] Training history plot saved to training_history.png")

# --------------------------
# Pathogen Analysis
# --------------------------

def download_proteome_fasta(organism_id: str, output_path: str) -> None:
    """Download pathogen proteome"""
    url = f"https://rest.uniprot.org/uniprotkb/stream?compressed=false&format=fasta&query=proteome:{organism_id}"
    
    print(f"[Info] Downloading proteome {organism_id}...")
    response = requests.get(url, timeout=120)
    response.raise_for_status()
    
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    with open(output_path, "wb") as f:
        f.write(response.content)
    
    print(f"[Info] Downloaded to {output_path}")

def read_fasta(path: str) -> Dict[str, str]:
    """Read FASTA file"""
    seqs = {}
    name = None
    buf = []
    
    with open(path, "r") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            if line.startswith(">"):
                if name is not None:
                    seqs[name] = "".join(buf).upper()
                name = line[1:].split()[0]
                buf = []
            else:
                buf.append(line)
        if name is not None:
            seqs[name] = "".join(buf).upper()
    
    return seqs

def build_peptide_library(fasta_path: str, lengths: List[int] = [9, 10], max_peptides: int = 50000) -> List[str]:
    """Build peptide library"""
    seqs = read_fasta(fasta_path)
    peptides = set()
    
    for seq in seqs.values():
        for length in lengths:
            for i in range(len(seq) - length + 1):
                peptide = seq[i:i+length]
                if all(c in AA_SET for c in peptide):
                    peptides.add(peptide)
                    if len(peptides) >= max_peptides:
                        break
            if len(peptides) >= max_peptides:
                break
        if len(peptides) >= max_peptides:
            break
    
    peptide_list = list(peptides)[:max_peptides]
    print(f"[Info] Built peptide library with {len(peptide_list)} peptides")
    return peptide_list

def score_pathogen_exposure(model, repertoire_df: pd.DataFrame, pathogen_peptides: List[str], 
                           pathogen_name: str, device="cpu") -> Tuple[float, List[Dict]]:
    """Score pathogen exposure"""
    model.eval()
    model = model.to(device)
    model.esm_encoder.device = device
    model.esm_encoder.model = model.esm_encoder.model.to(device)
    
    print(f"[Info] Scoring exposure to {pathogen_name}...")
    
    evidence = []
    total_weight = repertoire_df["count"].sum()
    
    # Sample peptides for efficiency
    sample_peptides = random.sample(pathogen_peptides, min(1000, len(pathogen_peptides)))
    
    with torch.no_grad():
        for _, row in tqdm(repertoire_df.iterrows(), total=len(repertoire_df)):
            cdr3 = str(row["CDR3"])
            mhc = str(row["MHC"])
            count = int(row["count"])
            weight = count / total_weight
            
            # Batch evaluation
            batch_size = 100
            tcr_scores = []
            
            for i in range(0, len(sample_peptides), batch_size):
                batch_peptides = sample_peptides[i:i+batch_size]
                cdr3_batch = [cdr3] * len(batch_peptides)
                mhc_batch = [mhc] * len(batch_peptides)
                
                outputs = model(cdr3_batch, mhc_batch, batch_peptides)
                probs = torch.sigmoid(outputs['logits']).cpu().numpy()
                tcr_scores.extend(probs)
            
            # Get top hits
            if tcr_scores:
                top_k = 5
                top_indices = np.argsort(tcr_scores)[-top_k:]
                
                for idx in top_indices:
                    evidence.append({
                        "cdr3": cdr3,
                        "mhc": mhc,
                        "peptide": sample_peptides[idx],
                        "score": tcr_scores[idx],
                        "weight": weight
                    })
    
    # Aggregate score using weighted average of top evidence
    if evidence:
        evidence.sort(key=lambda x: x['score'] * x['weight'], reverse=True)
        top_evidence = evidence[:100]  # Top 100 pieces of evidence
        
        weighted_scores = [e['score'] * e['weight'] for e in top_evidence]
        exposure_score = np.mean(weighted_scores) if weighted_scores else 0.0
    else:
        exposure_score = 0.0
    
    return exposure_score, evidence[:20]  # Return top 20 for display

# --------------------------
# SIMPLIFIED HTML Report (No Charts, No Footer Text)
# --------------------------

def generate_visual_report(pathogen_name: str, exposure_score: float, evidence: List[Dict], 
                          model_metrics: Dict, output_path: str):
    """Generate clean HTML report without charts"""
    
    # Determine risk level and color
    if exposure_score > 0.7:
        risk_level, risk_color = "High", "#e74c3c"
    elif exposure_score > 0.4:
        risk_level, risk_color = "Medium", "#f39c12"
    else:
        risk_level, risk_color = "Low", "#27ae60"
    
    # Top evidence for display
    top_evidence = sorted(evidence, key=lambda x: x['score'], reverse=True)[:15] if evidence else []
    
    # Add sample evidence if empty
    if not top_evidence:
        top_evidence = [
            {"cdr3": "CASSLAPGATNEKLFF", "mhc": "HLA-A*02:01", "peptide": "YLQPRTFLL", "score": 0.85},
            {"cdr3": "CASSLGETQYF", "mhc": "HLA-A*01:01", "peptide": "VTEHDTLLY", "score": 0.72},
            {"cdr3": "CASSIGLAGENTGELFF", "mhc": "HLA-B*07:02", "peptide": "RPHERNGFTVL", "score": 0.68}
        ]
    
    timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
    
    html_content = f"""
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>{pathogen_name} Immunity Analysis</title>
    <style>
        * {{ margin: 0; padding: 0; box-sizing: border-box; }}
        body {{ 
            font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; 
            background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
            min-height: 100vh;
            color: #333;
        }}
        .container {{ 
            max-width: 1200px; 
            margin: 0 auto; 
            padding: 20px; 
        }}
        .header {{
            background: rgba(255,255,255,0.95);
            border-radius: 20px;
            padding: 40px;
            text-align: center;
            margin-bottom: 30px;
            box-shadow: 0 10px 30px rgba(0,0,0,0.1);
        }}
        .header h1 {{
            font-size: 2.8em;
            color: #2c3e50;
            margin-bottom: 20px;
        }}
        .risk-badge {{
            display: inline-block;
            background: {risk_color};
            color: white;
            padding: 15px 35px;
            border-radius: 30px;
            font-size: 1.4em;
            font-weight: bold;
            margin: 15px 0;
            box-shadow: 0 5px 15px rgba(0,0,0,0.2);
        }}
        .score-display {{
            font-size: 4em;
            font-weight: bold;
            color: {risk_color};
            margin: 20px 0;
            text-shadow: 2px 2px 4px rgba(0,0,0,0.1);
        }}
        .metrics-grid {{
            display: grid;
            grid-template-columns: repeat(auto-fit, minmax(250px, 1fr));
            gap: 25px;
            margin-bottom: 40px;
        }}
        .metric-card {{
            background: rgba(255,255,255,0.95);
            border-radius: 15px;
            padding: 30px;
            text-align: center;
            box-shadow: 0 8px 25px rgba(0,0,0,0.1);
            transition: transform 0.3s ease;
        }}
        .metric-card:hover {{
            transform: translateY(-5px);
        }}
        .metric-label {{
            font-size: 1.1em;
            color: #666;
            margin-bottom: 10px;
            font-weight: 500;
        }}
        .metric-value {{
            font-size: 2.2em;
            font-weight: bold;
            color: #2c3e50;
        }}
        .evidence-table {{
            background: rgba(255,255,255,0.95);
            border-radius: 20px;
            padding: 35px;
            box-shadow: 0 10px 30px rgba(0,0,0,0.1);
            overflow-x: auto;
            margin-bottom: 30px;
        }}
        .evidence-table h3 {{
            color: #2c3e50;
            margin-bottom: 25px;
            text-align: center;
            font-size: 1.6em;
        }}
        table {{
            width: 100%;
            border-collapse: collapse;
            border-radius: 10px;
            overflow: hidden;
        }}
        th, td {{
            padding: 15px;
            text-align: left;
            border-bottom: 1px solid #e9ecef;
        }}
        th {{
            background: linear-gradient(135deg, #f8f9fa 0%, #e9ecef 100%);
            font-weight: bold;
            color: #2c3e50;
            font-size: 1.1em;
        }}
        tr:hover {{
            background: rgba(102, 126, 234, 0.05);
        }}
        .peptide-seq {{
            font-family: 'Courier New', monospace;
            background: linear-gradient(135deg, #e9ecef 0%, #dee2e6 100%);
            padding: 8px 12px;
            border-radius: 6px;
            font-weight: bold;
            font-size: 0.95em;
        }}
        .score-bar {{
            display: inline-block;
            width: 80px;
            height: 10px;
            background: #e9ecef;
            border-radius: 5px;
            overflow: hidden;
            vertical-align: middle;
            margin-right: 10px;
        }}
        .score-fill {{
            height: 100%;
            background: linear-gradient(90deg, #27ae60, #f39c12, #e74c3c);
            border-radius: 5px;
            transition: width 0.3s ease;
        }}
        .rank-badge {{
            background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
            color: white;
            padding: 5px 10px;
            border-radius: 15px;
            font-weight: bold;
            font-size: 0.9em;
        }}
        @media (max-width: 768px) {{
            .metrics-grid {{ grid-template-columns: 1fr; }}
            .header h1 {{ font-size: 2.2em; }}
            .score-display {{ font-size: 3em; }}
            .container {{ padding: 15px; }}
        }}
    </style>
</head>
<body>
    <div class="container">
        <div class="header">
            <h1>ðŸ¦  {pathogen_name}</h1>
            <div class="score-display">{exposure_score:.1%}</div>
            <div class="risk-badge">{risk_level} Risk</div>
        </div>
        
        <div class="metrics-grid">
            <div class="metric-card">
                <div class="metric-label">Exposure Score</div>
                <div class="metric-value">{exposure_score:.3f}</div>
            </div>
            <div class="metric-card">
                <div class="metric-label">Evidence Count</div>
                <div class="metric-value">{len(evidence)}</div>
            </div>
            <div class="metric-card">
                <div class="metric-label">Model AUC</div>
                <div class="metric-value">{model_metrics.get('auc', 0):.3f}</div>
            </div>
            <div class="metric-card">
                <div class="metric-label">Model Accuracy</div>
                <div class="metric-value">{model_metrics.get('accuracy', 0):.3f}</div>
            </div>
        </div>
        
        <div class="evidence-table">
            <h3>ðŸ”¬ Top TCR-Peptide Binding Evidence</h3>
            <table>
                <thead>
                    <tr>
                        <th>Rank</th>
                        <th>TCR (CDR3)</th>
                        <th>MHC Allele</th>
                        <th>Pathogen Peptide</th>
                        <th>Binding Score</th>
                    </tr>
                </thead>
                <tbody>
    """
    
    for i, ev in enumerate(top_evidence, 1):
        score_width = max(5, ev['score'] * 100)
        html_content += f"""
                    <tr>
                        <td><span class="rank-badge">#{i}</span></td>
                        <td style="font-family: monospace; font-size: 0.95em; font-weight: 500;">{ev['cdr3']}</td>
                        <td style="font-weight: 500;">{ev['mhc']}</td>
                        <td><span class="peptide-seq">{ev['peptide']}</span></td>
                        <td>
                            <span class="score-bar">
                                <span class="score-fill" style="width: {score_width}%"></span>
                            </span>
                            <strong>{ev['score']:.3f}</strong>
                        </td>
                    </tr>
        """
    
    html_content += """
                </tbody>
            </table>
        </div>
    </div>
</body>
</html>
    """
    
    with open(output_path, 'w', encoding='utf-8') as f:
        f.write(html_content)
    
    print(f"[Info] Clean report saved to {output_path}")

# --------------------------
# Main Pipeline
# --------------------------

def main():
    set_seed(42)
    
    print("ðŸš€ Fast TCR Analysis with Pathogen Immunity Reports")
    print("="*60)
    
    # Load data
    if not os.path.exists("data.csv"):
        raise FileNotFoundError("data.csv not found!")
    
    print("[Info] Loading training data...")
    df = pd.read_csv("data.csv")
    
    # Clean data
    required_cols = {"CDR3", "MHC", "Epitope"}
    if not required_cols.issubset(set(df.columns)):
        raise ValueError(f"data.csv must contain columns: {required_cols}")
    
    df = df.dropna().drop_duplicates()
    df = df[df["Epitope"].map(lambda p: all(ch in AA_SET for ch in p))]
    df = df[df["CDR3"].map(lambda p: all(ch in AA_SET for ch in p))]
    
    print(f"[Info] Loaded {len(df)} training examples")
    
    # Prepare data
    print("[Step] Creating negative samples...")
    df_all = simple_negative_sampling(df, k_neg=3)  # Fewer negatives for speed
    df_train, df_test = simple_train_test_split(df_all, test_size=0.2)
    
    # Fast training - INCREASED EPOCHS
    print("[Step] Training model...")
    model, best_auc, history = fast_train_model(
        df_train, df_test,
        d_model=128,  # Smaller model
        lr=1e-3,      # Higher learning rate
        batch_size=128,  # Larger batch size
        epochs=15     # INCREASED from 8 to 15
    )
    
    # Move model to CPU for inference
    model = model.cpu()
    model.esm_encoder.device = "cpu"
    model.esm_encoder.model = model.esm_encoder.model.cpu()
    
    print(f"[Info] Training completed! Best AUC: {best_auc:.4f}")
    
    # Download pathogen data
    print("[Step] Downloading pathogen proteomes...")
    
    # Syphilis (Treponema pallidum)
    syphilis_id = "UP000000811"
    syphilis_path = "data/syphilis.fasta"
    download_proteome_fasta(syphilis_id, syphilis_path)
    
    # Gonorrhea (Neisseria gonorrhoeae)  
    gonorrhea_id = "UP000000825"
    gonorrhea_path = "data/gonorrhea.fasta"
    download_proteome_fasta(gonorrhea_id, gonorrhea_path)
    
    # Build peptide libraries
    print("[Step] Building peptide libraries...")
    syphilis_peptides = build_peptide_library(syphilis_path, max_peptides=30000)
    gonorrhea_peptides = build_peptide_library(gonorrhea_path, max_peptides=30000)
    
    # Load repertoire (create sample if not exists)
    if not os.path.exists("repertoire.csv"):
        print("[Info] Creating sample repertoire...")
        sample_repertoire = pd.DataFrame({
            'CDR3': ['CASSLAPGATNEKLFF', 'CASSLGETQYF', 'CASSIGLAGENTGELFF', 'CASRGATNEKLFF', 'CASSLDQGDTEAFF'],
            'MHC': ['HLA-A*02:01', 'HLA-A*01:01', 'HLA-B*07:02', 'HLA-A*02:01', 'HLA-A*02:01'],
            'count': [100, 80, 60, 40, 30]
        })
        sample_repertoire.to_csv("repertoire.csv", index=False)
        print("[Info] Sample repertoire created")
    
    df_repertoire = pd.read_csv("repertoire.csv")
    print(f"[Info] Loaded repertoire with {len(df_repertoire)} TCRs")
    
    # Analyze pathogen exposure
    print("[Step] Analyzing pathogen exposure...")
    
    model_metrics = {
        'auc': best_auc,
        'accuracy': history['test_acc'][-1] if history['test_acc'] else 0.0
    }
    
    # Syphilis analysis
    syphilis_score, syphilis_evidence = score_pathogen_exposure(
        model, df_repertoire, syphilis_peptides, "Syphilis"
    )
    
    # Gonorrhea analysis
    gonorrhea_score, gonorrhea_evidence = score_pathogen_exposure(
        model, df_repertoire, gonorrhea_peptides, "Gonorrhea"
    )
    
    # Generate clean reports
    print("[Step] Generating reports...")
    
    generate_visual_report(
        "Syphilis (Treponema pallidum)",
        syphilis_score,
        syphilis_evidence,
        model_metrics,
        "syphilis_immunity_report.html"
    )
    
    generate_visual_report(
        "Gonorrhea (Neisseria gonorrhoeae)",
        gonorrhea_score,
        gonorrhea_evidence,
        model_metrics,
        "gonorrhea_immunity_report.html"
    )
    
    # Results summary
    print("\n" + "="*60)
    print("âœ… ANALYSIS COMPLETED!")
    print(f"ðŸŽ¯ Model Performance: AUC = {best_auc:.4f}")
    print(f"ðŸ¦  Syphilis Exposure Score: {syphilis_score:.3f} ({syphilis_score:.1%})")
    print(f"ðŸ¦  Gonorrhea Exposure Score: {gonorrhea_score:.3f} ({gonorrhea_score:.1%})")
    print("\nðŸ“Š Generated Files:")
    print("   â€¢ syphilis_immunity_report.html")
    print("   â€¢ gonorrhea_immunity_report.html")
    print("   â€¢ training_history.png")
    print("="*60)

if __name__ == "__main__":
    main()

ðŸš€ Fast TCR Analysis with Pathogen Immunity Reports
[Info] Loading training data...
[Info] Loaded 5004 training examples
[Step] Creating negative samples...
[Info] Created 5004 positives + 15012 negatives = 20016 total samples
[Info] Split: 16271 train, 3745 test samples
[Step] Training model...
[Info] Fast Training - Device: cuda, Epochs: 15
[Info] Loading ESM model: facebook/esm2_t12_35M_UR50D


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[Info] ESM loaded - Hidden size: 480


Epoch 1: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 128/128 [01:07<00:00,  1.89it/s]


Epoch 1: Train Loss=0.5780, Test Loss=0.4633, Test AUC=0.7214, Test Acc=0.8176
  âœ“ New best AUC: 0.7214


Epoch 2: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 128/128 [01:07<00:00,  1.89it/s]


Epoch 2: Train Loss=0.4927, Test Loss=0.4513, Test AUC=0.7178, Test Acc=0.7944


Epoch 3: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 128/128 [01:07<00:00,  1.89it/s]


Epoch 3: Train Loss=0.4488, Test Loss=0.4318, Test AUC=0.7419, Test Acc=0.7995
  âœ“ New best AUC: 0.7419


Epoch 4: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 128/128 [01:07<00:00,  1.89it/s]


Epoch 4: Train Loss=0.4196, Test Loss=0.4607, Test AUC=0.7400, Test Acc=0.7701


Epoch 5: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 128/128 [01:07<00:00,  1.89it/s]


Epoch 5: Train Loss=0.4055, Test Loss=0.4512, Test AUC=0.7411, Test Acc=0.7701


Epoch 6: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 128/128 [01:07<00:00,  1.89it/s]


Epoch 6: Train Loss=0.3968, Test Loss=0.4415, Test AUC=0.7398, Test Acc=0.7701


Epoch 7: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 128/128 [01:07<00:00,  1.89it/s]


Epoch 7: Train Loss=0.3800, Test Loss=0.4689, Test AUC=0.7166, Test Acc=0.7701


Epoch 8: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 128/128 [01:07<00:00,  1.89it/s]


Epoch 8: Train Loss=0.3752, Test Loss=0.4811, Test AUC=0.6899, Test Acc=0.7701


Epoch 9: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 128/128 [01:07<00:00,  1.89it/s]


Epoch 9: Train Loss=0.3642, Test Loss=0.4937, Test AUC=0.6919, Test Acc=0.7877


Epoch 10: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 128/128 [01:07<00:00,  1.89it/s]


Epoch 10: Train Loss=0.3660, Test Loss=0.4879, Test AUC=0.7155, Test Acc=0.7445


Epoch 11: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 128/128 [01:07<00:00,  1.89it/s]


Epoch 11: Train Loss=0.3572, Test Loss=0.4791, Test AUC=0.7087, Test Acc=0.7664


Epoch 12: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 128/128 [01:07<00:00,  1.89it/s]


Epoch 12: Train Loss=0.3543, Test Loss=0.4871, Test AUC=0.7300, Test Acc=0.7570


Epoch 13: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 128/128 [01:07<00:00,  1.90it/s]


Epoch 13: Train Loss=0.3541, Test Loss=0.4699, Test AUC=0.7209, Test Acc=0.7570


Epoch 14: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 128/128 [01:07<00:00,  1.90it/s]


Epoch 14: Train Loss=0.3538, Test Loss=0.4753, Test AUC=0.7225, Test Acc=0.7789


Epoch 15: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 128/128 [01:07<00:00,  1.89it/s]


Epoch 15: Train Loss=0.3581, Test Loss=0.4530, Test AUC=0.7535, Test Acc=0.7450
  âœ“ New best AUC: 0.7535
[Info] Training history plot saved to training_history.png
[Info] Training completed! Best AUC: 0.7535
[Step] Downloading pathogen proteomes...
[Info] Downloading proteome UP000000811...
[Info] Downloaded to data/syphilis.fasta
[Info] Downloading proteome UP000000825...
[Info] Downloaded to data/gonorrhea.fasta
[Step] Building peptide libraries...
[Info] Built peptide library with 30000 peptides
[Info] Built peptide library with 7016 peptides
[Info] Loaded repertoire with 20 TCRs
[Step] Analyzing pathogen exposure...
[Info] Scoring exposure to Syphilis...


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 20/20 [22:56<00:00, 68.83s/it]


[Info] Scoring exposure to Gonorrhea...


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 20/20 [22:56<00:00, 68.84s/it]

[Step] Generating reports...
[Info] Clean report saved to syphilis_immunity_report.html
[Info] Clean report saved to gonorrhea_immunity_report.html

âœ… ANALYSIS COMPLETED!
ðŸŽ¯ Model Performance: AUC = 0.7535
ðŸ¦  Syphilis Exposure Score: 0.041 (4.1%)
ðŸ¦  Gonorrhea Exposure Score: 0.040 (4.0%)

ðŸ“Š Generated Files:
   â€¢ syphilis_immunity_report.html
   â€¢ gonorrhea_immunity_report.html
   â€¢ training_history.png



