In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
################################################################################
# STAGE-2 ENHANCED ABYSSAL-STYLE ΔΔG PREDICTOR - TRAINING & EVALUATION
################################################################################
"""
This notebook extends Model_2.ipynb to exploit Stage-2 enhanced embeddings.

STAGE-1 → STAGE-2 CHANGES (INCREMENTAL ONLY):
1. ✅ Integrated Δ embeddings (Mut - WT) and |Δ| for amplified mutation signal
2. ✅ Added cosine similarity and L2 distance scalars
3. ✅ Implemented antisymmetry regularization in loss function
4. ✅ Switched from MSE to Huber loss (robust to outliers)
5. ✅ Added ONE hidden layer: 512 → 256 → 128 (gradual compression)
6. ✅ Enhanced diagnostics and sanity checks

PRESERVED FROM STAGE-1:
- Light Attention architecture (unchanged)
- Siamese network design (shared weights)
- Two-phase training strategy
- Input/output interfaces
- All existing evaluation metrics

PHYSICAL RATIONALE:
- Δ embeddings explicitly capture mutation effects (better than implicit learning)
- Antisymmetry ensures ΔΔG(WT→Mut) ≈ -ΔΔG(Mut→WT) (thermodynamic consistency)
- Huber loss reduces sensitivity to noisy experimental labels
- Deeper network exploits richer embedding signal from Stage-2
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import h5py
from pathlib import Path
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import pearsonr, spearmanr
from dataclasses import dataclass
from typing import Tuple, Dict, List
import json
import warnings
warnings.filterwarnings('ignore')

# Fix for PyTorch 2.6+ weights_only default change
import torch.serialization
torch.serialization.add_safe_globals([np.ndarray, np.dtype, np.core.multiarray.scalar])

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.dpi'] = 100
plt.rcParams['font.size'] = 11

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cudnn.benchmark = True

# Reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

print(f"✓ Random seed set to {SEED}")

Using device: cuda
GPU: NVIDIA A100-SXM4-40GB
Memory: 42.47 GB
✓ Random seed set to 42


In [3]:
################################################################################
# STAGE-2 DATASET CLASS (EXTENDED FOR ENHANCED EMBEDDINGS)
################################################################################
"""
STAGE-1 → STAGE-2 CHANGE:
Now loads:
- wt_embeddings (dim)
- mut_embeddings (dim)
- delta_embeddings (dim) = Mut - WT (explicit mutation signal)
- delta_abs_embeddings (dim) = |Mut - WT| (magnitude-only signal)
- cosine_similarity (scalar) = cosine(WT, Mut) (directional alignment)
- l2_distance (scalar) = ||Mut - WT||₂ (Euclidean separation)

RATIONALE:
- Δ embeddings: Directly encode mutation effects (amplified signal)
- |Δ|: Captures magnitude without direction (useful for symmetric mutations)
- Cosine: Measures semantic alignment (0 = orthogonal, 1 = identical)
- L2: Measures embedding space distance (complements cosine)

These features are ALREADY COMPUTED in Stage-2 embedding pipeline.
We only need to load and concatenate them properly.
"""

class MutationEmbeddingDatasetStage2(Dataset):
    """
    STAGE-2 Dataset: Loads enhanced embeddings with explicit mutation signals.

    Returns:
        wt_embedding: (dim,) - Wild-type embedding
        mut_embedding: (dim,) - Mutant embedding
        delta_embedding: (dim,) - Mut - WT (explicit mutation vector)
        delta_abs_embedding: (dim,) - |Mut - WT| (magnitude-only)
        cosine_sim: scalar - Cosine similarity between WT and Mut
        l2_dist: scalar - L2 distance between WT and Mut
        ddg: scalar - Experimental ΔΔG (kcal/mol)
    """

    def __init__(self, split: str, data_dir: Path):
        self.split = split
        self.h5_path = data_dir / f"{split}_embeddings.h5"

        with h5py.File(self.h5_path, 'r') as h5f:
            self.n_samples = h5f['wt_embeddings'].shape[0]
            self.embedding_dim = h5f['wt_embeddings'].shape[1]

            # Verify Stage-2 features exist
            required_keys = ['wt_embeddings', 'mut_embeddings',
                           'delta_embeddings', 'abs_delta_embeddings',
                           'cosine_similarities', 'l2_distances', 'ddg_values']
            missing = [k for k in required_keys if k not in h5f.keys()]
            if missing:
                raise ValueError(f"Missing Stage-2 features in {self.h5_path}: {missing}")

    def __len__(self) -> int:
        return self.n_samples

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, ...]:
        if not hasattr(self, 'h5f'):
            self.h5f = h5py.File(self.h5_path, 'r')

        wt_emb = torch.from_numpy(self.h5f['wt_embeddings'][idx]).float()
        mut_emb = torch.from_numpy(self.h5f['mut_embeddings'][idx]).float()
        delta_emb = torch.from_numpy(self.h5f['delta_embeddings'][idx]).float()
        # Fix: Use 'abs_delta_embeddings' instead of 'delta_abs_embeddings'
        delta_abs_emb = torch.from_numpy(self.h5f['abs_delta_embeddings'][idx]).float()
        # Fix: Use 'cosine_similarities' instead of 'cosine_similarity'
        cos_sim = torch.tensor(self.h5f['cosine_similarities'][idx]).float()
        # Fix: Use 'l2_distances' instead of 'l2_distance'
        l2_dist = torch.tensor(self.h5f['l2_distances'][idx]).float()
        ddg = torch.tensor(self.h5f['ddg_values'][idx]).float()

        return wt_emb, mut_emb, delta_emb, delta_abs_emb, cos_sim, l2_dist, ddg

In [4]:
################################################################################
# LIGHT ATTENTION BLOCK (UNCHANGED FROM STAGE-1)
################################################################################
"""
STAGE-1 → STAGE-2: NO CHANGES
This component remains identical - proven to work well for embedding attention.
"""

class LightAttention(nn.Module):
    """
    Light Attention Block from ABYSSAL architecture.
    UNCHANGED from Stage-1 - already optimal for this task.
    """

    def __init__(self, embedding_dim: int = 480, kernel_size: int = 9):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.kernel_size = kernel_size
        self.padding = (kernel_size - 1) // 2

        self.feature_conv = nn.Conv1d(
            in_channels=1, out_channels=1,
            kernel_size=kernel_size, padding=self.padding, bias=True
        )
        self.attention_conv = nn.Conv1d(
            in_channels=1, out_channels=1,
            kernel_size=kernel_size, padding=self.padding, bias=True
        )

        nn.init.xavier_uniform_(self.feature_conv.weight, gain=0.5)
        nn.init.xavier_uniform_(self.attention_conv.weight, gain=0.5)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        assert x.dim() == 3 and x.size(1) == 1

        features = self.feature_conv(x)
        attention_logits = self.attention_conv(x)
        attention_weights = torch.sigmoid(attention_logits)
        attended = features * attention_weights

        return attended.squeeze(1)

In [5]:
################################################################################
# STAGE-2 SIAMESE NETWORK (EXTENDED FOR ENHANCED EMBEDDINGS)
################################################################################
"""
STAGE-1 → STAGE-2 CHANGES:
1. ✅ Added input handling for Δ, |Δ|, cosine, L2 features
2. ✅ Concatenates [wt_att, mut_att, Δ_att, |Δ|_att, cos, L2] → richer input
3. ✅ Added ONE hidden layer: 512 → 256 → 128 (gradual compression)
4. ✅ Kept dropout tapering: 0.3 → 0.2 (less aggressive regularization)

