# FoldTree2 Model Training and Analysis (with Config Support)

This notebook trains a protein structure prediction model using FoldTree2's encoder-decoder architecture. The model learns to encode protein structures into discrete embeddings and decode them back to predict amino acid sequences and structural contacts.

**New Features:**
- **Consolidated Hyperparameters**: All hyperparameters are now defined in a single cell for easy modification
- **Config File Support**: Can load hyperparameters from YAML or JSON config files
- **Command-line Override**: Config file values can be overridden with cell parameters

## Training Process
The notebook demonstrates:
- **Vector Quantized Encoding**: Proteins are encoded into discrete embedding sequences using a transformer-based encoder
- **Multi-task Decoding**: The decoder predicts amino acid sequences, contact maps, and geometric properties
- **Progressive Learning**: Training occurs over multiple epochs with various loss components (reconstruction, contact prediction, VQ regularization)

## Training Visualizations
During training, the notebook generates comprehensive analysis plots showing:
- **Contact Prediction**: Predicted vs. true contact maps for protein residue interactions
- **Distance Analysis**: True distance matrices and binary contact classifications
- **Performance Metrics**: ROC curves and precision-recall analysis for contact prediction accuracy
- **Sequence Embedding**: Color-coded visualization of the discrete embedding alphabet learned by the model
- **3D Structure**: Interactive molecular visualization colored by embedding states
- **Bond Angles**: Comparison of predicted vs. true backbone bond angles

This provides real-time feedback on model performance across sequence, contact, and geometric prediction tasks.

In [None]:
#use autoreload
%load_ext autoreload
%autoreload 2

In [None]:
cd /home/dmoi/projects/foldtree2/

In [None]:
# Imports
import torch
from torch_geometric.data import DataLoader
import numpy as np
from foldtree2.src import pdbgraph
from foldtree2.src import encoder as ecdr
from foldtree2.src.losses.losses import recon_loss_diag , aa_reconstruction_loss, angles_reconstruction_loss
import os
import tqdm
import random
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import yaml
import json
from pathlib import Path

In [None]:
# Import transformers schedulers
from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
from transformers import get_cosine_with_hard_restarts_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup

## Configuration Management

This cell handles loading and merging configurations from:
1. Config file (if provided)
2. Cell-defined parameters (can override config file)

You can specify a config file path below, or leave it as `None` to use only cell parameters.

In [None]:
# ============================================================================
# CONSOLIDATED HYPERPARAMETERS
# ============================================================================
# All hyperparameters are defined in this single cell for easy management
# These can be overridden by loading a config file (YAML or JSON)
# ============================================================================

# --- CONFIG FILE LOADING ---
config_file = None  # Set to path like 'config.yaml' to load from file
# Example: config_file = 'config_notebook_1k_epochs.yaml'

# --- DATA PARAMETERS ---
datadir = '../../datasets/foldtree2/'
dataset_path = 'structs_train_final.h5'
aapropcsv = './foldtree2/config/aaindex1.csv'

# --- MODEL ARCHITECTURE PARAMETERS ---
# Alphabet/Embedding parameters
num_embeddings = 30
embedding_dim = 128

# Network size
hidden_size = 150

# Encoder parameters
encoder_type = 'mk1_Encoder'  # Options: 'mk1_Encoder', 'mk1_MuonEncoder'
encoder_hidden_channels = [150, 150, 150]
encoder_nheads = 16
encoder_dropout = 0.005
encoder_flavor = 'transformer'
encoder_fftin = True
encoder_learn_positions = True
encoder_concat_positions = False

# Decoder parameters
use_monodecoder = True  # True for MultiMonoDecoder, False for single decoder
use_muon_decoders = True  # Use Muon-compatible decoders

# --- TRAINING PARAMETERS ---
num_epochs = 300
batch_size = 10
gradient_accumulation_steps = 2
clip_grad = True
mask_plddt = True
plddt_threshold = 0.3
num_workers = 4

# Learning rate and scheduler
learning_rate = 1e-5
scheduler_type = 'plateau'  # Options: 'plateau', 'linear', 'cosine', 'cosine_with_restarts', 'polynomial'
warmup_steps = 20
warmup_ratio = 0.05  # Alternative: ratio of total training steps for warmup

# Optimizer parameters
use_muon = True  # Use Muon optimizer (hybrid Muon+AdamW)
muon_lr = 0.02  # Learning rate for Muon (hidden weights)
adamw_lr = 1e-4  # Learning rate for AdamW (gains/biases/other params)
weight_decay = 0.01

# --- LOSS WEIGHTS ---
edgeweight = 0.1
logitweight = 0.1
xweight = 0.1
fft2weight = 0.01
vqweight = 0.005
angles_weight = 0.1
ss_weight = 0.1

# Loss weight scheduler
use_weight_scheduler = True
loss_scheduler_type = 'linear'  # Options: 'linear', 'cosine', 'cosine_restarts', 'polynomial', 'constant'
loss_warmup_steps = 20

# --- COMMITMENT COST SCHEDULING ---
use_commitment_scheduling = True
commitment_cost_final = 0.9
commitment_warmup_steps = 1000
commitment_schedule = 'linear'  # Options: 'cosine', 'linear', 'none'
commitment_start = 0.5

# --- MIXED PRECISION TRAINING ---
use_mixed_precision = True

# --- REPRODUCIBILITY ---
random_seed = 0

# ============================================================================
# LOAD CONFIG FILE (if specified)
# ============================================================================
if config_file is not None and os.path.exists(config_file):
    print(f"Loading configuration from: {config_file}")
    
    with open(config_file, 'r') as f:
        if config_file.endswith('.yaml') or config_file.endswith('.yml'):
            config = yaml.safe_load(f)
        elif config_file.endswith('.json'):
            config = json.load(f)
        else:
            raise ValueError("Config file must be YAML (.yaml/.yml) or JSON (.json)")
    
    # Override cell parameters with config file values
    # Only override if key exists in config
    for key, value in config.items():
        if key in locals():
            locals()[key] = value
            print(f"  {key}: {value} (from config)")
        else:
            print(f"  Warning: Unknown config key '{key}' - ignoring")
    
    print("Configuration loaded successfully!")
else:
    if config_file is not None:
        print(f"Warning: Config file '{config_file}' not found. Using cell parameters.")
    else:
        print("No config file specified. Using cell parameters.")

# ============================================================================
# DISPLAY CONFIGURATION
# ============================================================================
print("\n" + "="*60)
print("TRAINING CONFIGURATION")
print("="*60)
print(f"Dataset: {dataset_path}")
print(f"Num Epochs: {num_epochs}")
print(f"Batch Size: {batch_size}")
print(f"Gradient Accumulation: {gradient_accumulation_steps}")
print(f"Learning Rate: {learning_rate}")
print(f"Scheduler: {scheduler_type}")
print(f"Optimizer: {'Muon+AdamW' if use_muon else 'AdamW'}")
print(f"Num Embeddings: {num_embeddings}")
print(f"Embedding Dim: {embedding_dim}")
print(f"Hidden Size: {hidden_size}")
print(f"Mixed Precision: {use_mixed_precision}")
print(f"Mask pLDDT: {mask_plddt} (threshold: {plddt_threshold})")
print("\nLoss Weights:")
print(f"  Edge: {edgeweight}, Logit: {logitweight}, X: {xweight}")
print(f"  FFT2: {fft2weight}, VQ: {vqweight}, Angles: {angles_weight}, SS: {ss_weight}")
print("="*60)

