In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
"""
Enhanced ESM2 Embedding Extraction for Protein ŒîŒîG Prediction
==============================================================

This notebook implements signal-amplified embedding extraction from ESM2-650M
to address weak mutation signals in protein stability prediction.

Key Features:
- Multi-layer aggregation (layers 30-33)
- Derived mutation features (Œî, |Œî|, cosine, L2)
- Z-normalization with train-only statistics
- Lazy HDF5 loading for memory efficiency
- Comprehensive validation and diagnostics

Author: Enhanced pipeline for Stage 2 training
"""

In [None]:
################################################################################
# CELL 1: SETUP AND DEPENDENCIES
################################################################################

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from scipy.stats import pearsonr
from pathlib import Path
from tqdm.auto import tqdm
import h5py
from typing import Tuple, List, Dict, Optional
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from dataclasses import dataclass
import json
import warnings
warnings.filterwarnings('ignore')

# Set plotting style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 6)

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

In [None]:
################################################################################
# CELL 2: INSTALL ESM AND CREATE DIRECTORIES
################################################################################

# Install fair-esm
import subprocess
import sys

try:
    import esm
    print("‚úì ESM already installed")
except ImportError:
    print("Installing fair-esm...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "fair-esm==2.0.0", "-q"])
    import esm
    print("‚úì ESM installed successfully")

# Create directory structure
output_dir = Path("/content/drive/MyDrive/Protein_prediction_model/abyssal_embeddings/stage_2")
output_dir.mkdir(parents=True, exist_ok=True)

for split in ['train', 'val', 'test']:
    (output_dir / split).mkdir(exist_ok=True)

(output_dir / 'visualizations').mkdir(exist_ok=True)
print(f"‚úì Directory structure created at {output_dir}")


In [None]:
################################################################################
# CELL 3: LOAD MUTATION DATASET
################################################################################

def load_mutation_data(csv_path: str) -> pd.DataFrame:
    """
    Load and prepare mutation dataset.

    Expected columns:
    - wt_seq: Wild-type sequence
    - mut_seq: Mutant sequence
    - pos: Mutation position (0-indexed)
    - ddG: Experimental ŒîŒîG (kcal/mol)
    - wt_cluster: Cluster ID for splitting
    """
    df = pd.read_csv(csv_path)

    # Split into train/val/test by clusters (80/10/10)
    train_p = [23, 3, 186, 100, 119, 174, 'HHH', 146, 225, '81', '47', 123, 29, 142, 181, 223, '53', 155,
            213, '99', 113, 16, 224, 147, '9', 150, '56', '49', 196, 107, 226, 232, 169, 191, 163, 127,
            'EEHEE', 161, '4', 11, 179, 28, '87', '73', 112, 120, '50', 104, 128, 168, 231, '59', 109,
            38, '76', '96', 199, 192, '93', 185, 'EEHH', 210, 'EHEE', 30, '62', 173, 206, '48', 190,
            '45', '74', 205, 27, 158, '97', 208, 215, '72', 166, 152, '89', 101, 194, 139, 102, 145,
            '65', '58', '44', '92', 24, '40', 36, 'hall', 135, 209, 230, '71', 121, '5', 193, '64']

    test_p=[170, 117, 202, 132, 218, '88', 156, '57', 114, 32, '95', 'HEEH', 184, 129, '67', '51',
            '55', 4, 227]

    val_p = [34, 164, 20, '7', 15, 149, 105]

    df['split'] = 'train'
    df.loc[df['wt_cluster'].isin(val_p), 'split'] = 'val'
    df.loc[df['wt_cluster'].isin(test_p), 'split'] = 'test'

    return df

# Load your dataset
csv_path = "/content/drive/MyDrive/Protein_prediction_model/k50_cleaned.csv"
mutation_df = load_mutation_data(csv_path)

print("Dataset Statistics:")
print(f"Total samples: {len(mutation_df)}")
print(f"Train: {len(mutation_df[mutation_df['split']=='train'])}")
print(f"Val: {len(mutation_df[mutation_df['split']=='val'])}")
print(f"Test: {len(mutation_df[mutation_df['split']=='test'])}")
print(f"\nŒîŒîG range: [{mutation_df['ddG'].min():.2f}, {mutation_df['ddG'].max():.2f}]")
print(f"ŒîŒîG mean¬±std: {mutation_df['ddG'].mean():.2f}¬±{mutation_df['ddG'].std():.2f}")
print(f"Sequence length range: {mutation_df['wt_seq'].str.len().min()}-{mutation_df['wt_seq'].str.len().max()}")

In [None]:
################################################################################
# CELL 4: ENHANCED ESM2 EXTRACTOR CLASS
################################################################################

class ESM2EmbeddingExtractor:
    """
    Extract multi-layer, signal-amplified embeddings from ESM2-650M.

    Enhancements:
    - Uses ESM2-650M (1280-dim) instead of 8M
    - Aggregates last N layers via mean pooling
    - Maintains frozen weights for efficiency
    - Supports batch processing with GPU optimization
    """

    def __init__(
        self,
        model_name: str = "esm2_t33_650M_UR50D",
        device: torch.device = None,
        batch_size: int = 16,
        num_layers_to_average: int = 4
    ):
        self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.batch_size = batch_size
        self.num_layers_to_average = num_layers_to_average

        print(f"Loading {model_name}...")
        self.model, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D()
        self.batch_converter = self.alphabet.get_batch_converter()

        # Move to device and freeze
        self.model = self.model.to(self.device)
        self.model.eval()
        for param in self.model.parameters():
            param.requires_grad = False

        # Model specifications
        self.embedding_dim = self.model.embed_dim  # 1280
        self.num_layers = self.model.num_layers    # 33

        # Validate and set layer extraction
        assert num_layers_to_average <= self.num_layers
        self.layers_to_extract = list(range(
            self.num_layers - num_layers_to_average + 1,
            self.num_layers + 1
        ))

        print(f"‚úì Model: ESM2-650M")
        print(f"‚úì Embedding dim: {self.embedding_dim}")
        print(f"‚úì Layers to average: {self.layers_to_extract}")
        print(f"‚úì Device: {self.device}")
        print("=" * 60)

    def extract_position_embedding(
        self,
        sequence: str,
        position: int
    ) -> np.ndarray:
        """Extract multi-layer averaged embedding at position."""
        data = [("protein", sequence)]
        _, _, batch_tokens = self.batch_converter(data)
        batch_tokens = batch_tokens.to(self.device)

        with torch.no_grad():
            results = self.model(batch_tokens, repr_layers=self.layers_to_extract)

        # Average across selected layers
        layer_embeddings = []
        for layer in self.layers_to_extract:
            emb = results["representations"][layer]
            pos_emb = emb[0, position + 1, :]  # +1 for <cls> token
            layer_embeddings.append(pos_emb)

        stacked = torch.stack(layer_embeddings, dim=0)
        averaged = torch.mean(stacked, dim=0)

        return averaged.cpu().numpy()

    def extract_batch_embeddings(
        self,
        sequences: List[str],
        positions: List[int]
    ) -> np.ndarray:
        """Extract embeddings for a batch of sequences."""
        data = [(f"protein_{i}", seq) for i, seq in enumerate(sequences)]
        _, _, batch_tokens = self.batch_converter(data)
        batch_tokens = batch_tokens.to(self.device)

        with torch.no_grad():
            results = self.model(batch_tokens, repr_layers=self.layers_to_extract)

        batch_embeddings = []
        for i, pos in enumerate(positions):
            layer_embs = []
            for layer in self.layers_to_extract:
                emb = results["representations"][layer][i, pos + 1, :]
                layer_embs.append(emb)

            stacked = torch.stack(layer_embs, dim=0)
            averaged = torch.mean(stacked, dim=0).cpu().numpy()
            batch_embeddings.append(averaged)

        return np.array(batch_embeddings)

# Initialize extractor
extractor = ESM2EmbeddingExtractor(
    device=device,
    batch_size=4,
    num_layers_to_average=4
)

In [None]:
################################################################################
# CELL 5: DERIVED FEATURE COMPUTATION
################################################################################

def compute_derived_features(
    wt_emb: np.ndarray,
    mut_emb: np.ndarray
) -> Dict[str, np.ndarray]:
    """
    Compute mutation-specific derived features.

    Features:
    - delta: mut - wt (directional change)
    - abs_delta: |mut - wt| (magnitude per dimension)
    - cosine_similarity: geometric alignment
    - l2_distance: Euclidean separation
    """
    delta = mut_emb - wt_emb
    abs_delta = np.abs(delta)

    wt_norm = np.linalg.norm(wt_emb)
    mut_norm = np.linalg.norm(mut_emb)
    cosine_sim = np.dot(wt_emb, mut_emb) / (wt_norm * mut_norm + 1e-8)

    l2_dist = np.linalg.norm(delta)

    return {
        'delta': delta,
        'abs_delta': abs_delta,
        'cosine_similarity': cosine_sim,
        'l2_distance': l2_dist
    }


In [None]:
################################################################################
# CELL 6: EMBEDDING EXTRACTION AND CACHING
################################################################################

def extract_and_cache_embeddings(
    mutation_df: pd.DataFrame,
    extractor: ESM2EmbeddingExtractor,
    output_dir: Path,
    batch_size: int = 32
):
    """
    Extract embeddings with derived features and save to HDF5.

    HDF5 structure per split:
    - wt_embeddings: (n_samples, 1280)
    - mut_embeddings: (n_samples, 1280)
    - delta_embeddings: (n_samples, 1280)
    - abs_delta_embeddings: (n_samples, 1280)
    - cosine_similarities: (n_samples,)
    - l2_distances: (n_samples,)
    - ddg_values: (n_samples,)
    - norm_stats/: normalization parameters (train only)
    """
    train_stats = None

    for split in ['train', 'val', 'test']:
        split_df = mutation_df[mutation_df['split'] == split].reset_index(drop=True)
        n_samples = len(split_df)

        if n_samples == 0:
            continue

        print(f"\n{'='*60}")
        print(f"Processing {split.upper()} split: {n_samples} samples")
        print(f"{'='*60}")

        h5_path = output_dir / f"{split}_embeddings.h5"

        with h5py.File(h5_path, 'w') as h5f:
            # Create datasets
            wt_ds = h5f.create_dataset('wt_embeddings', shape=(n_samples, extractor.embedding_dim), dtype='float32')
            mut_ds = h5f.create_dataset('mut_embeddings', shape=(n_samples, extractor.embedding_dim), dtype='float32')
            delta_ds = h5f.create_dataset('delta_embeddings', shape=(n_samples, extractor.embedding_dim), dtype='float32')
            abs_delta_ds = h5f.create_dataset('abs_delta_embeddings', shape=(n_samples, extractor.embedding_dim), dtype='float32')
            cosine_ds = h5f.create_dataset('cosine_similarities', shape=(n_samples,), dtype='float32')
            l2_ds = h5f.create_dataset('l2_distances', shape=(n_samples,), dtype='float32')
            ddg_ds = h5f.create_dataset('ddg_values', shape=(n_samples,), dtype='float32')

            # Accumulators for train statistics
            if split == 'train':
                all_wt, all_mut, all_delta, all_abs_delta = [], [], [], []

            # Process in batches
            for start_idx in tqdm(range(0, n_samples, batch_size), desc=f"{split}"):
                end_idx = min(start_idx + batch_size, n_samples)
                batch_df = split_df.iloc[start_idx:end_idx]

                # Extract embeddings
                wt_seqs = batch_df['wt_seq'].tolist()
                mut_seqs = batch_df['mut_seq'].tolist()
                positions = batch_df['pos'].tolist()

                wt_embs = extractor.extract_batch_embeddings(wt_seqs, positions)
                mut_embs = extractor.extract_batch_embeddings(mut_seqs, positions)
                ddg_vals = batch_df['ddG'].values

                # Compute derived features
                batch_deltas, batch_abs_deltas = [], []
                batch_cosines, batch_l2s = [], []

                for wt_emb, mut_emb in zip(wt_embs, mut_embs):
                    features = compute_derived_features(wt_emb, mut_emb)
                    batch_deltas.append(features['delta'])
                    batch_abs_deltas.append(features['abs_delta'])
                    batch_cosines.append(features['cosine_similarity'])
                    batch_l2s.append(features['l2_distance'])

                batch_deltas = np.array(batch_deltas)
                batch_abs_deltas = np.array(batch_abs_deltas)

                # Save to HDF5
                wt_ds[start_idx:end_idx] = wt_embs
                mut_ds[start_idx:end_idx] = mut_embs
                delta_ds[start_idx:end_idx] = batch_deltas
                abs_delta_ds[start_idx:end_idx] = batch_abs_deltas
                cosine_ds[start_idx:end_idx] = batch_cosines
                l2_ds[start_idx:end_idx] = batch_l2s
                ddg_ds[start_idx:end_idx] = ddg_vals

                if split == 'train':
                    all_wt.append(wt_embs)
                    all_mut.append(mut_embs)
                    all_delta.append(batch_deltas)
                    all_abs_delta.append(batch_abs_deltas)

                # GPU cleanup
                if device.type == 'cuda' and start_idx % (batch_size * 5) == 0:
                    torch.cuda.empty_cache()

            # Compute train normalization statistics
            if split == 'train':
                all_wt = np.vstack(all_wt)
                all_mut = np.vstack(all_mut)
                all_delta = np.vstack(all_delta)
                all_abs_delta = np.vstack(all_abs_delta)

                train_stats = {
                    'wt_mean': np.mean(all_wt, axis=0).astype(np.float32),
                    'wt_std': np.std(all_wt, axis=0).astype(np.float32) + 1e-8,
                    'mut_mean': np.mean(all_mut, axis=0).astype(np.float32),
                    'mut_std': np.std(all_mut, axis=0).astype(np.float32) + 1e-8,
                    'delta_mean': np.mean(all_delta, axis=0).astype(np.float32),
                    'delta_std': np.std(all_delta, axis=0).astype(np.float32) + 1e-8,
                    'abs_delta_mean': np.mean(all_abs_delta, axis=0).astype(np.float32),
                    'abs_delta_std': np.std(all_abs_delta, axis=0).astype(np.float32) + 1e-8,
                }

                # Save to HDF5
                for key, value in train_stats.items():
                    h5f.create_dataset(f'norm_stats/{key}', data=value)

                print(f"\n‚úì Normalization statistics computed from TRAIN")

        # Sanity checks
        with h5py.File(h5_path, 'r') as h5f:
            assert not np.isnan(h5f['wt_embeddings'][:]).any(), f"{split}: WT has NaN"
            assert not np.isnan(h5f['delta_embeddings'][:]).any(), f"{split}: Œî has NaN"

            cosines = h5f['cosine_similarities'][:]
            assert (cosines >= -1.01).all() and (cosines <= 1.01).all()

            delta_std = np.std(h5f['delta_embeddings'][:], axis=0)
            active_dims = (delta_std > 1e-6).sum()
            assert active_dims / len(delta_std) > 0.5, f"{split}: Variance collapse"

            print(f"‚úì {split} saved and validated ({n_samples} samples)")
            print(f"  Cosine: [{cosines.min():.3f}, {cosines.max():.3f}]")
            print(f"  L2: [{h5f['l2_distances'][:].min():.3f}, {h5f['l2_distances'][:].max():.3f}]")
            print(f"  Active dims: {active_dims}/{len(delta_std)}")

    # Save normalization stats globally
    if train_stats:
        np.savez(output_dir / 'normalization_stats.npz', **train_stats)
        print(f"\n‚úì Normalization stats saved")

    print("\n" + "="*60)
    print("EMBEDDING EXTRACTION COMPLETE")
    print("="*60)

# Run extraction
extract_and_cache_embeddings(
    mutation_df=mutation_df,
    extractor=extractor,
    output_dir=output_dir,
    batch_size=64
)

In [None]:
################################################################################
# CELL 7: PYTORCH DATASET WITH LAZY LOADING
################################################################################

@dataclass
class EmbeddingConfig:
    """Configuration for feature loading."""
    load_wt: bool = True
    load_mut: bool = True
    load_delta: bool = True
    load_abs_delta: bool = True
    load_cosine: bool = True
    load_l2: bool = True
    normalize: bool = True

class MutationEmbeddingDataset(Dataset):
    """
    PyTorch Dataset with signal-amplified embeddings and lazy loading.
    """

    def __init__(
        self,
        split: str,
        data_dir: Path,
        config: EmbeddingConfig = None,
        norm_stats: Optional[Dict] = None
    ):
        self.split = split
        self.h5_path = data_dir / f"{split}_embeddings.h5"
        self.config = config or EmbeddingConfig()

        if not self.h5_path.exists():
            raise FileNotFoundError(f"Not found: {self.h5_path}")

        with h5py.File(self.h5_path, 'r') as h5f:
            self.n_samples = h5f['wt_embeddings'].shape[0]
            self.embedding_dim = h5f['wt_embeddings'].shape[1]

            if self.config.normalize:
                if norm_stats is not None:
                    self.norm_stats = norm_stats
                elif split == 'train' and 'norm_stats/wt_mean' in h5f:
                    # FIX: Correctly access datasets within the 'norm_stats' group
                    self.norm_stats = {
                        key: h5f['norm_stats'][key][:]
                        for key in h5f['norm_stats'].keys()
                    }
                else:
                    stats_path = data_dir / 'normalization_stats.npz'
                    if stats_path.exists():
                        loaded = np.load(stats_path)
                        self.norm_stats = {k: loaded[k] for k in loaded.files}
                    else:
                        print(f"Warning: No norm stats, using raw embeddings")
                        self.config.normalize = False
                        self.norm_stats = None
            else:
                self.norm_stats = None

        if self.config.normalize and split != 'train':
            assert self.norm_stats is not None, f"Need train stats for {split}"

        print(f"‚úì {split}: {self.n_samples} samples, dim={self.embedding_dim}")
        print(f"  Normalization: {'ENABLED' if self.config.normalize else 'DISABLED'}")

    def __len__(self) -> int:
        return self.n_samples

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        sample = {}

        with h5py.File(self.h5_path, 'r') as h5f:
            if self.config.load_wt:
                wt = h5f['wt_embeddings'][idx].astype(np.float32)
                if self.config.normalize and self.norm_stats:
                    wt = (wt - self.norm_stats['wt_mean']) / self.norm_stats['wt_std']
                sample['wt_embedding'] = torch.from_numpy(wt)

            if self.config.load_mut:
                mut = h5f['mut_embeddings'][idx].astype(np.float32)
                if self.config.normalize and self.norm_stats:
                    mut = (mut - self.norm_stats['mut_mean']) / self.norm_stats['mut_std']
                sample['mut_embedding'] = torch.from_numpy(mut)

            if self.config.load_delta:
                delta = h5f['delta_embeddings'][idx].astype(np.float32)
                if self.config.normalize and self.norm_stats:
                    delta = (delta - self.norm_stats['delta_mean']) / self.norm_stats['delta_std']
                sample['delta_embedding'] = torch.from_numpy(delta)

            if self.config.load_abs_delta:
                abs_delta = h5f['abs_delta_embeddings'][idx].astype(np.float32)
                if self.config.normalize and self.norm_stats:
                    abs_delta = (abs_delta - self.norm_stats['abs_delta_mean']) / self.norm_stats['abs_delta_std']
                sample['abs_delta_embedding'] = torch.from_numpy(abs_delta)

            if self.config.load_cosine:
                sample['cosine_similarity'] = torch.tensor(h5f['cosine_similarities'][idx], dtype=torch.float32)

            if self.config.load_l2:
                sample['l2_distance'] = torch.tensor(h5f['l2_distances'][idx], dtype=torch.float32)

            sample['ddg'] = torch.tensor(h5f['ddg_values'][idx], dtype=torch.float32)

        return sample

    def get_stats(self) -> Dict:
        with h5py.File(self.h5_path, 'r') as h5f:
            ddg = h5f['ddg_values'][:]
            cos = h5f['cosine_similarities'][:]
            l2 = h5f['l2_distances'][:]

        return {
            'n_samples': self.n_samples,
            'ddg_mean': float(np.mean(ddg)),
            'ddg_std': float(np.std(ddg)),
            'ddg_range': [float(np.min(ddg)), float(np.max(ddg))],
            'cosine_mean': float(np.mean(cos)),
            'cosine_std': float(np.std(cos)),
            'l2_mean': float(np.mean(l2)),
            'l2_std': float(np.std(l2)),
        }

# Create datasets
datasets = {}
config = EmbeddingConfig(normalize=True)

if (output_dir / 'train_embeddings.h5').exists():
    datasets['train'] = MutationEmbeddingDataset('train', output_dir, config)
    train_norm_stats = datasets['train'].norm_stats

    for split in ['val', 'test']:
        if (output_dir / f'{split}_embeddings.h5').exists():
            datasets[split] = MutationEmbeddingDataset(split, output_dir, config, train_norm_stats)

print("\n" + "="*60)
print("DATASETS CREATED")
print("="*60)

In [None]:
################################################################################
# CELL 8: VALIDATION
################################################################################

def validate_embeddings(datasets: Dict[str, MutationEmbeddingDataset]):
    print("\n" + "="*60)
    print("VALIDATION CHECKS")
    print("="*60)

    for split, dataset in datasets.items():
        print(f"\n{split.upper()}:")
        print("-" * 40)

        stats = dataset.get_stats()
        print(f"‚úì Samples: {stats['n_samples']}")
        print(f"‚úì ŒîŒîG: {stats['ddg_mean']:.2f}¬±{stats['ddg_std']:.2f}")
        print(f"‚úì Cosine: {stats['cosine_mean']:.3f}¬±{stats['cosine_std']:.3f}")
        print(f"‚úì L2: {stats['l2_mean']:.3f}¬±{stats['l2_std']:.3f}")

        # Sample checks
        for i in range(min(2, len(dataset))):
            sample = dataset[i]

            for key, val in sample.items():
                if key != 'ddg':
                    assert not torch.isnan(val).any(), f"{key} has NaN"
                    assert not torch.isinf(val).any(), f"{key} has Inf"

            print(f"  Sample {i}: ŒîŒîG={sample['ddg']:.2f}, " +
                  f"L2={sample.get('l2_distance', 0):.3f}, " +
                  f"cos={sample.get('cosine_similarity', 0):.3f}")

        # DataLoader test
        loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=0)
        batch = next(iter(loader))
        print(f"‚úì DataLoader works, batch keys: {list(batch.keys())}")