PRESERVED:
- Siamese attention (shared weights for WT/Mut/Δ/|Δ|)
- Overall architecture philosophy
- Antisymmetry capability (tested via predict_antisymmetric)

RATIONALE FOR CHANGES:
- Concatenating attended Δ and |Δ| gives model explicit mutation signals
- Cosine/L2 scalars provide global similarity metrics (cheap but informative)
- Deeper network (256 layer) exploits richer 6-component input
- Gradual compression: 4*dim+2 → 512 → 256 → 128 → 1 (smooth bottleneck)
"""

class SiameseDDGPredictorStage2(nn.Module):
    """
    STAGE-2 Enhanced Siamese Network with explicit mutation signals.

    Architecture:
    1. Shared Light Attention → processes WT, Mut, Δ, |Δ| identically
    2. Concatenation: [wt_att, mut_att, Δ_att, |Δ|_att, cos_sim, l2_dist]
    3. MLP Regression: (4*dim + 2) → 512 → 256 → 128 → 1

    Key improvements over Stage-1:
    - Explicit Δ embeddings (amplified mutation signal)
    - Additional hidden layer (exploits richer input)
    - Preserved antisymmetry capability
    """

    def __init__(
        self,
        embedding_dim: int = 480,
        attention_kernel: int = 9,
        hidden_dims: List[int] = [512, 256, 128],  # STAGE-2: Added 256 layer
        dropout_rate: float = 0.3
    ):
        super().__init__()
        self.embedding_dim = embedding_dim

        # Shared Light Attention (UNCHANGED)
        self.attention = LightAttention(
            embedding_dim=embedding_dim,
            kernel_size=attention_kernel
        )

        # STAGE-2 CHANGE: Input is now [wt_att, mut_att, Δ_att, |Δ|_att, cos, L2]
        # = 4 * embedding_dim + 2 scalars
        input_dim = 4 * embedding_dim + 2

        # MLP Regressor (STAGE-2: ONE additional layer)
        layers = []
        current_dim = input_dim

        # STAGE-2: Taper dropout across layers (0.3 → 0.2 → 0.1)
        dropout_rates = [0.3, 0.2, 0.1] if len(hidden_dims) == 3 else [dropout_rate] * len(hidden_dims)

        for i, hidden_dim in enumerate(hidden_dims):
            layers.extend([
                nn.Linear(current_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout_rates[i] if i < len(dropout_rates) else 0.1)
            ])
            current_dim = hidden_dim

        # Final regression layer
        layers.append(nn.Linear(current_dim, 1))

        self.regressor = nn.Sequential(*layers)
        self._initialize_weights()

    def _initialize_weights(self):
        """Initialize network weights using best practices."""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(
        self,
        wt_embedding: torch.Tensor,
        mut_embedding: torch.Tensor,
        delta_embedding: torch.Tensor,
        delta_abs_embedding: torch.Tensor,
        cosine_sim: torch.Tensor,
        l2_dist: torch.Tensor
    ) -> torch.Tensor:
        """
        STAGE-2 Forward Pass with Enhanced Embeddings.

        Args:
            wt_embedding: (batch, dim) - Wild-type
            mut_embedding: (batch, dim) - Mutant
            delta_embedding: (batch, dim) - Mut - WT (explicit mutation)
            delta_abs_embedding: (batch, dim) - |Mut - WT| (magnitude-only)
            cosine_sim: (batch,) - Cosine similarity
            l2_dist: (batch,) - L2 distance

        Returns:
            ddg_pred: (batch,) - Predicted ΔΔG
        """
        # Sanity check: Detect NaN/Inf early
        assert not torch.isnan(wt_embedding).any(), "NaN in WT embedding"
        assert not torch.isnan(mut_embedding).any(), "NaN in Mut embedding"
        assert not torch.isinf(cosine_sim).any(), "Inf in cosine similarity"

        # Unsqueeze for Light Attention: (batch, dim) → (batch, 1, dim)
        wt_input = wt_embedding.unsqueeze(1)
        mut_input = mut_embedding.unsqueeze(1)
        delta_input = delta_embedding.unsqueeze(1)
        delta_abs_input = delta_abs_embedding.unsqueeze(1)

        # Apply shared attention (Siamese property preserved)
        wt_att = self.attention(wt_input)         # (batch, dim)
        mut_att = self.attention(mut_input)       # (batch, dim)
        delta_att = self.attention(delta_input)   # (batch, dim) - STAGE-2: Explicit Δ signal
        delta_abs_att = self.attention(delta_abs_input)  # (batch, dim) - STAGE-2: Magnitude signal

        # STAGE-2: Concatenate all attended features + scalars
        # Rationale:
        # - wt_att, mut_att: Original Stage-1 features (preserved)
        # - delta_att: Explicit mutation direction (amplified signal)
        # - delta_abs_att: Mutation magnitude (symmetric signal)
        # - cosine_sim, l2_dist: Global similarity metrics (cheap, informative)
        combined = torch.cat([
            wt_att,              # (batch, dim)
            mut_att,             # (batch, dim)
            delta_att,           # (batch, dim) - NEW
            delta_abs_att,       # (batch, dim) - NEW
            cosine_sim.unsqueeze(1),  # (batch, 1) - NEW
            l2_dist.unsqueeze(1)      # (batch, 1) - NEW
        ], dim=1)  # → (batch, 4*dim + 2)

        # Predict ΔΔG through MLP (now with deeper network)
        ddg_pred = self.regressor(combined).squeeze(-1)  # (batch,)

        return ddg_pred

    def predict_antisymmetric(
        self,
        wt_embedding: torch.Tensor,
        mut_embedding: torch.Tensor,
        delta_embedding: torch.Tensor,
        delta_abs_embedding: torch.Tensor,
        cosine_sim: torch.Tensor,
        l2_dist: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Predict both forward and reverse ΔΔG for antisymmetry check.
        PRESERVED from Stage-1 (critical for physical consistency).

        Forward: ΔΔG(WT→Mut)
        Reverse: ΔΔG(Mut→WT) should ≈ -ΔΔG(WT→Mut)

        Returns:
            forward_ddg: ΔΔG(WT→Mut)
            reverse_ddg: ΔΔG(Mut→WT)
        """
        forward_ddg = self.forward(
            wt_embedding, mut_embedding, delta_embedding,
            delta_abs_embedding, cosine_sim, l2_dist
        )

        # Reverse: swap WT ↔ Mut, negate Δ
        reverse_ddg = self.forward(
            mut_embedding, wt_embedding, -delta_embedding,
            delta_abs_embedding, cosine_sim, l2_dist  # |Δ| and distances unchanged
        )

        return forward_ddg, reverse_ddg