In [None]:
# Set seeds for reproducibility
torch.manual_seed(random_seed)
np.random.seed(random_seed)
random.seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Data setup
converter = pdbgraph.PDB2PyG(aapropcsv=aapropcsv)
struct_dat = pdbgraph.StructureDataset(dataset_path)
train_loader = DataLoader(struct_dat, batch_size=batch_size, shuffle=True, num_workers=num_workers)

data_sample = next(iter(train_loader))

# Calculate training steps
training_steps = len(train_loader) * num_epochs

print(f"Loaded {len(struct_dat)} structures")
print(f"Training steps per epoch: {len(train_loader)}")
print(f"Total training steps: {training_steps}")

In [None]:
#print the cuda devices available and their specs
if torch.cuda.is_available():
	for i in range(torch.cuda.device_count()):
		print(f"Device {i}: {torch.cuda.get_device_name(i)}")
		print(f"  Memory Allocated: {torch.cuda.memory_allocated(i)} bytes")
		print(f"  Memory Cached: {torch.cuda.memory_reserved(i)} bytes")
else:
	print("No CUDA devices available.")

In [None]:
# Get dimensions from data sample
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

ndim = data_sample['res'].x.shape[1]
ndim_godnode = data_sample['godnode'].x.shape[1]
ndim_fft2i = data_sample['fourier2di'].x.shape[1]
ndim_fft2r = data_sample['fourier2dr'].x.shape[1]

print(f"Using device: {device}")
print(f"Residue features dim: {ndim}")
print(f"Godnode dim: {ndim_godnode}")
print(f"FFT dimensions: {ndim_fft2r} (real), {ndim_fft2i} (imag)")

In [None]:
import math
from functools import partial	

def loss_weight_scheduler(step, total_steps, schedule_type='linear', warmup_steps=0, power=1.0, num_cycles=1):
	"""
	Loss weight scheduler that modulates weights during training.
	
	Args:
		step: Current training step
		total_steps: Total number of training steps
		schedule_type: Type of schedule ('linear', 'cosine', 'cosine_restarts', 'polynomial', 'constant')
		warmup_steps: Number of steps to warmup from 0 to 1
		power: Power for polynomial decay (only used for 'polynomial')
		num_cycles: Number of cycles for cosine with restarts
		
	Returns:
		weight: Scalar weight multiplier in range [0, 1]
	"""
	
	# Warmup phase
	if step < warmup_steps:
		return step / warmup_steps
	
	# Adjust step for post-warmup scheduling
	progress = (step - warmup_steps) / (total_steps - warmup_steps)
	progress = min(progress, 1.0)
	
	if schedule_type == 'constant':
		return 1.0
	
	elif schedule_type == 'linear':
		# Linear decay from 1.0 to 0.0
		return 1.0 - progress
	
	elif schedule_type == 'cosine':
		# Cosine annealing from 1.0 to 0.0
		return 0.5 * (1.0 + math.cos(math.pi * progress))
	
	elif schedule_type == 'cosine_restarts':
		# Cosine with hard restarts (SGDR)
		cycle_progress = (progress * num_cycles) % 1.0
		return 0.5 * (1.0 + math.cos(math.pi * cycle_progress))
	
	elif schedule_type == 'polynomial':
		# Polynomial decay
		return (1.0 - progress) ** power
	
	else:
		raise ValueError(f"Unknown schedule_type: {schedule_type}")

# Define partial functions for each loss weight scheduler
if use_weight_scheduler:
	x_scheduler = partial(loss_weight_scheduler,
						  total_steps=training_steps, 
						  schedule_type=loss_scheduler_type,
						  warmup_steps=loss_warmup_steps)

	logit_scheduler = partial(loss_weight_scheduler,
						  total_steps=training_steps, 
						  schedule_type=loss_scheduler_type,
						  warmup_steps=loss_warmup_steps)

	edgeweight_scheduler = partial(loss_weight_scheduler,
						  total_steps=training_steps, 
						  schedule_type='cosine_restarts',
						  num_cycles=3,
						  warmup_steps=loss_warmup_steps)

	ss_scheduler = partial(loss_weight_scheduler,
						  total_steps=training_steps, 
						  schedule_type=loss_scheduler_type,
						  warmup_steps=loss_warmup_steps,
						  power=2.0)

	vq_scheduler = partial(loss_weight_scheduler,
						  total_steps=training_steps, 
						  schedule_type='cosine_restarts',	
						  num_cycles=10,
						  warmup_steps=loss_warmup_steps)
	
	fft2_scheduler = partial(loss_weight_scheduler,
						  total_steps=training_steps, 
						  schedule_type='cosine_restarts',	
						  num_cycles=5,
						  warmup_steps=loss_warmup_steps)
	
	angles_scheduler = partial(loss_weight_scheduler,
							total_steps=training_steps,
							schedule_type=loss_scheduler_type,
							warmup_steps=loss_warmup_steps)
	
	print("Loss weight schedulers initialized")
else:
	print("Loss weight scheduling disabled - using constant weights")

In [None]:
def get_scheduler(optimizer, scheduler_type, num_warmup_steps, num_training_steps, **kwargs):
	if scheduler_type == 'linear':
		return get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)
	elif scheduler_type == 'cosine':
		return get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)
	elif scheduler_type == 'cosine_with_restarts':
		return get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)
	elif scheduler_type == 'polynomial':
		return get_polynomial_decay_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, lr_end=0.0, power=1.0)
	elif scheduler_type == 'plateau':
		# ReduceLROnPlateau doesn't require distributed process groups - it only monitors loss values
		return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, **kwargs)
	else:
		raise ValueError(f"Unknown scheduler type: {scheduler_type}")

# Calculate warmup steps
num_training_steps = num_epochs * len(train_loader)
if isinstance(warmup_ratio, float) and warmup_ratio > 0:
	num_warmup_steps = int(num_training_steps * warmup_ratio)
else:
	num_warmup_steps = warmup_steps

print(f"Scheduler: {scheduler_type}")
print(f"Total training steps: {num_training_steps}")
print(f"Warmup steps: {num_warmup_steps}")

In [None]:
# Initialize Encoder
print(f"Initializing encoder: {encoder_type}")

encoder = ecdr.mk1_Encoder(
	in_channels=ndim,
	hidden_channels=encoder_hidden_channels,
	out_channels=embedding_dim,
	metadata={'edge_types': [('res','contactPoints','res')]},
	num_embeddings=num_embeddings,
	commitment_cost=commitment_cost_final,
	edge_dim=1,
	encoder_hidden=hidden_size,
	EMA=True,
	nheads=encoder_nheads,
	dropout_p=encoder_dropout,
	reset_codes=False,
	flavor=encoder_flavor,
	fftin=encoder_fftin,
	use_commitment_scheduling=use_commitment_scheduling,
	commitment_warmup_steps=commitment_warmup_steps,
	commitment_schedule=commitment_schedule,
	commitment_start=commitment_start,
	concat_positions=encoder_concat_positions,
	learn_positions=encoder_learn_positions
)