validate_embeddings(datasets)

In [None]:
################################################################################
# CELL 9: VISUALIZATION
################################################################################

def visualize_embedding_quality(dataset, split, n_samples=200, save_path=None):
    print(f"\nVisualizing {split}...")
    n_samples = min(n_samples, len(dataset))

    wt_embs, mut_embs, delta_embs = [], [], []
    ddgs, cosines, l2s = [], [], []

    for i in tqdm(range(n_samples), desc="Loading"):
        s = dataset[i]
        wt_embs.append(s['wt_embedding'].numpy())
        mut_embs.append(s['mut_embedding'].numpy())
        delta_embs.append(s['delta_embedding'].numpy())
        ddgs.append(s['ddg'].item())
        cosines.append(s['cosine_similarity'].item())
        l2s.append(s['l2_distance'].item())

    wt_embs = np.array(wt_embs)
    delta_embs = np.array(delta_embs)
    ddgs = np.array(ddgs)
    cosines = np.array(cosines)
    l2s = np.array(l2s)

    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    fig.suptitle(f'Embedding Analysis - {split.upper()}', fontsize=16, fontweight='bold')

    # 1. L2 distance
    axes[0, 0].hist(l2s, bins=30, edgecolor='black', alpha=0.7, color='steelblue')
    axes[0, 0].set_xlabel('L2 Distance')
    axes[0, 0].set_title('Mutation Distance')
    axes[0, 0].axvline(np.mean(l2s), color='red', linestyle='--', label=f'Mean: {np.mean(l2s):.2f}')
    axes[0, 0].legend()

    # 2. L2 vs |ŒîŒîG|
    corr_l2, p_l2 = pearsonr(l2s, np.abs(ddgs))
    axes[0, 1].scatter(l2s, np.abs(ddgs), alpha=0.5, s=20)
    axes[0, 1].set_xlabel('L2 Distance')
    axes[0, 1].set_ylabel('|ŒîŒîG|')
    axes[0, 1].set_title('Distance vs Stability')
    axes[0, 1].text(0.05, 0.95, f'r={corr_l2:.3f}\np={p_l2:.2e}',
                    transform=axes[0, 1].transAxes, va='top',
                    bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

    # 3. Cosine vs |ŒîŒîG|
    corr_cos, p_cos = pearsonr(cosines, np.abs(ddgs))
    axes[0, 2].scatter(cosines, np.abs(ddgs), alpha=0.5, s=20, color='coral')
    axes[0, 2].set_xlabel('Cosine Similarity')
    axes[0, 2].set_ylabel('|ŒîŒîG|')
    axes[0, 2].set_title('Similarity vs Stability')
    axes[0, 2].text(0.05, 0.95, f'r={corr_cos:.3f}\np={p_cos:.2e}',
                    transform=axes[0, 2].transAxes, va='top',
                    bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

    # 4. Œî embedding norms
    delta_norms = np.linalg.norm(delta_embs, axis=1)
    axes[1, 0].hist(delta_norms, bins=30, edgecolor='black', alpha=0.7, color='mediumseagreen')
    axes[1, 0].set_xlabel('||Œî||')
    axes[1, 0].set_title('Œî Embedding Magnitude')
    axes[1, 0].axvline(np.mean(delta_norms), color='red', linestyle='--', label=f'Mean: {np.mean(delta_norms):.2f}')
    axes[1, 0].legend()

    # 5. PCA of Œî
    pca = PCA(n_components=2)
    delta_2d = pca.fit_transform(delta_embs)
    scatter = axes[1, 1].scatter(delta_2d[:, 0], delta_2d[:, 1], c=ddgs,
                                  cmap='RdYlBu_r', alpha=0.6, s=30, edgecolors='black', linewidths=0.5)
    axes[1, 1].set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]*100:.1f}%)')
    axes[1, 1].set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]*100:.1f}%)')
    axes[1, 1].set_title('PCA of Œî Embeddings')
    plt.colorbar(scatter, ax=axes[1, 1], label='ŒîŒîG')

    # 6. Cosine distribution
    axes[1, 2].hist(cosines, bins=30, edgecolor='black', alpha=0.7, color='mediumpurple')
    axes[1, 2].set_xlabel('Cosine Similarity')
    axes[1, 2].set_title('WT-Mutant Similarity')
    axes[1, 2].axvline(np.mean(cosines), color='red', linestyle='--', label=f'Mean: {np.mean(cosines):.3f}')
    axes[1, 2].legend()

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"‚úì Saved to {save_path}")
    plt.show()

    print(f"\n{'='*60}")
    print(f"STATISTICS - {split.upper()}")
    print(f"{'='*60}")
    print(f"L2 distance: {np.mean(l2s):.3f}¬±{np.std(l2s):.3f}")
    print(f"Cosine sim: {np.mean(cosines):.3f}¬±{np.std(cosines):.3f}")
    print(f"Corr (L2 vs |ŒîŒîG|): r={corr_l2:.3f}, p={p_l2:.2e}")
    print(f"Corr (cos vs |ŒîŒîG|): r={corr_cos:.3f}, p={p_cos:.2e}")
    print(f"Œî norm: {np.mean(delta_norms):.3f}¬±{np.std(delta_norms):.3f}")