In [6]:
################################################################################
# STAGE-2 TRAINING CONFIGURATION
################################################################################

@dataclass
class TrainingConfigStage2:
    """
    STAGE-2 Training Configuration.

    STAGE-1 → STAGE-2 CHANGES:
    - Added lambda_antisym: Weight for antisymmetry regularization
    - Adjusted patience (8→10, 5→7): Deeper network needs more convergence time
    - Added enable_sanity_checks: Optional assertions for debugging
    - Added enable_visualizations: Toggle for diagnostic plots
    """
    # Learning rates (unchanged)
    lr_phase1: float = 3e-4
    lr_phase2: float = 1e-5

    # Training epochs (unchanged)
    epochs_phase1: int = 30
    epochs_phase2: int = 20

    # Batch size (unchanged)
    batch_size: int = 128

    # Regularization (unchanged)
    weight_decay: float = 1e-4
    gradient_clip: float = 1.0

    # STAGE-2: Antisymmetry regularization weight
    # Rationale: Enforce ΔΔG(WT→Mut) + ΔΔG(Mut→WT) ≈ 0 (thermodynamic consistency)
    # Typical values: 0.01 - 0.1 (too high → dominates main loss, too low → ignored)
    lambda_antisym: float = 0.02

    # Early stopping (STAGE-2: Increased patience for deeper network)
    patience_phase1: int = 10  # Was 8 in Stage-1
    patience_phase2: int = 7   # Was 5 in Stage-1
    min_delta: float = 0.002

    # Debugging and visualization (STAGE-2: NEW)
    enable_sanity_checks: bool = True   # Enable assertions
    enable_visualizations: bool = True  # Save diagnostic plots

    # Paths (unchanged)
    save_dir: Path = Path("/content/drive/MyDrive/Protein_prediction_model/Saved_model_Stage2")
    device: torch.device = device

In [7]:
################################################################################
# STAGE-2 TRAINING UTILITIES (EXTENDED)
################################################################################
"""
STAGE-1 → STAGE-2 CHANGES:
1. ✅ Switched from MSE to Huber Loss (SmoothL1Loss)
2. ✅ Added antisymmetry regularization term
3. ✅ Added sanity checks (NaN/Inf detection)
4. ✅ Extended metrics computation (no changes needed)

PRESERVED:
- EarlyStopping logic (uses Pearson correlation)
- Evaluation metrics (PCC, SCC, RMSE, MAE, Accuracy)
"""

class EarlyStopping:
    """Early stopping based on validation Pearson correlation (UNCHANGED)."""

    def __init__(self, patience: int = 10, min_delta: float = 0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_score = -np.inf
        self.early_stop = False
        self.best_epoch = 0

    def __call__(self, score: float, epoch: int) -> bool:
        if score > self.best_score + self.min_delta:
            self.best_score = score
            self.counter = 0
            self.best_epoch = epoch
            return False
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
                return True
            return False

def compute_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]:
    """Compute all evaluation metrics (UNCHANGED from Stage-1)."""
    pcc, pcc_pval = pearsonr(y_true, y_pred)
    scc, scc_pval = spearmanr(y_true, y_pred)
    rmse = np.sqrt(np.mean((y_true - y_pred) ** 2))
    mae = np.mean(np.abs(y_true - y_pred))

    true_signs = (y_true < 0).astype(int)
    pred_signs = (y_pred < 0).astype(int)
    accuracy = np.mean(true_signs == pred_signs)

    return {
        'pearson': pcc, 'pearson_pval': pcc_pval,
        'spearman': scc, 'spearman_pval': scc_pval,
        'rmse': rmse, 'mae': mae, 'accuracy': accuracy
    }