print(encoder)
encoder = encoder.to(device)

In [None]:
from foldtree2.src.mono_decoders import MultiMonoDecoder

# Initialize Decoder
if use_monodecoder:
	print("Initializing MultiMonoDecoder")
	
	mono_configs = {
		'sequence_transformer': {
			'in_channels': {'res': embedding_dim},
			'xdim': 20,
			'concat_positions': False,
			'hidden_channels': {('res','backbone','res'): [hidden_size], ('res','backbonerev','res'): [hidden_size]},
			'layers': 2,
			'AAdecoder_hidden': [hidden_size, hidden_size, hidden_size//2], 
			'amino_mapper': converter.aaindex,
			'nheads': 2,
			'dropout': 0.001,
			'normalize': False,
			'residual': False,
			'use_cnn_decoder': False,
			'output_ss': False,
			'learn_positions': True,
			'concat_positions': False
		},
		
		'geometry_transformer': {
			'in_channels': {'res': embedding_dim},
			'concat_positions': False,
			'hidden_channels': {('res','backbone','res'): [hidden_size], ('res','backbonerev','res'): [hidden_size]},
			'layers': 2,
			'nheads': 2,
			'RTdecoder_hidden': [hidden_size, hidden_size, hidden_size//2],
			'ssdecoder_hidden': [hidden_size, hidden_size, hidden_size//2],
			'anglesdecoder_hidden': [hidden_size, hidden_size, hidden_size//2],
			'dropout': 0.001,
			'normalize': False,
			'residual': False,
			'learn_positions': True,
			'use_cnn_decoder': True,
			'concat_positions': False,
			'output_rt': False,
			'output_ss': True,
			'output_angles': True
		},
		
		'geometry_cnn': {
			'in_channels': {'res': embedding_dim, 'godnode4decoder': ndim_godnode, 'foldx': 23, 'fft2r': ndim_fft2r, 'fft2i': ndim_fft2i},
			'concat_positions': False,
			'conv_channels': [2*hidden_size, hidden_size, hidden_size],
			'kernel_sizes': [3, 3, 3],
			'FFT2decoder_hidden': [hidden_size//2, hidden_size//2],
			'contactdecoder_hidden': [hidden_size//2, hidden_size//4],
			'ssdecoder_hidden': [hidden_size//2, hidden_size//2],
			'Xdecoder_hidden': [hidden_size, hidden_size], 
			'anglesdecoder_hidden': [hidden_size, hidden_size, hidden_size//2],
			'RTdecoder_hidden': [hidden_size//2, hidden_size//4],
			'metadata': converter.metadata, 
			'dropout': 0.001,
			'output_fft': False,
			'output_rt': False,
			'output_angles': False,
			'output_ss': False,
			'normalize': True,
			'residual': False,
			'output_edge_logits': True,
			'ncat': 8,
			'contact_mlp': False,
			'pool_type': 'global_mean',
			'learn_positions': True,
			'concat_positions': False
		},
	}

	decoder = MultiMonoDecoder(configs=mono_configs)
else:
	print("Initializing single HeteroGAE_Decoder")
	decoder = ecdr.HeteroGAE_Decoder(
		in_channels={'res': embedding_dim, 'godnode4decoder': ndim_godnode, 'foldx': 23},
		concat_positions=False,
		hidden_channels={('res','backbone','res'): [hidden_size]*5, ('res','backbonerev','res'): [hidden_size]*5},
		layers=3,
		AAdecoder_hidden=[hidden_size, hidden_size, hidden_size//2],
		Xdecoder_hidden=[hidden_size, hidden_size, hidden_size],
		contactdecoder_hidden=[hidden_size//2, hidden_size//2],
		anglesdecoder_hidden=[hidden_size//2, hidden_size//2, hidden_size//4],
		nheads=5,
		amino_mapper=converter.aaindex,
		flavor='sage',
		dropout=0.005,
		normalize=True,
		residual=False,
		contact_mlp=False
	)

decoder = decoder.to(device)
print(decoder)

In [None]:
# Training loop setup
import time
from collections import defaultdict
from muon import MuonWithAuxAdam

encoder.device = device
encoder.train()
decoder.train()

if not use_muon:
	print("Using AdamW optimizer")
	optimizer = torch.optim.AdamW(
		list(encoder.parameters()) + list(decoder.parameters()), 
		lr=learning_rate,
		weight_decay=weight_decay
	)
else:
	print("Using Muon+AdamW hybrid optimizer")
	hidden_weights = []
	hidden_gains_biases = []
	nonhidden_params = []
	
	# Helper function to check if a model has modular structure
	def has_modular_structure(model):
		return hasattr(model, 'input') and hasattr(model, 'body') and hasattr(model, 'head')
	
	# Process encoder
	if has_modular_structure(encoder):
		print("  Encoder: modular structure detected")
		hidden_weights += [p for p in encoder.body.parameters() if p.ndim >= 2]
		hidden_gains_biases += [p for p in encoder.body.parameters() if p.ndim < 2]
		nonhidden_params += [*encoder.head.parameters(), *encoder.input.parameters()]
	else:
		print("  Encoder: non-modular, using AdamW")
		nonhidden_params += list(encoder.parameters())
	
	# Process decoder
	if hasattr(decoder, 'decoders'):
		print(f"  Decoder: MultiMonoDecoder with {len(decoder.decoders)} sub-decoders")
		for name, subdecoder in decoder.decoders.items():
			if has_modular_structure(subdecoder):
				hidden_weights += [p for p in subdecoder.body.parameters() if p.ndim >= 2]
				hidden_gains_biases += [p for p in subdecoder.body.parameters() if p.ndim < 2]
				nonhidden_params += [*subdecoder.head.parameters(), *subdecoder.input.parameters()]
			else:
				nonhidden_params += list(subdecoder.parameters())
	elif has_modular_structure(decoder):
		print("  Decoder: modular structure detected")
		hidden_weights += [p for p in decoder.body.parameters() if p.ndim >= 2]
		hidden_gains_biases += [p for p in decoder.body.parameters() if p.ndim < 2]
		nonhidden_params += [*decoder.head.parameters(), *decoder.input.parameters()]
	else:
		print("  Decoder: non-modular, using AdamW")
		nonhidden_params += list(decoder.parameters())
	
	print(f"\nParameter groups:")
	print(f"  Hidden weights (Muon): {len(hidden_weights)} tensors")
	print(f"  Hidden gains/biases (AdamW): {len(hidden_gains_biases)} tensors")
	print(f"  Non-hidden params (AdamW): {len(nonhidden_params)} tensors")
	
	param_groups = [
		dict(params=hidden_weights, use_muon=True, lr=muon_lr, weight_decay=weight_decay),
		dict(params=hidden_gains_biases+nonhidden_params, use_muon=False,
			 lr=adamw_lr, betas=(0.9, 0.95), weight_decay=weight_decay),
	]

	optimizer = MuonWithAuxAdam(param_groups)

# Initialize scheduler
scheduler = get_scheduler(
	optimizer, 
	scheduler_type=scheduler_type,
	num_training_steps=num_training_steps,
	num_warmup_steps=num_warmup_steps,
)

scheduler_step_mode = 'epoch' if scheduler_type == 'plateau' else 'step'

print(f"\nScheduler: {scheduler_type} (step mode: {scheduler_step_mode})")
print(f"Mixed precision: {use_mixed_precision}")

## Muon Optimizer Training

The Muon optimizer uses a hybrid approach:
- **Muon optimizer** for hidden layer weights (2D+ tensors in body modules) - uses momentum-based Newton updates
- **AdamW optimizer** for gains, biases, and non-hidden parameters (input/head modules)

This configuration is optimal for deep networks with modular architectures where:
- Body modules contain the core transformations (graph convolutions, CNNs, transformers)
- Input/Head modules handle preprocessing and task-specific outputs

In [None]:
# Verify Muon setup
if use_muon:
	print("\n" + "="*60)
	print("MUON OPTIMIZER INITIALIZATION")
	print("="*60)
	
	import os as dist_os
	import torch.distributed as dist
	
	# Clean up any existing process group
	if dist.is_initialized():
		dist.destroy_process_group()
		print("Destroyed existing process group")
	
	# Find an available port
	import socket
	def find_free_port():
		with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
			s.bind(('', 0))
			s.listen(1)
			port = s.getsockname()[1]
		return port
	
	free_port = find_free_port()
	
	dist_os.environ['MASTER_ADDR'] = 'localhost'
	dist_os.environ['MASTER_PORT'] = str(free_port)
	dist_os.environ['RANK'] = '0'
	dist_os.environ['WORLD_SIZE'] = '1'
	dist.init_process_group(backend='gloo', init_method='env://')
	print(f"Initialized process group on port {free_port}")
	print("="*60)

In [None]:
# Initialize mixed precision scaler if enabled
if use_mixed_precision:
	scaler = GradScaler()
	print("Mixed precision training enabled (GradScaler initialized)")
else:
	scaler = None
	print("Mixed precision training disabled")

## Training Loop

The training loop below continues from the original notebook but uses the consolidated hyperparameters defined above.

In [None]:
# Get a sample from the dataloader for testing
train_loader_test = DataLoader(struct_dat, batch_size=1, shuffle=True, num_workers=4)
randint = random.randint(0, len(train_loader_test) - 1)
print(f"Randomly selected batch index: {randint}")
data_sample = struct_dat[randint]
print(data_sample)
data = data_sample.to(device)
optimizer.zero_grad()
z, vqloss = encoder(data, debug=True)
print('Encoded z shape:', z.shape)

In [None]:
# Filter proteins with at least 200 amino acids and average plddt > 0.7
import json
check_plddt = False
if check_plddt or not os.path.exists('plddt_dataset.json'):
	# Collect valid protein indices
	valid_indices = []
	for idx in range(len(struct_dat)):
		try:
			data = struct_dat[idx]
			num_residues = data['AA'].x.shape[0]
			avg_plddt = data['plddt'].x.mean().item()
			
			if num_residues >= 200 and avg_plddt > 0.7:
				valid_indices.append(idx)
				print(f"Index {idx}: {data.identifier}, {num_residues} residues, avg pLDDT={avg_plddt:.3f}")
		except Exception as e:
			continue

	print(f"\nFound {len(valid_indices)} proteins with ≥200 residues and avg pLDDT > 0.7")
	with open('plddt_dataset.json', 'w') as fileout:
		fileout.write(json.dumps(valid_indices))
else:
	with open('plddt_dataset.json') as fileout:
		valid_indices = json.load(fileout)

if len(valid_indices) > 0:
	# Select random protein from valid ones
	selected_idx = random.choice(valid_indices)
	selected_protein = struct_dat[selected_idx]
	
	print(f"\n{'='*60}")
	print(f"SELECTED PROTEIN:")
	print(f"  Identifier: {selected_protein.identifier}")
	print(f"  Number of residues: {selected_protein['AA'].x.shape[0]}")
	print(f"  Average pLDDT: {selected_protein['plddt'].x.mean().item():.3f}")
	print(f"  Dataset index: {selected_idx}")
	print(f"{'='*60}")
else:
	print("No proteins found matching the criteria!")
	selected_protein = None
	selected_idx = None

In [None]:
from Bio import PDB
from Bio.PDB import PDBParser
from foldtree2.src.AFDB_tools import grab_struct

def getCAatoms(pdb_file):
	parser = PDBParser(QUIET=True)
	structure = parser.get_structure('structure', pdb_file)
	ca_atoms = []
	for model in structure:
		for chain in model:
			for residue in chain:
				if 'CA' in residue and PDB.is_aa(residue):
					ca_atoms.append(residue['CA'])
	return ca_atoms

In [None]:
from torch_geometric.data import DataLoader, HeteroData
from scipy import sparse
from matplotlib import pyplot as plt
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score

def get_backbone(naa):
	backbone_mat = np.zeros((naa, naa))
	backbone_rev_mat = np.zeros((naa, naa))
	np.fill_diagonal(backbone_mat[1:], 1)
	np.fill_diagonal(backbone_rev_mat[:, 1:], 1)
	return backbone_mat, backbone_rev_mat

def sparse2pairs(sparsemat):
	sparsemat = sparse.find(sparsemat)
	return np.vstack([sparsemat[0], sparsemat[1]])

def decoder_reconstruction2aa(ords, device, verbose=False):
	decoder.eval()
	if verbose:
		print(ords)
	
	z = encoder.vector_quantizer.embeddings(ords).to('cpu')
	
	edge_index = torch.tensor([[i, j] for i in range(z.shape[0]) for j in range(z.shape[0])], dtype=torch.long).T
	godnode_index = np.vstack([np.zeros(z.shape[0]), [i for i in range(z.shape[0])]])
	godnode_rev = np.vstack([[i for i in range(z.shape[0])], np.zeros(z.shape[0])])
	
	# Generate a backbone for the decoder
	data = HeteroData()
	
	data['res'].x = z
	backbone, backbone_rev = get_backbone(z.shape[0])
	backbone = sparse.csr_matrix(backbone)
	backbone_rev = sparse.csr_matrix(backbone_rev)
	backbone = sparse2pairs(backbone)
	backbone_rev = sparse2pairs(backbone_rev)
	positional_encoding = converter.get_positional_encoding(z.shape[0], 256)
	
	if verbose:
		print('positional encoding shape:', positional_encoding.shape)
	
	data['res'].batch = torch.tensor([0 for i in range(z.shape[0])], dtype=torch.long)
	data['positions'].x = torch.tensor(positional_encoding, dtype=torch.float32)
	data['res', 'backbone', 'res'].edge_index = torch.tensor(backbone, dtype=torch.long)
	data['res', 'backbone_rev', 'res'].edge_index = torch.tensor(backbone_rev, dtype=torch.long)
	
	if verbose:
		print(data['res'].x.shape)
	
	# Add the godnode
	data['godnode'].x = torch.tensor(np.ones((1, 5)), dtype=torch.float32)
	data['godnode4decoder'].x = torch.tensor(np.ones((1, 5)), dtype=torch.float32)
	data['godnode4decoder', 'informs', 'res'].edge_index = torch.tensor(godnode_index, dtype=torch.long)
	data['res', 'informs', 'godnode4decoder'].edge_index = torch.tensor(godnode_rev, dtype=torch.long)
	data['res', 'informs', 'godnode'].edge_index = torch.tensor(godnode_rev, dtype=torch.long)
	edge_index = edge_index.to(device)
	
	if verbose:
		print(data)
	
	data = data.to(device)
	allpairs = torch.tensor([[i, j] for i in range(z.shape[0]) for j in range(z.shape[0])], dtype=torch.long).T
	out = decoder(data, allpairs)
	recon_x = out['aa'] if 'aa' in out else None
	edge_probs = out['edge_probs'] if 'edge_probs' in out else None
	logits = out['edge_logits'] if 'edge_logits' in out else None

	if verbose and edge_probs is not None:
		print(edge_probs.shape)
	
	aastr = None
	
	if edge_probs is not None:
		edge_probs = edge_probs.reshape((z.shape[0], z.shape[0]))
	if logits is not None:
		logits = torch.sum(logits, dim=1).squeeze()
		logits = logits.reshape((z.shape[0], z.shape[0]))
	
	return aastr, edge_probs, logits, out

## Visualization Functions

These functions provide comprehensive visualization of encoder-decoder performance:
- Contact map predictions vs true contacts
- ROC curves and precision-recall analysis
- Embedding sequence visualization
- Bond angle and secondary structure predictions

In [None]:
def plot_logits_sequence_on_ax(selected_indices, num_embeddings, ax):
    """
    Visualize embedding sequence as colored bands
    
    Args:
        selected_indices: Discrete embedding indices
        num_embeddings: Total number of embeddings
        ax: Matplotlib axis to plot on
    """
    from colour import Color
    
    # Generate color gradient
    ord_colors = Color("red").range_to(Color("blue"), num_embeddings)
    ord_colors = np.array([c.get_rgb() for c in ord_colors])
    
    # Map indices to colors
    sequence_colors = ord_colors[selected_indices.cpu().numpy()]
    
    # Create canvas with wrapping for long sequences
    max_width = 64
    seq_len = len(sequence_colors)
    rows = int(np.ceil(seq_len / max_width))
    canvas = np.ones((rows, max_width, 3))
    
    for i in range(rows):
        start = i * max_width
        end = min((i + 1) * max_width, seq_len)
        row_colors = sequence_colors[start:end]
        canvas[i, :len(row_colors), :] = row_colors
    
    ax.imshow(canvas, aspect='auto')
    ax.set_title('FT2 Alphabet State Sequence')
    ax.axis('off')

In [None]:
def visualize_decoder_reconstruction(encoder, decoder, data_sample, device, num_embeddings, converter, epoch=None, save_path=None):
    """
    Comprehensive visualization of decoder reconstruction performance
    
    Creates a 3x3 subplot grid showing:
    - Row 1: Contact maps and distance matrices
    - Row 2: ROC curves, precision-recall curves, correlation plots
    - Row 3: Bond angles/SS prediction, edge logits, embedding sequence
    
    Args:
        encoder: Trained encoder model
        decoder: Trained decoder model
        data_sample: Input data sample
        device: PyTorch device
        num_embeddings: Number of discrete embeddings
        converter: PDB2PyG converter
        epoch: Current epoch (optional, for title)
        save_path: Path to save figure (optional)
    
    Returns:
        fig: Matplotlib figure
        metrics_dict: Dictionary of computed metrics
        sample_out: Decoder outputs
    """
    from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score
    from scipy.stats import pearsonr
    
    encoder.eval()
    decoder.eval()
    
    with torch.no_grad():
        # Forward pass through encoder
        z, vqloss = encoder(data_sample)
        
        # Get discrete indices
        selected_indices = encoder.vector_quantizer.discretize_z(z.detach())[0]
        
        # Reconstruct using decoder
        aastr, edge_probs, logits, sample_out = decoder_reconstruction2aa(selected_indices, device)
    
    # Create figure
    fig, axs = plt.subplots(3, 3, figsize=(18, 15))
    
    # Title
    epoch_str = f"Epoch {epoch} - " if epoch is not None else ""
    fig.suptitle(f'{epoch_str}Decoder Reconstruction Analysis', fontsize=16, y=0.995)
    
    # ============================================================================
    # ROW 1: CONTACT MAPS AND DISTANCE MATRICES
    # ============================================================================
    
    # Plot 1: True contact map
    true_contacts = data_sample.edge_index_dict[('res', 'contactPoints', 'res')].cpu().numpy()
    naa = data_sample['AA'].x.shape[0]
    true_contact_mat = np.zeros((naa, naa))
    true_contact_mat[true_contacts[0], true_contacts[1]] = 1
    
    im0 = axs[0, 0].imshow(true_contact_mat, cmap='hot', interpolation='nearest')
    axs[0, 0].set_title('True Contacts')
    axs[0, 0].set_xlabel('Residue Index')
    axs[0, 0].set_ylabel('Residue Index')
    plt.colorbar(im0, ax=axs[0, 0], fraction=0.046, pad=0.04)
    
    # Plot 2: Predicted contact probabilities
    im1 = axs[0, 1].imshow(edge_probs, cmap='hot', interpolation='nearest')
    axs[0, 1].set_title('Predicted Contact Probabilities')
    axs[0, 1].set_xlabel('Residue Index')
    axs[0, 1].set_ylabel('Residue Index')
    plt.colorbar(im1, ax=axs[0, 1], fraction=0.046, pad=0.04)
    
    # Plot 3: True distance matrix
    coords = data_sample['coords'].x.cpu().numpy()
    distance_matrix = np.sqrt(((coords[:, None, :] - coords[None, :, :]) ** 2).sum(axis=2))
    
    im2 = axs[0, 2].imshow(distance_matrix, cmap='viridis', interpolation='nearest')
    axs[0, 2].set_title('True Distance Matrix (Å)')
    axs[0, 2].set_xlabel('Residue Index')
    axs[0, 2].set_ylabel('Residue Index')
    plt.colorbar(im2, ax=axs[0, 2], fraction=0.046, pad=0.04)
    
    # ============================================================================
    # ROW 2: ROC CURVES, PRECISION-RECALL, CORRELATION
    # ============================================================================
    
    # Flatten matrices for ROC/PR curves
    true_flat = true_contact_mat.flatten()
    pred_flat = edge_probs.flatten()
    
    # Plot 4: ROC Curve
    fpr, tpr, thresholds = roc_curve(true_flat, pred_flat)
    roc_auc = auc(fpr, tpr)
    
    axs[1, 0].plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC (AUC = {roc_auc:.3f})')
    axs[1, 0].plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random')
    axs[1, 0].set_xlim([0.0, 1.0])
    axs[1, 0].set_ylim([0.0, 1.05])
    axs[1, 0].set_xlabel('False Positive Rate')
    axs[1, 0].set_ylabel('True Positive Rate')
    axs[1, 0].set_title('ROC Curve - Contact Prediction')
    axs[1, 0].legend(loc="lower right")
    axs[1, 0].grid(alpha=0.3)
    
    # Plot 5: Precision-Recall Curve
    precision, recall, _ = precision_recall_curve(true_flat, pred_flat)
    avg_precision = average_precision_score(true_flat, pred_flat)
    
    axs[1, 1].plot(recall, precision, color='blue', lw=2, label=f'PR (AP = {avg_precision:.3f})')
    axs[1, 1].set_xlim([0.0, 1.0])
    axs[1, 1].set_ylim([0.0, 1.05])
    axs[1, 1].set_xlabel('Recall')
    axs[1, 1].set_ylabel('Precision')
    axs[1, 1].set_title('Precision-Recall Curve')
    axs[1, 1].legend(loc="lower left")
    axs[1, 1].grid(alpha=0.3)
    
    # Plot 6: Correlation plot - predicted vs true distances
    # Use contact probabilities vs true distances (inverted)
    contact_threshold = 15.0  # Ångströms
    true_distances_contacts = distance_matrix[distance_matrix < contact_threshold]
    pred_probs_contacts = edge_probs[distance_matrix < contact_threshold]
    
    if len(true_distances_contacts) > 0:
        correlation, _ = pearsonr(true_distances_contacts, pred_probs_contacts)
        axs[1, 2].scatter(true_distances_contacts, pred_probs_contacts, alpha=0.5, s=10)
        axs[1, 2].set_xlabel('True Distance (Å)')
        axs[1, 2].set_ylabel('Predicted Contact Probability')
        axs[1, 2].set_title(f'Distance vs Contact Prob (r = {correlation:.3f})')
        axs[1, 2].grid(alpha=0.3)
    else:
        axs[1, 2].text(0.5, 0.5, 'No contacts < 15Å', ha='center', va='center', transform=axs[1, 2].transAxes)
        axs[1, 2].set_title('Distance vs Contact Prob')
    
    # ============================================================================
    # ROW 3: BOND ANGLES/SS, EDGE LOGITS, EMBEDDING SEQUENCE
    # ============================================================================
    
    # Plot 7: Bond angles and secondary structure
    if 'angles' in sample_out and sample_out['angles'] is not None:
        angles = sample_out['angles'].cpu().numpy()
        true_angles = data_sample['bondangles'].x.cpu().numpy()
        
        # Plot predicted and true angles
        for i, angle_name in enumerate(['Phi', 'Psi', 'Omega']):
            axs[2, 0].plot(angles[:, i], label=f'Pred {angle_name}', alpha=0.7, linestyle='--')
            axs[2, 0].plot(true_angles[:, i], label=f'True {angle_name}', alpha=0.7)
        
        axs[2, 0].set_xlabel('Residue Index')
        axs[2, 0].set_ylabel('Angle (radians)')
        axs[2, 0].set_title('Bond Angles Prediction')
        axs[2, 0].legend(loc='upper right', fontsize=8)
        axs[2, 0].grid(alpha=0.3)
    else:
        axs[2, 0].text(0.5, 0.5, 'No angles predicted', ha='center', va='center', transform=axs[2, 0].transAxes)
        axs[2, 0].set_title('Bond Angles Prediction')
    
    # Plot 8: Edge logits heatmap
    if logits is not None and logits.size > 0:
        im7 = axs[2, 1].imshow(logits, cmap='hot', interpolation='nearest')
        axs[2, 1].set_title('Edge Logits')
        axs[2, 1].set_xlabel('Residue Index')
        axs[2, 1].set_ylabel('Residue Index')
        plt.colorbar(im7, ax=axs[2, 1], fraction=0.046, pad=0.04)
    else:
        axs[2, 1].text(0.5, 0.5, 'No edge logits', ha='center', va='center', transform=axs[2, 1].transAxes)
        axs[2, 1].set_title('Edge Logits')
    
    # Plot 9: Embedding sequence visualization
    plot_logits_sequence_on_ax(selected_indices, num_embeddings, axs[2, 2])
    
    plt.tight_layout()
    
    if save_path:
        fig.savefig(save_path, bbox_inches='tight', dpi=150)
        print(f"Figure saved to {save_path}")
    
    # Compute metrics
    metrics_dict = {
        'roc_auc': roc_auc,
        'avg_precision': avg_precision,
        'num_residues': naa,
        'num_true_contacts': true_contact_mat.sum(),
        'vq_loss': float(vqloss) if isinstance(vqloss, torch.Tensor) else vqloss
    }
    
    if len(true_distances_contacts) > 0:
        metrics_dict['distance_correlation'] = correlation
    
    encoder.train()
    decoder.train()
    
    return fig, metrics_dict, sample_out

## Main Training Loop

The training loop performs the following:
1. Forward pass through encoder and decoder
2. Compute multi-task losses (AA sequence, contacts, angles, secondary structure)
3. Apply gradient accumulation and clipping
4. Update learning rate according to schedule
5. Generate visualizations and save checkpoints every N epochs

In [None]:
import tqdm.notebook
# Reset dataloader with configured batch size
train_loader = DataLoader(struct_dat, batch_size=batch_size, shuffle=True, num_workers=num_workers)

encoder.train()
decoder.train()

figurestack = []
metrics_history = []

print(f"Starting training for {num_epochs} epochs...")
print(f"Batch size: {batch_size}")
print(f"Gradient accumulation steps: {gradient_accumulation_steps}")
print(f"Effective batch size: {batch_size * gradient_accumulation_steps}")
print(f"Steps per epoch: {len(train_loader)}")
print()

for epoch in range(num_epochs):
    total_loss_x = 0
    total_loss_edge = 0
    total_vq = 0
    total_angles_loss = 0
    total_loss_fft2 = 0
    total_logit_loss = 0
    total_ss_loss = 0
    
    for batch_idx, data in enumerate(tqdm.notebook.tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")):
        data = data.to(device)
        
        # Forward pass with autocast for mixed precision
        if use_mixed_precision:
            with autocast():
                z, vqloss = encoder(data)
                data['res'].x = z
                
                # Forward pass through decoder
                out = decoder(data, None)
                edge_index = data.edge_index_dict.get(('res', 'contactPoints', 'res')) if hasattr(data, 'edge_index_dict') else None

                # Edge reconstruction loss
                logitloss = torch.tensor(0.0, device=device)
                edgeloss = torch.tensor(0.0, device=device)
                if edge_index is not None:
                    edgeloss, logitloss = recon_loss_diag(data, edge_index, decoder, plddt=mask_plddt, key='edge_probs')
                
                # Amino acid reconstruction loss
                xloss = aa_reconstruction_loss(data['AA'].x, out['aa'])
                
                # FFT2 loss
                fft2loss = torch.tensor(0.0, device=device)
                if 'fft2pred' in out and out['fft2pred'] is not None:
                    fft2loss = F.smooth_l1_loss(torch.cat([data['fourier2dr'].x, data['fourier2di'].x], axis=1), out['fft2pred'])

                # Angles loss
                angles_loss = torch.tensor(0.0, device=device)
                if out.get('angles') is not None:
                    angles_loss = angles_reconstruction_loss(out['angles'], data['bondangles'].x, plddt_mask=data['plddt'].x if mask_plddt else None)

                # Secondary structure loss
                ss_loss = torch.tensor(0.0, device=device)
                if out.get('ss_pred') is not None:
                    if mask_plddt:
                        mask = (data['plddt'].x >= plddt_threshold).squeeze()
                        if mask.sum() > 0:
                            ss_loss = F.cross_entropy(out['ss_pred'][mask], data['ss'].x[mask])
                    else:
                        ss_loss = F.cross_entropy(out['ss_pred'], data['ss'].x)
                
                # Total loss
                loss = (xweight * xloss + edgeweight * edgeloss + vqweight * vqloss + 
                        fft2weight * fft2loss + angles_weight * angles_loss + 
                        ss_weight * ss_loss + logitweight * logitloss)
                
                # Scale loss by gradient accumulation steps
                loss = loss / gradient_accumulation_steps
        else:
            # Non-mixed precision path
            z, vqloss = encoder(data)
            data['res'].x = z
            
            out = decoder(data, None)
            edge_index = data.edge_index_dict.get(('res', 'contactPoints', 'res')) if hasattr(data, 'edge_index_dict') else None

            logitloss = torch.tensor(0.0, device=device)
            edgeloss = torch.tensor(0.0, device=device)
            if edge_index is not None:
                edgeloss, logitloss = recon_loss_diag(data, edge_index, decoder, plddt=mask_plddt, key='edge_probs')
            
            xloss = aa_reconstruction_loss(data['AA'].x, out['aa'])
            
            fft2loss = torch.tensor(0.0, device=device)
            if 'fft2pred' in out and out['fft2pred'] is not None:
                fft2loss = F.smooth_l1_loss(torch.cat([data['fourier2dr'].x, data['fourier2di'].x'], axis=1), out['fft2pred'])

            angles_loss = torch.tensor(0.0, device=device)
            if out.get('angles') is not None:
                angles_loss = angles_reconstruction_loss(out['angles'], data['bondangles'].x, plddt_mask=data['plddt'].x if mask_plddt else None)

            ss_loss = torch.tensor(0.0, device=device)
            if out.get('ss_pred') is not None:
                if mask_plddt:
                    mask = (data['plddt'].x >= plddt_threshold).squeeze()
                    if mask.sum() > 0:
                        ss_loss = F.cross_entropy(out['ss_pred'][mask], data['ss'].x[mask])
                else:
                    ss_loss = F.cross_entropy(out['ss_pred'], data['ss'].x)

            loss = (xweight * xloss + edgeweight * edgeloss + vqweight * vqloss + 
                    fft2weight * fft2loss + angles_weight * angles_loss + 
                    ss_weight * ss_loss + logitweight * logitloss)
            
            loss = loss / gradient_accumulation_steps
        
        # Backward pass with gradient scaling
        if use_mixed_precision:
            scaler.scale(loss).backward()
        else:
            loss.backward()
        
        # Only update weights every gradient_accumulation_steps
        if (batch_idx + 1) % gradient_accumulation_steps == 0:
            if clip_grad:
                if use_mixed_precision:
                    scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(encoder.parameters(), max_norm=1.0)
                torch.nn.utils.clip_grad_norm_(decoder.parameters(), max_norm=1.0)
            
            # Step optimizer with scaler
            if use_mixed_precision:
                scaler.step(optimizer)
                scaler.update()
            else:
                optimizer.step()
            optimizer.zero_grad()
            
            # Step scheduler if it's a step-based scheduler
            if scheduler_step_mode == 'step':
                scheduler.step()
        
        # Accumulate losses (unscaled for reporting)
        total_loss_x += xloss.item()
        total_logit_loss += logitloss.item()
        total_loss_edge += edgeloss.item()
        total_loss_fft2 += fft2loss.item()
        total_angles_loss += angles_loss.item()
        total_vq += vqloss.item() if isinstance(vqloss, torch.Tensor) else float(vqloss)
        total_ss_loss += ss_loss.item()
    
    # Clean up any remaining gradients at epoch end
    if len(train_loader) % gradient_accumulation_steps != 0:
        if clip_grad:
            if use_mixed_precision:
                scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(encoder.parameters(), max_norm=1.0)
            torch.nn.utils.clip_grad_norm_(decoder.parameters(), max_norm=1.0)
        if use_mixed_precision:
            scaler.step(optimizer)
            scaler.update()
        else:
            optimizer.step()
        optimizer.zero_grad()
    
    # Update learning rate for epoch-based schedulers
    if scheduler_step_mode == 'epoch':
        if scheduler_type == 'plateau':
            scheduler.step(total_loss_x)
        else:
            scheduler.step()
    
    # Get current learning rate for logging
    current_lr = optimizer.param_groups[0]['lr']
    
    # Compute average losses
    avg_losses = {
        'aa_loss': total_loss_x / len(train_loader),
        'edge_loss': total_loss_edge / len(train_loader),
        'vq_loss': total_vq / len(train_loader),
        'fft2_loss': total_loss_fft2 / len(train_loader),
        'angles_loss': total_angles_loss / len(train_loader),
        'ss_loss': total_ss_loss / len(train_loader),
        'logit_loss': total_logit_loss / len(train_loader)
    }
    
    # Print epoch summary
    print(f"\n{'='*80}")
    print(f"Epoch {epoch+1}/{num_epochs} | Learning Rate: {current_lr:.2e}")
    print(f"{'='*80}")
    print(f"  AA Loss:     {avg_losses['aa_loss']:.4f}")
    print(f"  Edge Loss:   {avg_losses['edge_loss']:.4f}")
    print(f"  VQ Loss:     {avg_losses['vq_loss']:.4f}")
    print(f"  FFT2 Loss:   {avg_losses['fft2_loss']:.4f}")
    print(f"  Angles Loss: {avg_losses['angles_loss']:.4f}")
    print(f"  SS Loss:     {avg_losses['ss_loss']:.4f}")
    print(f"  Logit Loss:  {avg_losses['logit_loss']:.4f}")
    
    # Save checkpoints and visualize every 10 epochs
    if (epoch + 1) % 10 == 0:
        print(f"\n{'─'*80}")
        print(f"Saving checkpoint at epoch {epoch+1}...")
        print(f"{'─'*80}")
        
        # Save models
        os.makedirs('models', exist_ok=True)
        torch.save(encoder.state_dict(), f'models/notebook_encoder_epoch_{epoch+1}.pt')
        torch.save(decoder.state_dict(), f'models/notebook_decoder_epoch_{epoch+1}.pt')
        
        # Generate visualization
        print(f"Generating reconstruction visualization...")
        os.makedirs('figures', exist_ok=True)
        
        # Use selected protein or random sample
        viz_sample = selected_protein if selected_protein is not None else data_sample
        
        fig, metrics, sample_output = visualize_decoder_reconstruction(
            encoder, decoder, viz_sample, device, num_embeddings, 
            converter, epoch=epoch+1, save_path=f'figures/reconstruction_epoch_{epoch+1}.png'
        )
        
        figurestack.append(fig)
        
        # Add epoch info to metrics
        metrics['epoch'] = epoch + 1
        metrics.update(avg_losses)
        metrics_history.append(metrics)
        
        # Print reconstruction metrics
        print(f"\nReconstruction Metrics:")
        print(f"  ROC AUC:           {metrics['roc_auc']:.4f}")
        print(f"  Average Precision: {metrics['avg_precision']:.4f}")
        if 'distance_correlation' in metrics:
            print(f"  Distance Corr:     {metrics['distance_correlation']:.4f}")
        print(f"  Num Residues:      {metrics['num_residues']}")
        print(f"  True Contacts:     {metrics['num_true_contacts']:.0f}")
        
        plt.show()
        print()

print("\n" + "="*80)
print("Training Complete!")
print("="*80)
print(f"Generated {len(figurestack)} visualization figures")
print(f"\nFinal Training Summary:")
print(f"{'─'*80}")
for m in metrics_history[-5:]:  # Show last 5 checkpoints
    print(f"Epoch {m['epoch']:3d}: ROC AUC={m['roc_auc']:.4f}, AP={m['avg_precision']:.4f}")


## Training Metrics Visualization

This cell generates comprehensive plots showing how all metrics evolved during training:
- ROC AUC and Average Precision
- All loss components over time
- Summary table of key metrics

In [None]:
if len(metrics_history) > 0:
    fig, axes = plt.subplots(3, 3, figsize=(18, 14))
    fig.suptitle('Training Metrics Over Time', fontsize=16, fontweight='bold')
    
    # Extract data
    epochs = [m['epoch'] for m in metrics_history]
    roc_auc = [m['roc_auc'] for m in metrics_history]
    avg_precision = [m['avg_precision'] for m in metrics_history]
    aa_loss = [m['aa_loss'] for m in metrics_history]
    edge_loss = [m['edge_loss'] for m in metrics_history]
    vq_loss = [m['vq_loss'] for m in metrics_history]
    angles_loss = [m['angles_loss'] for m in metrics_history]
    ss_loss = [m['ss_loss'] for m in metrics_history]
    logit_loss = [m['logit_loss'] for m in metrics_history]
    
    # ROC AUC
    axes[0, 0].plot(epochs, roc_auc, 'b-o', linewidth=2, markersize=6)
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('ROC AUC')
    axes[0, 0].set_title('ROC AUC Score')
    axes[0, 0].grid(True, alpha=0.3)
    
    # Average Precision
    axes[0, 1].plot(epochs, avg_precision, 'g-o', linewidth=2, markersize=6)
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Average Precision')
    axes[0, 1].set_title('Average Precision Score')
    axes[0, 1].grid(True, alpha=0.3)
    
    # Distance Correlation (if available)
    if 'distance_correlation' in metrics_history[0]:
        distance_corr = [m.get('distance_correlation', 0) for m in metrics_history]
        axes[0, 2].plot(epochs, distance_corr, 'purple', marker='o', linewidth=2, markersize=6)
        axes[0, 2].set_xlabel('Epoch')
        axes[0, 2].set_ylabel('Correlation')
        axes[0, 2].set_title('Distance Correlation')
        axes[0, 2].grid(True, alpha=0.3)
    else:
        axes[0, 2].text(0.5, 0.5, 'No correlation data', ha='center', va='center', transform=axes[0, 2].transAxes)
    
    # AA Loss
    axes[1, 0].plot(epochs, aa_loss, 'orange', marker='o', linewidth=2, markersize=6)
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Loss')
    axes[1, 0].set_title('Amino Acid Reconstruction Loss')
    axes[1, 0].grid(True, alpha=0.3)
    
    # Edge Loss
    axes[1, 1].plot(epochs, edge_loss, 'cyan', marker='o', linewidth=2, markersize=6)
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Loss')
    axes[1, 1].set_title('Contact Prediction Loss')
    axes[1, 1].grid(True, alpha=0.3)
    
    # VQ Loss
    axes[1, 2].plot(epochs, vq_loss, 'magenta', marker='o', linewidth=2, markersize=6)
    axes[1, 2].set_xlabel('Epoch')
    axes[1, 2].set_ylabel('Loss')
    axes[1, 2].set_title('Vector Quantization Loss')
    axes[1, 2].grid(True, alpha=0.3)
    
    # Angles Loss
    axes[2, 0].plot(epochs, angles_loss, 'red', marker='o', linewidth=2, markersize=6)
    axes[2, 0].set_xlabel('Epoch')
    axes[2, 0].set_ylabel('Loss')
    axes[2, 0].set_title('Bond Angles Reconstruction Loss')
    axes[2, 0].grid(True, alpha=0.3)
    
    # SS Loss
    axes[2, 1].plot(epochs, ss_loss, 'brown', marker='o', linewidth=2, markersize=6)
    axes[2, 1].set_xlabel('Epoch')
    axes[2, 1].set_ylabel('Loss')
    axes[2, 1].set_title('Secondary Structure Prediction Loss')
    axes[2, 1].grid(True, alpha=0.3)
    
    # Logit Loss
    axes[2, 2].plot(epochs, logit_loss, 'green', marker='o', linewidth=2, markersize=6)
    axes[2, 2].set_xlabel('Epoch')
    axes[2, 2].set_ylabel('Loss')
    axes[2, 2].set_title('Edge Logit Loss')
    axes[2, 2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('figures/training_metrics_summary.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"\nSaved training metrics summary to figures/training_metrics_summary.png")
else:
    print("No metrics history available. Train for at least one checkpoint epoch.")

## Save Training Summary

Save a comprehensive JSON summary of the training run for future analysis and reproducibility.

In [None]:
import json
from datetime import datetime

# Create comprehensive training summary
training_summary = {
    'timestamp': datetime.now().isoformat(),
    'config_file': config_file if config_file else None,
    'hyperparameters': {
        'num_epochs': num_epochs,
        'batch_size': batch_size,
        'gradient_accumulation_steps': gradient_accumulation_steps,
        'learning_rate': learning_rate,
        'num_embeddings': num_embeddings,
        'embedding_dim': embedding_dim,
        'hidden_size': hidden_size,
        'use_mixed_precision': use_mixed_precision,
        'mask_plddt': mask_plddt,
        'plddt_threshold': plddt_threshold if mask_plddt else None,
        'use_muon': use_muon,
        'muon_lr': muon_lr if use_muon else None,
        'adamw_lr': adamw_lr if use_muon else None,
        'use_commitment_scheduling': use_commitment_scheduling,
        'commitment_cost_final': commitment_cost_final,
        'commitment_warmup_steps': commitment_warmup_steps if use_commitment_scheduling else None,
        'scheduler_type': scheduler_type,
        'warmup_steps': warmup_steps,
    },
    'loss_weights': {
        'edgeweight': edgeweight,
        'logitweight': logitweight,
        'xweight': xweight,
        'fft2weight': fft2weight,
        'vqweight': vqweight,
        'angles_weight': angles_weight,
        'ss_weight': ss_weight,
    },
    'dataset': {
        'path': dataset_path,
        'total_samples': len(struct_dat),
    },
    'metrics_history': metrics_history,
    'final_metrics': metrics_history[-1] if metrics_history else None,
}

# Save summary to JSON
os.makedirs('models', exist_ok=True)
summary_filename = f"models/training_summary_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
with open(summary_filename, 'w') as f:
    json.dump(training_summary, f, indent=2)

print(f"Training summary saved to: {summary_filename}")
print(f"\nFinal Performance:")
if metrics_history:
    final = metrics_history[-1]
    print(f"  Epoch: {final['epoch']}")
    print(f"  ROC AUC: {final['roc_auc']:.4f}")
    print(f"  Average Precision: {final['avg_precision']:.4f}")
    print(f"  AA Loss: {final['aa_loss']:.4f}")
    print(f"  Edge Loss: {final['edge_loss']:.4f}")
    print(f"  VQ Loss: {final['vq_loss']:.4f}")
else:
    print("  No metrics available")