# Run visualization
if 'train' in datasets:
    visualize_embedding_quality(
        datasets['train'],
        'train',
        n_samples=200,
        save_path=output_dir / 'visualizations' / 'train_analysis.png'
    )

In [None]:
################################################################################
# CELL 10: USAGE EXAMPLE
################################################################################

print("\n" + "="*60)
print("USAGE EXAMPLE")
print("="*60)

# Delta-only configuration (recommended)
delta_config = EmbeddingConfig(
    load_wt=False,
    load_mut=False,
    load_delta=True,
    load_abs_delta=True,
    load_cosine=True,
    load_l2=True,
    normalize=True
)

train_loader = DataLoader(datasets['train'], batch_size=32, shuffle=True, num_workers=0)
batch = next(iter(train_loader))

print("\nBatch contents:")
for key, val in batch.items():
    print(f"  {key}: {val.shape if val.dim() > 1 else 'scalar'}")

# Example model
class SimpleDeltaPredictor(nn.Module):
    def __init__(self, embedding_dim=1280):
        super().__init__()
        input_dim = 2 * embedding_dim + 2  # delta + abs_delta + cos + L2
        self.network = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, batch):
        features = [
            batch['delta_embedding'],
            batch['abs_delta_embedding'],
            batch['cosine_similarity'].unsqueeze(1),
            batch['l2_distance'].unsqueeze(1)
        ]
        x = torch.cat(features, dim=1)
        return self.network(x)