def train_epoch_stage2(
    model: nn.Module,
    dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
    config: TrainingConfigStage2,
    scaler: torch.cuda.amp.GradScaler = None
) -> Tuple[float, float, float, Dict[str, float]]:
    """
    STAGE-2 Training Epoch with Antisymmetry Regularization.

    Loss Function:
        L_total = L_huber(pred, target) + λ * L_antisym

        where L_antisym = || ΔΔG(WT→Mut) + ΔΔG(Mut→WT) ||²

    RATIONALE:
    - Huber loss: Robust to outliers (less sensitive than MSE to noisy labels)
    - Antisymmetry: Enforces thermodynamic consistency (physical constraint)
    - λ = 0.05: Balance between accuracy and physical validity

    Returns:
        avg_huber_loss: Average Huber loss
        avg_antisym_loss: Average antisymmetry loss
        avg_total_loss: Average combined loss
        metrics: Training metrics
    """
    model.train()

    # STAGE-2: Huber Loss (robust to outliers)
    # beta=1.0 means: |error| < 1 → L2 loss, |error| ≥ 1 → L1 loss
    # Rationale: Noisy experimental ΔΔG labels benefit from reduced outlier influence
    criterion_huber = nn.SmoothL1Loss(beta=1.0, reduction='mean')

    total_huber_loss = 0
    total_antisym_loss = 0
    total_combined_loss = 0
    all_preds = []
    all_targets = []

    pbar = tqdm(dataloader, desc="Training", leave=False)
    for wt_emb, mut_emb, delta_emb, delta_abs_emb, cos_sim, l2_dist, ddg in pbar:
        # Move to device
        wt_emb = wt_emb.to(config.device, non_blocking=True)
        mut_emb = mut_emb.to(config.device, non_blocking=True)
        delta_emb = delta_emb.to(config.device, non_blocking=True)
        delta_abs_emb = delta_abs_emb.to(config.device, non_blocking=True)
        cos_sim = cos_sim.to(config.device, non_blocking=True)
        l2_dist = l2_dist.to(config.device, non_blocking=True)
        ddg = ddg.to(config.device, non_blocking=True)

        # STAGE-2: Sanity checks (optional, lightweight)
        if config.enable_sanity_checks:
            assert not torch.isnan(wt_emb).any(), "NaN in WT embeddings"
            assert not torch.isinf(cos_sim).any(), "Inf in cosine similarity"
            assert (cos_sim >= -1.01).all() and (cos_sim <= 1.01).all(), \
                f"Cosine similarity out of bounds: {cos_sim.min():.3f}, {cos_sim.max():.3f}"

        optimizer.zero_grad(set_to_none=True)

        with torch.cuda.amp.autocast():
            # Forward prediction: ΔΔG(WT→Mut)
            pred_ddg_forward = model(wt_emb, mut_emb, delta_emb,
                                    delta_abs_emb, cos_sim, l2_dist)

            # Main regression loss (Huber)
            huber_loss = criterion_huber(pred_ddg_forward, ddg)

            # STAGE-2: Antisymmetry regularization
            # Compute reverse prediction: ΔΔG(Mut→WT)
            pred_ddg_reverse = model(mut_emb, wt_emb, -delta_emb,
                                    delta_abs_emb, cos_sim, l2_dist)

            # Antisymmetry loss: || ΔΔG(WT→Mut) + ΔΔG(Mut→WT) ||²
            # Physical meaning: Thermodynamic cycle closure
            # If model is perfectly antisymmetric, this term = 0
            antisym_loss = torch.mean((pred_ddg_forward + pred_ddg_reverse) ** 2)

            # STAGE-2: Combined loss
            total_loss = huber_loss + config.lambda_antisym * antisym_loss

        # Backward pass
        if scaler is not None:
            scaler.scale(total_loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.gradient_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.gradient_clip)
            optimizer.step()

        # Track losses
        total_huber_loss += huber_loss.item()
        total_antisym_loss += antisym_loss.item()
        total_combined_loss += total_loss.item()
        all_preds.extend(pred_ddg_forward.detach().cpu().numpy())
        all_targets.extend(ddg.cpu().numpy())

        pbar.set_postfix({
            'huber': huber_loss.item(),
            'antisym': antisym_loss.item(),
            'total': total_loss.item()
        })

    avg_huber = total_huber_loss / len(dataloader)
    avg_antisym = total_antisym_loss / len(dataloader)
    avg_total = total_combined_loss / len(dataloader)
    metrics = compute_metrics(np.array(all_targets), np.array(all_preds))

    return avg_huber, avg_antisym, avg_total, metrics

