In [2]:
# Cell 1: Imports and device setup
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score, average_precision_score, matthews_corrcoef
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm
import esm
from descriptastorus.descriptors import rdNormalizedDescriptors
import gc
import psutil

# Device setup with memory management
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    
# Memory monitoring function
def check_memory():
    if torch.cuda.is_available():
        gpu_mem = torch.cuda.memory_allocated() / 1e9
        gpu_max = torch.cuda.max_memory_allocated() / 1e9
        print(f"GPU Memory: {gpu_mem:.2f} GB / {gpu_max:.2f} GB max")
    
    cpu_mem = psutil.virtual_memory()
    print(f"CPU Memory: {cpu_mem.percent:.1f}% used")



Using device: cuda
GPU: NVIDIA GeForce RTX 4090
GPU Memory: 25.3 GB


In [3]:
# Cell 2: Load combined dataset
print("Loading combined dataset...")
positives = pd.read_csv('../datasets/cysdb_positives.csv')
sampled_negatives = pd.read_csv('sampled_negatives.csv')

# Combine datasets
combined_df = pd.concat([positives, sampled_negatives], ignore_index=True)
combined_df = combined_df.sample(frac=1, random_state=42).reset_index(drop=True)

print(f"Dataset composition:")
print(f"  Positives: {len(positives):,}")
print(f"  Negatives: {len(sampled_negatives):,}")
print(f"  Total: {len(combined_df):,}")
print(f"  Balance: {combined_df['Activity'].mean():.3f}")

check_memory()



Loading combined dataset...
Dataset composition:
  Positives: 49,511
  Negatives: 49,511
  Total: 99,022
  Balance: 0.500
GPU Memory: 0.00 GB / 0.00 GB max
CPU Memory: 8.7% used


In [4]:
# Cell 3: Protein embeddings via ESM2-t33 - Fixed approach
print("Setting up ESM-2 33-layer model...")

# Clear GPU memory first
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    gc.collect()
    check_memory()

# Load ESM model
model_esm, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
model_esm = model_esm.eval().to(device)
batch_converter = alphabet.get_batch_converter()
print("ESM-2 33-layer model loaded successfully")

# Get unique proteins from the combined dataset
unique_proteins = combined_df[['Entry', 'Sequence']].drop_duplicates()
print(f"Unique proteins to embed: {len(unique_proteins):,}")

# Embedding parameters
MAX_LEN = 512
BATCH_SIZE = 4  # Start with small batch

# Storage for embeddings
protein_embeddings = {}
protein_sequences = {}

# Process proteins in batches
protein_list = unique_proteins.values.tolist()
print("Generating protein embeddings...")

for i in tqdm(range(0, len(protein_list), BATCH_SIZE), desc="ESM-2 batches"):
    batch = protein_list[i:i + BATCH_SIZE]
    
    # Prepare batch
    batch_data = []
    batch_entries = []
    
    for entry, seq in batch:
        seq_trunc = seq if len(seq) <= MAX_LEN else seq[:MAX_LEN]
        batch_data.append((entry, seq_trunc))
        batch_entries.append(entry)
    
    try:
        # Convert batch
        batch_labels, batch_strs, batch_tokens = batch_converter(batch_data)
        batch_tokens = batch_tokens.to(device)
        
        # Generate embeddings
        with torch.no_grad():
            results = model_esm(batch_tokens, repr_layers=[33], return_contacts=False)
            representations = results['representations'][33]
        
        # Store embeddings (mean pooling)
        for j, entry in enumerate(batch_entries):
            seq_len = len(batch_strs[j])
            rep = representations[j, 1:seq_len+1]  # Remove CLS token
            protein_embeddings[entry] = rep.mean(0).cpu().numpy().astype(np.float32)
            protein_sequences[entry] = batch_strs[j]
        
        # Cleanup
        del batch_tokens, results, representations
        torch.cuda.empty_cache()
        
    except RuntimeError as e:
        if "out of memory" in str(e):
            print(f"OOM at batch {i//BATCH_SIZE + 1}, reducing batch size...")
            torch.cuda.empty_cache()
            
            # Process one by one
            for entry, seq in batch:
                seq_trunc = seq if len(seq) <= MAX_LEN else seq[:MAX_LEN]
                single_data = [(entry, seq_trunc)]
                
                _, single_strs, single_tokens = batch_converter(single_data)
                single_tokens = single_tokens.to(device)
                
                with torch.no_grad():
                    single_results = model_esm(single_tokens, repr_layers=[33], return_contacts=False)
                    single_rep = single_results['representations'][33]
                
                seq_len = len(single_strs[0])
                rep = single_rep[0, 1:seq_len+1]
                protein_embeddings[entry] = rep.mean(0).cpu().numpy().astype(np.float32)
                protein_sequences[entry] = single_strs[0]
                
                del single_tokens, single_results, single_rep
                torch.cuda.empty_cache()
        else:
            raise e