model = SimpleDeltaPredictor().to(device)
print(f"\n‚úì Model created: {sum(p.numel() for p in model.parameters()):,} parameters")
print(f"  Input dim: {2*1280 + 2}")

# Test forward pass
model.eval()
with torch.no_grad():
    batch_gpu = {k: v.to(device) for k, v in batch.items()}
    preds = model(batch_gpu)
    print(f"\n‚úì Forward pass: {preds.shape}")
    print(f"  Predictions: {preds[:3].squeeze().cpu().numpy()}")
    print(f"  True ŒîŒîG: {batch['ddg'][:3].numpy()}")

In [None]:
################################################################################
# CELL 11: SIGNAL QUALITY ASSESSMENT
################################################################################

def assess_signal_quality(dataset, n_samples=500):
    print(f"\nAssessing signal quality...")
    n_samples = min(n_samples, len(dataset))

    cosines, l2s, ddgs, delta_norms = [], [], [], []

    for i in range(n_samples):
        s = dataset[i]
        cosines.append(s['cosine_similarity'].item())
        l2s.append(s['l2_distance'].item())
        ddgs.append(s['ddg'].item())
        delta_norms.append(torch.norm(s['delta_embedding']).item())

    cosines = np.array(cosines)
    l2s = np.array(l2s)
    ddgs = np.array(ddgs)
    delta_norms = np.array(delta_norms)

    corr_l2, p_l2 = pearsonr(l2s, np.abs(ddgs))
    corr_cos, p_cos = pearsonr(cosines, np.abs(ddgs))

    print(f"\n{'='*60}")
    print("SIGNAL QUALITY ASSESSMENT")
    print(f"{'='*60}")
    print(f"\n1. WT-MUTANT COLLINEARITY:")
    print(f"   Cosine: {np.mean(cosines):.3f}¬±{np.std(cosines):.3f}")
    print(f"   Status: {'‚ö†Ô∏è  HIGH' if np.mean(cosines) > 0.95 else '‚úì Good'}")

    print(f"\n2. MUTATION SEPARATION:")
    print(f"   L2: {np.mean(l2s):.3f}¬±{np.std(l2s):.3f}")
    print(f"   Status: {'‚ö†Ô∏è  LOW' if np.mean(l2s) < 1.0 else '‚úì Good'}")

    print(f"\n3. Œî EMBEDDING VARIANCE:")
    print(f"   ||Œî||: {np.mean(delta_norms):.3f}¬±{np.std(delta_norms):.3f}")
    print(f"   Status: {'‚ö†Ô∏è  LOW' if np.mean(delta_norms) < 0.5 else '‚úì Good'}")

    print(f"\n4. CORRELATION WITH ŒîŒîG:")
    print(f"   L2 vs |ŒîŒîG|: r={corr_l2:.3f} (p={p_l2:.2e})")
    print(f"   Cos vs |ŒîŒîG|: r={corr_cos:.3f} (p={p_cos:.2e})")

    quality_score = 0
    if np.mean(cosines) < 0.95: quality_score += 1
    if np.mean(l2s) >= 1.0: quality_score += 1
    if abs(corr_l2) >= 0.2 and p_l2 < 0.01: quality_score += 1
    if np.mean(delta_norms) >= 0.5: quality_score += 1

    print(f"\n5. OVERALL QUALITY: {quality_score}/4")
    if quality_score >= 3:
        print(f"   ‚úì GOOD - Ready for training")
    elif quality_score == 2:
        print(f"   ‚ö†Ô∏è  MODERATE - Consider additional features")
    else:
        print(f"   ‚ùå WEAK - May need fine-tuning")

    return quality_score