@torch.no_grad()
def evaluate_stage2(
    model: nn.Module,
    dataloader: DataLoader,
    config: TrainingConfigStage2
) -> Tuple[float, Dict[str, float]]:
    """
    STAGE-2 Evaluation (no antisymmetry loss, only Huber).
    Uses forward pass only (single direction).
    """
    model.eval()
    criterion_huber = nn.SmoothL1Loss(beta=0.8, reduction='mean')

    total_loss = 0
    all_preds = []
    all_targets = []

    for wt_emb, mut_emb, delta_emb, delta_abs_emb, cos_sim, l2_dist, ddg in \
        tqdm(dataloader, desc="Evaluating", leave=False):

        wt_emb = wt_emb.to(config.device, non_blocking=True)
        mut_emb = mut_emb.to(config.device, non_blocking=True)
        delta_emb = delta_emb.to(config.device, non_blocking=True)
        delta_abs_emb = delta_abs_emb.to(config.device, non_blocking=True)
        cos_sim = cos_sim.to(config.device, non_blocking=True)
        l2_dist = l2_dist.to(config.device, non_blocking=True)
        ddg = ddg.to(config.device, non_blocking=True)

        with torch.cuda.amp.autocast():
            pred_ddg = model(wt_emb, mut_emb, delta_emb,
                           delta_abs_emb, cos_sim, l2_dist)
            loss = criterion_huber(pred_ddg, ddg)

        total_loss += loss.item()
        all_preds.extend(pred_ddg.cpu().numpy())
        all_targets.extend(ddg.cpu().numpy())

    avg_loss = total_loss / len(dataloader)
    metrics = compute_metrics(np.array(all_targets), np.array(all_preds))

    return avg_loss, metrics

In [8]:
################################################################################
# STAGE-2 VISUALIZATION UTILITIES (NEW)
################################################################################
"""
STAGE-2 Addition: Diagnostic visualizations for monitoring training.
These are NON-BLOCKING and TOGGLEABLE via config.enable_visualizations.
"""

def plot_training_curves(history: Dict, save_path: Path):
    """STAGE-2: Plot training dynamics (loss and Pearson correlation)."""
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    epochs = range(1, len(history['train_loss']) + 1)

    # Plot 1: Huber Loss
    axes[0].plot(epochs, history['train_loss'], 'b-', label='Train', linewidth=2)
    axes[0].plot(epochs, history['val_loss'], 'r-', label='Val', linewidth=2)
    if history.get('phase1_early_stop_epoch', 0) > 0:
        axes[0].axvline(history['phase1_early_stop_epoch'],
                       color='gray', linestyle='--', label='Phase 2 Start')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Huber Loss')
    axes[0].set_title('Training Loss (Huber)')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)

    # Plot 2: Antisymmetry Loss
    if 'train_antisym_loss' in history:
        axes[1].plot(epochs, history['train_antisym_loss'], 'g-', linewidth=2)
        if history.get('phase1_early_stop_epoch', 0) > 0:
            axes[1].axvline(history['phase1_early_stop_epoch'],
                           color='gray', linestyle='--')
        axes[1].set_xlabel('Epoch')
        axes[1].set_ylabel('Antisymmetry Loss')
        axes[1].set_title('Antisymmetry Regularization')
        axes[1].grid(True, alpha=0.3)

    # Plot 3: Pearson Correlation
    axes[2].plot(epochs, history['train_pearson'], 'b-', label='Train', linewidth=2)
    axes[2].plot(epochs, history['val_pearson'], 'r-', label='Val', linewidth=2)
    axes[2].axhline(history['best_val_pearson'], color='green',
                   linestyle=':', label=f'Best={history["best_val_pearson"]:.3f}')
    if history.get('phase1_early_stop_epoch', 0) > 0:
        axes[2].axvline(history['phase1_early_stop_epoch'],
                       color='gray', linestyle='--', label='Phase 2 Start')
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('Pearson Correlation')
    axes[2].set_title('Model Performance (PCC)')
    axes[2].legend()
    axes[2].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"✓ Training curves saved: {save_path}")

def plot_evaluation_results(
    test_targets: np.ndarray,
    test_preds: np.ndarray,
    forward_preds: np.ndarray,
    reverse_preds: np.ndarray,
    results: Dict,
    save_path: Path
):
    """STAGE-2: Comprehensive evaluation plots."""
    fig = plt.figure(figsize=(18, 5))

    # Plot 1: Predicted vs Experimental
    ax1 = fig.add_subplot(131)
    scatter = ax1.scatter(test_targets, test_preds, alpha=0.6, s=50,
                         c=np.abs(test_targets - test_preds),
                         cmap='viridis', edgecolors='black', linewidth=0.5)

    min_val = min(test_targets.min(), test_preds.min())
    max_val = max(test_targets.max(), test_preds.max())
    ax1.plot([min_val, max_val], [min_val, max_val],
            'r--', linewidth=2, label='Perfect', alpha=0.8)

    textstr = '\n'.join([
        f'PCC = {results["Test"]["pearson"]:.3f}',
        f'RMSE = {results["Test"]["rmse"]:.3f}',
        f'MAE = {results["Test"]["mae"]:.3f}',
        f'n = {len(test_targets)}'
    ])
    ax1.text(0.05, 0.95, textstr, transform=ax1.transAxes,
            fontsize=10, verticalalignment='top',
            bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))

    ax1.set_xlabel('Experimental ΔΔG (kcal/mol)')
    ax1.set_ylabel('Predicted ΔΔG (kcal/mol)')
    ax1.set_title('Predicted vs Experimental')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    ax1.set_aspect('equal', adjustable='box')
    plt.colorbar(scatter, ax=ax1, label='|Error|')

    # Plot 2: Antisymmetry Check
    ax2 = fig.add_subplot(132)
    scatter2 = ax2.scatter(forward_preds, -reverse_preds, alpha=0.6, s=50,
                          c=np.abs(forward_preds + reverse_preds),
                          cmap='coolwarm', edgecolors='black', linewidth=0.5)

    min_val = min(forward_preds.min(), (-reverse_preds).min())
    max_val = max(forward_preds.max(), (-reverse_preds).max())
    ax2.plot([min_val, max_val], [min_val, max_val],
            'r--', linewidth=2, label='Perfect Antisymmetry', alpha=0.8)

    textstr = '\n'.join([
        f'PCC = {results["Antisymmetry"]["correlation"]:.3f}',
        f'Mean bias = {results["Antisymmetry"]["mean_bias"]:.4f}',
        f'|Bias| = {results["Antisymmetry"]["abs_bias"]:.4f}'
    ])
    ax2.text(0.05, 0.95, textstr, transform=ax2.transAxes,
            fontsize=10, verticalalignment='top',
            bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))

    ax2.set_xlabel('ΔΔG(WT→Mut)')
    ax2.set_ylabel('-ΔΔG(Mut→WT)')
    ax2.set_title('Antisymmetry Check')
    ax2.legend(loc='lower right')
    ax2.grid(True, alpha=0.3)
    ax2.set_aspect('equal', adjustable='box')
    plt.colorbar(scatter2, ax=ax2, label='|Forward + Reverse|')

    # Plot 3: Residual Distribution
    ax3 = fig.add_subplot(133)
    residuals = test_targets - test_preds
    ax3.hist(residuals, bins=30, edgecolor='black', alpha=0.7,
            color='steelblue', density=True)

    mu, sigma = residuals.mean(), residuals.std()
    x = np.linspace(residuals.min(), residuals.max(), 100)
    gaussian = (1/(sigma * np.sqrt(2*np.pi))) * np.exp(-0.5*((x-mu)/sigma)**2)
    ax3.plot(x, gaussian, 'r-', linewidth=2, label=f'N({mu:.2f}, {sigma:.2f}²)')
    ax3.axvline(0, color='green', linestyle='--', linewidth=2, label='Zero Error')

    textstr = '\n'.join([
        f'Mean = {mu:.3f}',
        f'Std = {sigma:.3f}',
        f'MAE = {results["Test"]["mae"]:.3f}'
    ])
    ax3.text(0.65, 0.95, textstr, transform=ax3.transAxes,
            fontsize=10, verticalalignment='top',
            bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.8))

    ax3.set_xlabel('Residual (Exp - Pred)')
    ax3.set_ylabel('Density')
    ax3.set_title('Residual Distribution')
    ax3.legend()
    ax3.grid(True, alpha=0.3, axis='y')

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"✓ Evaluation plots saved: {save_path}")