print(f"Generated embeddings for {len(protein_embeddings)} proteins")

# Create mapping compatible with the training pipeline
prot_index = {}
mean_seq = {}

for i, (entry, seq) in enumerate(unique_proteins.values):
    if entry in protein_embeddings:
        label = f'P{i}'
        prot_index[label] = seq
        mean_seq[label] = protein_embeddings[entry]

print(f"Created mappings for {len(mean_seq)} proteins")

# Free ESM model memory
del model_esm, alphabet, batch_converter
del protein_embeddings, protein_sequences
torch.cuda.empty_cache()
gc.collect()

print("ESM model freed, ready for training")
check_memory()


Setting up ESM-2 33-layer model...
GPU Memory: 0.00 GB / 0.00 GB max
CPU Memory: 8.8% used


OutOfMemoryError: CUDA out of memory. Tried to allocate 26.00 MiB. GPU 0 has a total capacity of 23.52 GiB of which 16.75 MiB is free. Process 2115792 has 22.72 GiB memory in use. Including non-PyTorch memory, this process has 688.00 MiB memory in use. Of the allocated memory 282.40 MiB is allocated by PyTorch, and 21.60 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
# Cell 4: Molecule descriptors via RDKit
print("Computing molecular descriptors...")
gen = rdNormalizedDescriptors.RDKit2DNormalized()
mol_index, mol_feats = {}, {}
smiles_list = combined_df['SMILES'].unique().tolist()

for i, smi in enumerate(tqdm(smiles_list, desc="RDKit descriptors")):
    label = f'M{i}'
    mol_index[label] = smi
    try:
        desc = gen.process(smiles=smi)
        if desc is not None and len(desc) > 1:
            mol_feats[label] = np.array(desc[1:], dtype=np.float32)
    except Exception as e:
        print(f"Error processing {smi}: {e}")

print(f"Generated features for {len(mol_feats)} molecules")



In [None]:
# Cell 5: Map dataset to indices
print("Mapping dataset to feature indices...")
prot_map = {seq: pid for pid, seq in prot_index.items()}
mol_map = {smi: mid for mid, smi in mol_index.items()}

combined_df['P_index'] = combined_df['Sequence'].map(prot_map)
combined_df['M_index'] = combined_df['SMILES'].map(mol_map)

# Filter out any pairs without valid mappings
valid_pairs = combined_df.dropna(subset=['P_index', 'M_index']).copy()
print(f"Valid pairs after mapping: {len(valid_pairs):,}")

pair_indices = list(valid_pairs[['P_index','M_index']].itertuples(index=False, name=None))
y = valid_pairs['Activity'].values.astype(np.float32)

print(f"Final dataset: {len(y)} pairs, {y.mean():.3f} positive ratio")



In [None]:
# Cell 6: 5-Fold Cross-Validation Setup
from sklearn.model_selection import StratifiedKFold

print("Setting up 5-fold cross-validation...")

# Create 5-fold CV with stratification
kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
all_indices = list(range(len(pair_indices)))

# Store CV results
cv_results = {
    'fold_auprcs': [],
    'fold_accs': [],
    'fold_mccs': [],
    'fold_histories': []
}

print(f"Total samples: {len(y):,}")
print(f"Positive ratio: {y.mean():.3f}")



In [None]:
# Cell 7: Dataset & DataLoader definitions (same as before)
class ProtMolDataset(Dataset):
    def __init__(self, indices):
        self.indices = indices
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, ix):
        pi, mi = pair_indices[self.indices[ix]]
        return mean_seq[pi], mol_feats[mi], y[self.indices[ix]]