if 'train' in datasets:
    assess_signal_quality(datasets['train'])

In [None]:
################################################################################
# CELL 12: SAVE CONFIGURATION
################################################################################

config_metadata = {
    'model': {
        'name': 'esm2_t33_650M_UR50D',
        'embedding_dim': 1280,
        'num_layers': 33,
        'layers_averaged': extractor.layers_to_extract
    },
    'features': {
        'wt_embeddings': 'Base WT (1280-dim)',
        'mut_embeddings': 'Base mutant (1280-dim)',
        'delta_embeddings': 'mut - wt (1280-dim)',
        'abs_delta_embeddings': '|mut - wt| (1280-dim)',
        'cosine_similarities': 'cos(wt, mut)',
        'l2_distances': '||mut - wt||'
    },
    'enhancements': {
        'multi_layer_averaging': True,
        'num_layers': len(extractor.layers_to_extract),
        'derived_features': True,
        'normalization': 'Z-score (train only)'
    },
    'splits': {split: datasets[split].get_stats() for split in datasets},
    'date': pd.Timestamp.now().isoformat()
}

config_path = output_dir / 'embedding_config.json'
with open(config_path, 'w') as f:
    json.dump(config_metadata, f, indent=2)

print(f"\n‚úì Configuration saved to {config_path}")