In [9]:
################################################################################
# STAGE-2 MAIN TRAINING LOOP
################################################################################

def train_model_stage2(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    config: TrainingConfigStage2
) -> Dict:
    """
    STAGE-2 Two-phase training with antisymmetry regularization.

    STAGE-1 → STAGE-2 CHANGES:
    - Uses train_epoch_stage2 (with Huber + antisymmetry loss)
    - Tracks antisymmetry loss separately
    - Extended patience for deeper network

    PRESERVED:
    - Two-phase learning rate schedule
    - Early stopping on Pearson correlation
    - Checkpoint saving logic
    """
    scaler = torch.cuda.amp.GradScaler() if config.device.type == 'cuda' else None

    history = {
        'train_loss': [], 'val_loss': [],
        'train_antisym_loss': [],  # STAGE-2: NEW
        'train_pearson': [], 'val_pearson': [],
        'train_mae': [], 'val_mae': [],
        'best_epoch': 0, 'best_val_pearson': -np.inf,
        'phase1_early_stop_epoch': 0, 'phase2_early_stop_epoch': 0
    }

    print("\n" + "="*60)
    print("STAGE-2 TRAINING STARTED")
    print("="*60)
    print(f"Model: {sum(p.numel() for p in model.parameters()):,} parameters")
    print(f"Lambda antisymmetry: {config.lambda_antisym}")
    print(f"Loss: Huber (beta=1.0) + Antisymmetry regularization")

    # ========================================================================
    # PHASE 1: Initial Training
    # ========================================================================
    print(f"\nPHASE 1: lr={config.lr_phase1}, epochs={config.epochs_phase1}")
    print("-" * 60)

    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr_phase1,
                                weight_decay=config.weight_decay)
    early_stopping = EarlyStopping(patience=config.patience_phase1,
                                  min_delta=config.min_delta)

    for epoch in range(config.epochs_phase1):
        # Train
        train_huber, train_antisym, train_total, train_metrics = \
            train_epoch_stage2(model, train_loader, optimizer, config, scaler)

        # Validate
        val_loss, val_metrics = evaluate_stage2(model, val_loader, config)

        # Record history
        history['train_loss'].append(train_huber)
        history['val_loss'].append(val_loss)
        history['train_antisym_loss'].append(train_antisym)
        history['train_pearson'].append(train_metrics['pearson'])
        history['val_pearson'].append(val_metrics['pearson'])
        history['train_mae'].append(train_metrics['mae'])
        history['val_mae'].append(val_metrics['mae'])

        print(f"Epoch {epoch+1:3d}/{config.epochs_phase1}: "
              f"train_huber={train_huber:.4f}, train_antisym={train_antisym:.4f}, "
              f"val_loss={val_loss:.4f}, val_PCC={val_metrics['pearson']:.3f}")

        # Save best model
        if val_metrics['pearson'] > history['best_val_pearson']:
            history['best_val_pearson'] = val_metrics['pearson']
            history['best_epoch'] = epoch + 1
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_pearson': val_metrics['pearson'],
                'config': config
            }, config.save_dir / 'best_model_phase1.pt')
            print(f"  → Saved (val_PCC={val_metrics['pearson']:.3f})")

        # Early stopping
        if early_stopping(val_metrics['pearson'], epoch):
            history['phase1_early_stop_epoch'] = epoch + 1
            print(f"\nEarly stop at epoch {epoch+1}")
            break

    # ========================================================================
    # PHASE 2: Fine-tuning
    # ========================================================================
    print(f"\nPHASE 2: lr={config.lr_phase2}, epochs={config.epochs_phase2}")
    print("-" * 60)

    checkpoint = torch.load(config.save_dir / 'best_model_phase1.pt',
                           weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded Phase 1 best (val_PCC={checkpoint['val_pearson']:.3f})")

    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr_phase2,
                                weight_decay=config.weight_decay)
    early_stopping = EarlyStopping(patience=config.patience_phase2,
                                  min_delta=config.min_delta)

    phase2_start = len(history['train_loss'])

    for epoch in range(config.epochs_phase2):
        train_huber, train_antisym, train_total, train_metrics = \
            train_epoch_stage2(model, train_loader, optimizer, config, scaler)
        val_loss, val_metrics = evaluate_stage2(model, val_loader, config)

        history['train_loss'].append(train_huber)
        history['val_loss'].append(val_loss)
        history['train_antisym_loss'].append(train_antisym)
        history['train_pearson'].append(train_metrics['pearson'])
        history['val_pearson'].append(val_metrics['pearson'])
        history['train_mae'].append(train_metrics['mae'])
        history['val_mae'].append(val_metrics['mae'])

        print(f"Epoch {epoch+1:3d}/{config.epochs_phase2}: "
              f"train_huber={train_huber:.4f}, val_loss={val_loss:.4f}, "
              f"val_PCC={val_metrics['pearson']:.3f}")

        if val_metrics['pearson'] > history['best_val_pearson']:
            history['best_val_pearson'] = val_metrics['pearson']
            history['best_epoch'] = phase2_start + epoch + 1
            print(f"  → Best (val_PCC={val_metrics['pearson']:.3f})")

        torch.save({
            'epoch': phase2_start + epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_pearson': val_metrics['pearson'],
            'config': config
        }, config.save_dir / 'best_model_final.pt')

        if early_stopping(val_metrics['pearson'], epoch):
            history['phase2_early_stop_epoch'] = epoch + 1
            print(f"\nEarly stop at epoch {epoch+1}")
            break

    print("\n" + "="*60)
    print("TRAINING COMPLETED")
    print(f"Best val Pearson: {history['best_val_pearson']:.3f} @ epoch {history['best_epoch']}")
    print("="*60)

    return history