def collate_fn(batch):
    prots = torch.tensor([p for p, _, _ in batch], dtype=torch.float32)
    mols = torch.tensor([m for _, m, _ in batch], dtype=torch.float32)
    tgts = torch.tensor([t for _, _, t in batch], dtype=torch.float32)
    return prots.to(device), mols.to(device), tgts.to(device)



In [None]:
# Cell 8: Model definition (same as before)
class MolProteinCrossAttention(nn.Module):
    def __init__(self, prot_dim, mol_dim, emb_dim=256, dropout_rate=0.2):
        super().__init__()
        
        # Protein and molecule projection layers
        self.prot_proj = nn.Sequential(
            nn.Linear(prot_dim, emb_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(emb_dim, emb_dim)
        )
        
        self.mol_proj = nn.Sequential(
            nn.Linear(mol_dim, emb_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(emb_dim, emb_dim)
        )
        
        # Cross-attention mechanism
        self.cross_attn = nn.MultiheadAttention(
            embed_dim=emb_dim,
            num_heads=8,
            dropout=dropout_rate,
            batch_first=True
        )
        
        # Final prediction layers
        self.predictor = nn.Sequential(
            nn.Linear(emb_dim * 2, emb_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(emb_dim, 64),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
        
        # Initialize weights
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                module.bias.data.fill_(0.01)
    
    def forward(self, x_prot, x_mol):
        # Project to common embedding space
        prot_emb = self.prot_proj(x_prot)  # [batch, emb_dim]
        mol_emb = self.mol_proj(x_mol)     # [batch, emb_dim]
        
        # Add sequence dimension for attention
        prot_seq = prot_emb.unsqueeze(1)   # [batch, 1, emb_dim]
        mol_seq = mol_emb.unsqueeze(1)     # [batch, 1, emb_dim]
        
        # Cross-attention: protein queries molecule
        prot_attended, _ = self.cross_attn(prot_seq, mol_seq, mol_seq)
        prot_attended = prot_attended.squeeze(1)  # [batch, emb_dim]
        
        # Combine attended protein with original molecule
        combined = torch.cat([prot_attended, mol_emb], dim=1)  # [batch, emb_dim*2]
        
        # Final prediction
        output = self.predictor(combined)
        return output.squeeze(-1)

# Get dimensions from sample data
PROT_DIM = len(list(mean_seq.values())[0])
MOL_DIM = len(list(mol_feats.values())[0])

print(f"Model input dimensions: Protein={PROT_DIM}, Molecule={MOL_DIM}")



In [None]:
# Cell 9: 5-Fold Cross-Validation Training
def train_single_fold(train_idx, val_idx, fold_num):
    """Train model for a single fold"""
    print(f"\n{'='*20} FOLD {fold_num} {'='*20}")
    
    # Create data loaders for this fold
    train_dataset = ProtMolDataset(train_idx)
    val_dataset = ProtMolDataset(val_idx)
    
    train_loader = DataLoader(
        train_dataset, 
        batch_size=256, 
        shuffle=True, 
        collate_fn=collate_fn,
        num_workers=4,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=512, 
        shuffle=False, 
        collate_fn=collate_fn,
        num_workers=4,
        pin_memory=True
    )
    
    print(f"Fold {fold_num}: Train={len(train_idx):,}, Val={len(val_idx):,}")
    print(f"Train balance: {y[train_idx].mean():.3f}, Val balance: {y[val_idx].mean():.3f}")
    
    # Initialize fresh model for this fold
    model = MolProteinCrossAttention(PROT_DIM, MOL_DIM).to(device)
    criterion = nn.BCELoss()
    optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.7, patience=3)
    
    # Training tracking for this fold
    best_auprc = 0.0
    train_losses, val_auprcs = [], []
    patience_counter = 0
    max_patience = 8
    
    num_epochs = 50  # Increased epochs for better training

    for epoch in range(1, num_epochs + 1):
        # Training phase
        model.train()
        train_loss = 0.0
        
        for prot, mol, tgt in train_loader:
            optimizer.zero_grad()
            pred = model(prot, mol)
            loss = criterion(pred, tgt)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            train_loss += loss.item()
        
        avg_train_loss = train_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        
        # Validation phase
        model.eval()
        all_preds, all_tgts = [], []
        
        with torch.no_grad():
            for prot, mol, tgt in val_loader:
                pred = model(prot, mol)
                all_preds.append(pred.cpu())
                all_tgts.append(tgt.cpu())
        
        # Calculate metrics
        preds = torch.cat(all_preds).numpy()
        tgts = torch.cat(all_tgts).numpy()
        
        auprc = average_precision_score(tgts, preds)
        acc = accuracy_score(tgts, preds > 0.5)
        mcc = matthews_corrcoef(tgts, preds > 0.5)
        
        val_auprcs.append(auprc)
        
        # Learning rate scheduling
        scheduler.step(auprc)
        
        # Save best model for this fold
        if auprc > best_auprc:
            best_auprc = auprc
            torch.save(model.state_dict(), f'best_model_fold_{fold_num}.pt')
            patience_counter = 0
        else:
            patience_counter += 1
        
        # Print progress every 10 epochs
        if epoch % 10 == 0:
            print(f"  Epoch {epoch:02d}: Loss {avg_train_loss:.4f}, AUPRC {auprc:.4f}, Acc {acc:.4f}, MCC {mcc:.4f}")
        
        # Early stopping
        if patience_counter >= max_patience:
            print(f"  Early stopping at epoch {epoch}")
            break
    
    # Final evaluation for this fold
    model.load_state_dict(torch.load(f'best_model_fold_{fold_num}.pt'))
    model.eval()
    
    all_preds, all_tgts = [], []
    with torch.no_grad():
        for prot, mol, tgt in val_loader:
            pred = model(prot, mol)
            all_preds.append(pred.cpu())
            all_tgts.append(tgt.cpu())
    
    final_preds = torch.cat(all_preds).numpy()
    final_tgts = torch.cat(all_tgts).numpy()
    
    fold_auprc = average_precision_score(final_tgts, final_preds)
    fold_acc = accuracy_score(final_tgts, final_preds > 0.5)
    fold_mcc = matthews_corrcoef(final_tgts, final_preds > 0.5)
    
    print(f"Fold {fold_num} Results: AUPRC {fold_auprc:.4f}, Acc {fold_acc:.4f}, MCC {fold_mcc:.4f}")
    
    # Cleanup
    del model, optimizer, scheduler
    torch.cuda.empty_cache()
    gc.collect()
    
    return fold_auprc, fold_acc, fold_mcc, {'train_losses': train_losses, 'val_auprcs': val_auprcs}

# Run 5-fold cross-validation
print("Starting 5-fold cross-validation...")

for fold_num, (train_idx, val_idx) in enumerate(kfold.split(all_indices, y), 1):
    fold_auprc, fold_acc, fold_mcc, fold_history = train_single_fold(train_idx, val_idx, fold_num)
    
    cv_results['fold_auprcs'].append(fold_auprc)
    cv_results['fold_accs'].append(fold_acc)
    cv_results['fold_mccs'].append(fold_mcc)
    cv_results['fold_histories'].append(fold_history)



In [None]:
# Cell 10: Cross-Validation Results Summary with Plots
import matplotlib.pyplot as plt
import seaborn as sns

print("\n" + "="*60)
print("5-FOLD CROSS-VALIDATION RESULTS")
print("="*60)

auprcs = cv_results['fold_auprcs']
accs = cv_results['fold_accs']
mccs = cv_results['fold_mccs']

print("Per-fold results:")
for i, (auprc, acc, mcc) in enumerate(zip(auprcs, accs, mccs), 1):
    print(f"  Fold {i}: AUPRC {auprc:.4f}, Acc {acc:.4f}, MCC {mcc:.4f}")

print(f"\nMean ± Std:")
print(f"  AUPRC: {np.mean(auprcs):.4f} ± {np.std(auprcs):.4f}")
print(f"  Accuracy: {np.mean(accs):.4f} ± {np.std(accs):.4f}")
print(f"  MCC: {np.mean(mccs):.4f} ± {np.std(mccs):.4f}")

# Create comprehensive plots
plt.style.use('default')
fig = plt.figure(figsize=(20, 15))

# 1. Per-fold performance comparison
plt.subplot(3, 3, 1)
folds = range(1, 6)
plt.bar(folds, auprcs, alpha=0.7, color='skyblue', edgecolor='navy')
plt.axhline(y=np.mean(auprcs), color='red', linestyle='--', alpha=0.8, label=f'Mean: {np.mean(auprcs):.3f}')
plt.xlabel('Fold')
plt.ylabel('AUPRC')
plt.title('AUPRC Across Folds')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(3, 3, 2)
plt.bar(folds, accs, alpha=0.7, color='lightgreen', edgecolor='darkgreen')
plt.axhline(y=np.mean(accs), color='red', linestyle='--', alpha=0.8, label=f'Mean: {np.mean(accs):.3f}')
plt.xlabel('Fold')
plt.ylabel('Accuracy')
plt.title('Accuracy Across Folds')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(3, 3, 3)
plt.bar(folds, mccs, alpha=0.7, color='salmon', edgecolor='darkred')
plt.axhline(y=np.mean(mccs), color='red', linestyle='--', alpha=0.8, label=f'Mean: {np.mean(mccs):.3f}')
plt.xlabel('Fold')
plt.ylabel('MCC')
plt.title('MCC Across Folds')
plt.legend()
plt.grid(True, alpha=0.3)

# 2. Training curves for each fold
plt.subplot(3, 3, 4)
for i, history in enumerate(cv_results['fold_histories']):
    plt.plot(history['train_losses'], label=f'Fold {i+1}', alpha=0.7)
plt.xlabel('Epoch')
plt.ylabel('Training Loss')
plt.title('Training Loss Curves')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(3, 3, 5)
for i, history in enumerate(cv_results['fold_histories']):
    plt.plot(history['val_auprcs'], label=f'Fold {i+1}', alpha=0.7)
plt.xlabel('Epoch')
plt.ylabel('Validation AUPRC')
plt.title('Validation AUPRC Curves')
plt.legend()
plt.grid(True, alpha=0.3)

# 3. Performance distribution
plt.subplot(3, 3, 6)
metrics_data = {
    'AUPRC': auprcs,
    'Accuracy': accs,
    'MCC': mccs
}
box_data = [auprcs, accs, mccs]
plt.boxplot(box_data, labels=['AUPRC', 'Accuracy', 'MCC'])
plt.ylabel('Score')
plt.title('Performance Distribution')
plt.grid(True, alpha=0.3)

# 4. Combined metrics plot
plt.subplot(3, 3, 7)
x = np.arange(len(folds))
width = 0.25
plt.bar(x - width, auprcs, width, label='AUPRC', alpha=0.8)
plt.bar(x, accs, width, label='Accuracy', alpha=0.8)
plt.bar(x + width, mccs, width, label='MCC', alpha=0.8)
plt.xlabel('Fold')
plt.ylabel('Score')
plt.title('All Metrics by Fold')
plt.xticks(x, [f'Fold {i}' for i in folds])
plt.legend()
plt.grid(True, alpha=0.3)

# 5. Performance vs Mean
plt.subplot(3, 3, 8)
metrics = ['AUPRC', 'Accuracy', 'MCC']
means = [np.mean(auprcs), np.mean(accs), np.mean(mccs)]
stds = [np.std(auprcs), np.std(accs), np.std(mccs)]
plt.errorbar(metrics, means, yerr=stds, fmt='o-', capsize=5, capthick=2, markersize=8)
plt.ylabel('Score')
plt.title('Mean Performance ± Std')
plt.grid(True, alpha=0.3)

# 6. Best vs Worst fold comparison
plt.subplot(3, 3, 9)
best_fold = np.argmax(auprcs) + 1
worst_fold = np.argmin(auprcs) + 1
comparison_data = {
    'AUPRC': [auprcs[best_fold-1], auprcs[worst_fold-1]],
    'Accuracy': [accs[best_fold-1], accs[worst_fold-1]],
    'MCC': [mccs[best_fold-1], mccs[worst_fold-1]]
}
x_pos = np.arange(len(metrics))
plt.bar(x_pos - 0.2, [comparison_data['AUPRC'][0], comparison_data['Accuracy'][0], comparison_data['MCC'][0]], 
        0.4, label=f'Best (Fold {best_fold})', alpha=0.8)
plt.bar(x_pos + 0.2, [comparison_data['AUPRC'][1], comparison_data['Accuracy'][1], comparison_data['MCC'][1]], 
        0.4, label=f'Worst (Fold {worst_fold})', alpha=0.8)
plt.xlabel('Metric')
plt.ylabel('Score')
plt.title('Best vs Worst Fold')
plt.xticks(x_pos, metrics)
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('cv_results_comprehensive.png', dpi=300, bbox_inches='tight')
plt.show()

# Additional detailed training plots
fig, axes = plt.subplots(2, 3, figsize=(18, 10))

# Individual fold training curves
for i, history in enumerate(cv_results['fold_histories']):
    row = i // 3
    col = i % 3
    
    if i < 5:  # We have 5 folds
        ax = axes[row, col]
        epochs = range(1, len(history['train_losses']) + 1)
        
        # Plot training loss and validation AUPRC
        ax2 = ax.twinx()
        
        line1 = ax.plot(epochs, history['train_losses'], 'b-', label='Training Loss', alpha=0.7)
        line2 = ax2.plot(epochs, history['val_auprcs'], 'r-', label='Validation AUPRC', alpha=0.7)
        
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Training Loss', color='b')
        ax2.set_ylabel('Validation AUPRC', color='r')
        ax.set_title(f'Fold {i+1} Training Progress')
        ax.grid(True, alpha=0.3)
        
        # Combined legend
        lines = line1 + line2
        labels = [l.get_label() for l in lines]
        ax.legend(lines, labels, loc='center right')

# Remove empty subplot
axes[1, 2].remove()

plt.tight_layout()
plt.savefig('individual_fold_training.png', dpi=300, bbox_inches='tight')
plt.show()

# Summary statistics plot
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 10))

# Correlation matrix of metrics
metric_df = pd.DataFrame({
    'AUPRC': auprcs,
    'Accuracy': accs,
    'MCC': mccs
})
correlation_matrix = metric_df.corr()
sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', center=0, ax=ax1)
ax1.set_title('Metric Correlations Across Folds')

# Performance variability
cv_scores = [np.std(auprcs)/np.mean(auprcs), np.std(accs)/np.mean(accs), np.std(mccs)/np.mean(mccs)]
ax2.bar(metrics, cv_scores, color=['skyblue', 'lightgreen', 'salmon'])
ax2.set_ylabel('Coefficient of Variation')
ax2.set_title('Performance Variability (CV)')
ax2.grid(True, alpha=0.3)

# Fold ranking by AUPRC
fold_rankings = np.argsort(auprcs)[::-1] + 1
colors = plt.cm.viridis(np.linspace(0, 1, 5))
ax3.bar(range(1, 6), np.sort(auprcs)[::-1], color=colors)
ax3.set_xlabel('Rank')
ax3.set_ylabel('AUPRC')
ax3.set_title('Fold Ranking by AUPRC')
ax3.set_xticks(range(1, 6))
ax3.set_xticklabels([f'Fold {r}' for r in fold_rankings])
ax3.grid(True, alpha=0.3)

# Performance improvement over epochs (average)
avg_train_losses = np.mean([hist['train_losses'] for hist in cv_results['fold_histories']], axis=0)
avg_val_auprcs = np.mean([hist['val_auprcs'] for hist in cv_results['fold_histories']], axis=0)

ax4_twin = ax4.twinx()
epochs = range(1, len(avg_train_losses) + 1)
ax4.plot(epochs, avg_train_losses, 'b-', label='Avg Training Loss')
ax4_twin.plot(epochs, avg_val_auprcs, 'r-', label='Avg Validation AUPRC')
ax4.set_xlabel('Epoch')
ax4.set_ylabel('Training Loss', color='b')
ax4_twin.set_ylabel('Validation AUPRC', color='r')
ax4.set_title('Average Training Progress')
ax4.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('cv_analysis_detailed.png', dpi=300, bbox_inches='tight')
plt.show()

# Save CV results
torch.save(cv_results, 'cv_results.pt')

print("\nFiles saved:")
print("- best_model_fold_1.pt through best_model_fold_5.pt")
print("- cv_results.pt")
print("- cv_results_comprehensive.png")
print("- individual_fold_training.png") 
print("- cv_analysis_detailed.png")

check_memory()
print("\n✅ 5-Fold Cross-Validation with comprehensive plotting complete!")