In [None]:
################################################################################
# SUMMARY
################################################################################

print("\n" + "="*60)
print("COMPLETE! üéâ")
print("="*60)
print("""
‚úì ESM2-650M embeddings extracted (1280-dim)
‚úì Multi-layer averaging (layers 30-33)
‚úì Derived features: Œî, |Œî|, cosine, L2
‚úì Z-normalization (train statistics only)
‚úì HDF5 lazy loading enabled
‚úì Datasets validated
‚úì Visualizations generated

NEXT STEPS:
1. Use delta_config for training (most efficient)
2. Input: [Œî, |Œî|, cos, L2] = 2562 dims
3. Architecture: 2562 ‚Üí 512 ‚Üí 256 ‚Üí 128 ‚Üí 1
4. Monitor Pearson correlation on validation
5. Early stopping on best val correlation

FILES CREATED:
""")

for split in ['train', 'val', 'test']:
    h5_path = output_dir / f'{split}_embeddings.h5'
    if h5_path.exists():
        size = h5_path.stat().st_size / 1e6
        print(f"  ‚Ä¢ {h5_path.name}: {size:.1f} MB")

print(f"  ‚Ä¢ normalization_stats.npz")
print(f"  ‚Ä¢ embedding_config.json")
print(f"  ‚Ä¢ visualizations/train_analysis.png")

print("\nREADY FOR STAGE 2 TRAINING! üöÄ")