In [None]:
################################################################################
# MAIN EXECUTION
################################################################################

# Configure paths
DATA_DIR = Path("/content/drive/MyDrive/Protein_prediction_model/abyssal_embeddings/stage_2_8M")
config = TrainingConfigStage2()
config.save_dir.mkdir(exist_ok=True, parents=True)

print("="*60)
print("STAGE-2: LOADING ENHANCED EMBEDDINGS")
print("="*60)

# Check files
required_files = ['train_embeddings.h5', 'val_embeddings.h5', 'test_embeddings.h5']
missing = [f for f in required_files if not (DATA_DIR / f).exists()]
if missing:
    print(f"ERROR: Missing files: {missing}")
    print("Please run Stage-2 embedding extraction first")
else:
    print("✓ All embedding files found")

    # Load datasets
    train_dataset = MutationEmbeddingDatasetStage2('train', DATA_DIR)
    val_dataset = MutationEmbeddingDatasetStage2('val', DATA_DIR)
    test_dataset = MutationEmbeddingDatasetStage2('test', DATA_DIR)

    print(f"\nDatasets:")
    print(f"  Train: {len(train_dataset)} samples")
    print(f"  Val:   {len(val_dataset)} samples")
    print(f"  Test:  {len(test_dataset)} samples")

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=config.batch_size,
                             shuffle=True, num_workers=2, pin_memory=True,
                             persistent_workers=True)
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size,
                           shuffle=False, num_workers=2, pin_memory=True,
                           persistent_workers=True)
    test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                            shuffle=False, num_workers=2, pin_memory=True,
                            persistent_workers=True)

    # Initialize model
    # FIX: Use the actual embedding_dim from the dataset
    model = SiameseDDGPredictorStage2(
        embedding_dim=train_dataset.embedding_dim, # Changed from hardcoded 480
        attention_kernel=9,
        hidden_dims=[512, 256, 128],  # STAGE-2: Added 256 layer
        dropout_rate=0.3
    ).to(config.device)

    print(f"\nModel: {sum(p.numel() for p in model.parameters()):,} parameters")

    # Train
    history = train_model_stage2(model, train_loader, val_loader, config)

    # Plot training curves
    if config.enable_visualizations:
        plot_training_curves(history, config.save_dir / 'training_curves_stage2.png')

    # Final evaluation
    print("\n" + "="*60)
    print("FINAL EVALUATION")
    print("="*60)

    checkpoint = torch.load(config.save_dir / 'best_model_final.pt', weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    results = {}
    for name, loader in [('Train', train_loader), ('Val', val_loader), ('Test', test_loader)]:
        loss, metrics = evaluate_stage2(model, loader, config)
        results[name] = {'loss': loss, **metrics}
        print(f"✓ {name} evaluated")

    # Antisymmetry evaluation
    print("\nAntisymmetry Evaluation...")
    forward_preds, reverse_preds, true_ddg = [], [], []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Antisymmetry"):
            wt, mut, delta, delta_abs, cos, l2, ddg = batch
            wt = wt.to(config.device)
            mut = mut.to(config.device)
            delta = delta.to(config.device)
            delta_abs = delta_abs.to(config.device)
            cos = cos.to(config.device)
            l2 = l2.to(config.device)

            fwd = model(wt, mut, delta, delta_abs, cos, l2)
            rev = model(mut, wt, -delta, delta_abs, cos, l2)

            forward_preds.extend(fwd.cpu().numpy())
            reverse_preds.extend(rev.cpu().numpy())
            true_ddg.extend(ddg.numpy())

    forward_preds = np.array(forward_preds)
    reverse_preds = np.array(reverse_preds)
    true_ddg = np.array(true_ddg)

    antisym_corr, _ = pearsonr(forward_preds, -reverse_preds)
    mean_bias = np.mean(forward_preds + reverse_preds)
    abs_bias = np.mean(np.abs(forward_preds + reverse_preds))

    results['Antisymmetry'] = {
        'correlation': antisym_corr,
        'mean_bias': mean_bias,
        'abs_bias': abs_bias
    }

    # Print results table
    print("\n" + "="*60)
    print("RESULTS TABLE")
    print("="*60)
    print("\nDataset   | PCC ↑  | SCC ↑  | RMSE ↓ | MAE ↓  | Acc ↑  |")
    print("----------|--------|--------|--------|--------|--------|")

    for name in ['Train', 'Val', 'Test']:
        r = results[name]
        print(f"{name:9s} | {r['pearson']:6.3f} | {r['spearman']:6.3f} | "
              f"{r['rmse']:6.3f} | {r['mae']:6.3f} | {r['accuracy']:6.3f} |")

    print(f"\nAntisymmetry: PCC={antisym_corr:.3f}, Bias={mean_bias:.4f}")

    # Save results
    results_dict = {k: {kk: float(vv) if isinstance(vv, (np.floating, float)) else vv
                       for kk, vv in v.items()}
                   for k, v in results.items()}

    with open(config.save_dir / 'results_stage2.json', 'w') as f:
        json.dump(results_dict, f, indent=2)

    # Evaluation plots
    if config.enable_visualizations:
        test_preds = []
        test_targets = []

        with torch.no_grad():
            for batch in test_loader:
                wt, mut, delta, delta_abs, cos, l2, ddg = batch
                wt = wt.to(config.device)
                mut = mut.to(config.device)
                delta = delta.to(config.device)
                delta_abs = delta_abs.to(config.device)
                cos = cos.to(config.device)
                l2 = l2.to(config.device)

                pred = model(wt, mut, delta, delta_abs, cos, l2)
                test_preds.extend(pred.cpu().numpy())
                test_targets.extend(ddg.numpy())

        test_preds = np.array(test_preds)
        test_targets = np.array(test_targets)

        plot_evaluation_results(
            test_targets, test_preds, forward_preds, reverse_preds,
            results, config.save_dir / 'evaluation_plots_stage2.png'
        )

    print("\n" + "="*60)
    print("STAGE-2 COMPLETE")
    print("="*60)
    print(f"✓ Model: {config.save_dir / 'best_model_final.pt'}")
    print(f"✓ Results: {config.save_dir / 'results_stage2.json'}")
    if config.enable_visualizations:
        print(f"✓ Plots: {config.save_dir / 'training_curves_stage2.png'}")
        print(f"✓ Plots: {config.save_dir / 'evaluation_plots_stage2.png'}")

STAGE-2: LOADING ENHANCED EMBEDDINGS
✓ All embedding files found

Datasets:
  Train: 277012 samples
  Val:   58883 samples
  Test:  39665 samples

Model: 821,269 parameters

STAGE-2 TRAINING STARTED
Model: 821,269 parameters
Lambda antisymmetry: 0.02
Loss: Huber (beta=1.0) + Antisymmetry regularization

PHASE 1: lr=0.0003, epochs=30
------------------------------------------------------------


Training:   0%|          | 0/2165 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/461 [00:00<?, ?it/s]

Epoch   1/30: train_huber=0.2550, train_antisym=1.4323, val_loss=0.4638, val_PCC=0.561
  → Saved (val_PCC=0.561)


Training:   0%|          | 0/2165 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/461 [00:00<?, ?it/s]

Epoch   2/30: train_huber=0.1986, train_antisym=1.2679, val_loss=0.4440, val_PCC=0.573
  → Saved (val_PCC=0.573)


Training:   0%|          | 0/2165 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/461 [00:00<?, ?it/s]

Epoch   3/30: train_huber=0.1679, train_antisym=1.2440, val_loss=0.4297, val_PCC=0.575
  → Saved (val_PCC=0.575)


Training:   0%|          | 0/2165 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/461 [00:00<?, ?it/s]

Epoch   4/30: train_huber=0.1494, train_antisym=1.2219, val_loss=0.4224, val_PCC=0.590
  → Saved (val_PCC=0.590)


Training:   0%|          | 0/2165 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/461 [00:00<?, ?it/s]

Epoch   5/30: train_huber=0.1386, train_antisym=1.1851, val_loss=0.4528, val_PCC=0.573


Training:   0%|          | 0/2165 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/461 [00:00<?, ?it/s]

Epoch   6/30: train_huber=0.1303, train_antisym=1.1390, val_loss=0.4595, val_PCC=0.567


Training:   0%|          | 0/2165 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/461 [00:00<?, ?it/s]

Epoch   7/30: train_huber=0.1254, train_antisym=1.1025, val_loss=0.4590, val_PCC=0.568


Training:   0%|          | 0/2165 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/461 [00:00<?, ?it/s]

Epoch   8/30: train_huber=0.1208, train_antisym=1.0651, val_loss=0.4493, val_PCC=0.569


Training:   0%|          | 0/2165 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/461 [00:00<?, ?it/s]

Epoch   9/30: train_huber=0.1179, train_antisym=1.0408, val_loss=0.4413, val_PCC=0.578


Training:   0%|          | 0/2165 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/461 [00:00<?, ?it/s]

Epoch  10/30: train_huber=0.1151, train_antisym=1.0206, val_loss=0.4414, val_PCC=0.573


Training:   0%|          | 0/2165 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/461 [00:00<?, ?it/s]

Epoch  11/30: train_huber=0.1131, train_antisym=0.9920, val_loss=0.4398, val_PCC=0.582


Training:   0%|          | 0/2165 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/461 [00:00<?, ?it/s]

Epoch  12/30: train_huber=0.1111, train_antisym=0.9735, val_loss=0.4407, val_PCC=0.574


Training:   0%|          | 0/2165 [00:00<?, ?it/s]