In [1]:
# Use autoreload for development
%load_ext autoreload
%autoreload 2

In [2]:
# Core imports
import torch
import torch.nn.functional as F
import numpy as np
import random
import os
import tqdm
from collections import defaultdict

# PyTorch Lightning imports
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger

# Project imports
from torch_geometric.data import DataLoader
from foldtree2.src import pdbgraph
from foldtree2.src import encoder as ecdr
from foldtree2.src.mono_decoders import MultiMonoDecoder
from foldtree2.src.losses.losses import recon_loss_diag, aa_reconstruction_loss, angles_reconstruction_loss

# Visualization imports
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score

print("Imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"PyTorch Lightning version: {pl.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
#if torch.cuda.is_available():
#    print(f"Number of GPUs: {torch.cuda.device_count()}")
#    for i in range(torch.cuda.device_count()):
#        print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")

Imports successful!
PyTorch version: 2.8.0+cu128
PyTorch Lightning version: 2.6.0
CUDA available: True


In [3]:
# Configuration parameters
config = {
    # Data parameters
    'dataset_path': 'structs_train_final.h5',
    'batch_size': 15,
    'num_workers': 4,
    
    # Model architecture parameters
    'num_embeddings': 40,
    'embedding_dim': 256,
    'hidden_size': 256,
    
    # Training parameters
    'num_epochs': 300,
    'learning_rate': 5e-5,
    'gradient_accumulation_steps': 1,
    'clip_grad': True,
    'mask_plddt': True,
    'plddt_threshold': 0.3,
    
    # Scheduler parameters
    'scheduler_type': 'plateau',  # Options: 'plateau', 'linear', 'cosine', 'cosine_with_restarts', 'polynomial'
    'warmup_steps': 20,
    'warmup_ratio': 0.05,
    
    # Loss weights
    'edgeweight': 0.1,
    'logitweight': 0.1,
    'xweight': 1.0,
    'fft2weight': 0.01,
    'vqweight': 0.1,
    'angles_weight': 0.1,
    'ss_weight': 0.1,
    
    # Optimizer parameters
    'use_muon': True,
    'muon_lr': 0.02,
    'adamw_lr': 1e-4,
    'weight_decay': 0.01,
    
    # Multi-GPU parameters
    'accelerator': 'gpu',  # 'gpu', 'cpu', or 'auto'
    'devices': 1,  # Use 1 GPU in Jupyter to avoid multiprocessing issues
    'strategy': 'auto',  # 'auto' selects best strategy for single GPU
    'precision': '16-mixed',  # Use mixed precision training
    
    # Logging parameters
    'log_every_n_steps': 10,
    'val_check_interval': 1.0,  # Validate every epoch
    'save_top_k': 3,  # Save top 3 checkpoints
}

print("Configuration:")
for key, value in config.items():
    print(f"  {key}: {value}")

Configuration:
  dataset_path: structs_train_final.h5
  batch_size: 15
  num_workers: 4
  num_embeddings: 40
  embedding_dim: 256
  hidden_size: 256
  num_epochs: 300
  learning_rate: 5e-05
  gradient_accumulation_steps: 1
  clip_grad: True
  mask_plddt: True
  plddt_threshold: 0.3
  scheduler_type: plateau
  warmup_steps: 20
  warmup_ratio: 0.05
  edgeweight: 0.1
  logitweight: 0.1
  xweight: 1.0
  fft2weight: 0.01
  vqweight: 0.1
  angles_weight: 0.1
  ss_weight: 0.1
  use_muon: True
  muon_lr: 0.02
  adamw_lr: 0.0001
  weight_decay: 0.01
  accelerator: gpu
  devices: 1
  strategy: auto
  precision: 16-mixed
  log_every_n_steps: 10
  val_check_interval: 1.0
  save_top_k: 3


In [4]:
cd /home/dmoi/projects/foldtree2/

/home/dmoi/projects/foldtree2


In [5]:
# Initialize data converter and dataset
converter = pdbgraph.PDB2PyG(aapropcsv='./foldtree2/config/aaindex1.csv')
struct_dat = pdbgraph.StructureDataset(config['dataset_path'])

# Get sample data for model initialization
temp_loader = DataLoader(struct_dat, batch_size=5, shuffle=False, num_workers=1)
data_sample = next(iter(temp_loader))

ndim = data_sample['res'].x.shape[1]
ndim_godnode = data_sample['godnode'].x.shape[1]
ndim_fft2i = data_sample['fourier2di'].x.shape[1]
ndim_fft2r = data_sample['fourier2dr'].x.shape[1]

print(f"Dataset loaded: {len(struct_dat)} structures")
print(f"Feature dimensions:")
print(f"  Residue features: {ndim}")
print(f"  Godnode features: {ndim_godnode}")
print(f"  FFT2 (imaginary): {ndim_fft2i}")
print(f"  FFT2 (real): {ndim_fft2r}")
print(f"\nSample data structure:")
print(data_sample)



Dataset loaded: 4999 structures
Feature dimensions:
  Residue features: 857
  Godnode features: 5
  FFT2 (imaginary): 1300
  FFT2 (real): 1300

Sample data structure:
HeteroDataBatch(
  identifier=[5],
  AA={
    x=[1405, 20],
    batch=[1405],
    ptr=[6],
  },
  R_true={
    x=[1405, 3, 3],
    batch=[1405],
    ptr=[6],
  },
  bondangles={
    x=[1405, 3],
    batch=[1405],
    ptr=[6],
  },
  coords={
    x=[1405, 3],
    batch=[1405],
    ptr=[6],
  },
  fourier1di={
    x=[1405, 80],
    batch=[1405],
    ptr=[6],
  },
  fourier1dr={
    x=[1405, 80],
    batch=[1405],
    ptr=[6],
  },
  fourier2di={
    x=[5, 1300],
    batch=[5],
    ptr=[6],
  },
  fourier2dr={
    x=[5, 1300],
    batch=[5],
    ptr=[6],
  },
  godnode={
    x=[5, 5],
    batch=[5],
    ptr=[6],
  },
  godnode4decoder={
    x=[5, 5],
    batch=[5],
    ptr=[6],
  },
  plddt={
    x=[1405, 1],
    batch=[1405],
    ptr=[6],
  },
  positions={
    x=[1405, 256],
    batch=[1405],
    ptr=[6],
  },
  res={
    

In [6]:
# Lightning Data Module
class ProteinDataModule(pl.LightningDataModule):
    """PyTorch Lightning DataModule for protein structure data."""
    
    def __init__(self, dataset, batch_size=8, num_workers=4, train_split=0.9):
        super().__init__()
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.train_split = train_split
        
    def setup(self, stage=None):
        """Split dataset into train and validation sets."""
        dataset_size = len(self.dataset)
        train_size = int(self.train_split * dataset_size)
        val_size = dataset_size - train_size
        
        self.train_dataset, self.val_dataset = torch.utils.data.random_split(
            self.dataset, [train_size, val_size],
            generator=torch.Generator().manual_seed(42)
        )
        
        print(f"Dataset split: {train_size} train, {val_size} validation")
    
    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            persistent_workers=True if self.num_workers > 0 else False,
            pin_memory=True
        )
    
    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            persistent_workers=True if self.num_workers > 0 else False,
            pin_memory=True
        )

print("ProteinDataModule defined")

ProteinDataModule defined


In [7]:
# Lightning Module for Protein Structure Model
class ProteinStructureModel(pl.LightningModule):
    """PyTorch Lightning module for encoder-decoder protein structure prediction."""
    
    def __init__(self, config, ndim, ndim_godnode, ndim_fft2i, ndim_fft2r, converter):
        super().__init__()
        self.save_hyperparameters(ignore=['converter'])
        self.config = config
        self.converter = converter
        
        # Initialize encoder
        self.encoder = ecdr.mk1_Encoder(
            in_channels=ndim,
            hidden_channels=[2*config['hidden_size'], 2*config['hidden_size'], 2*config['hidden_size']],
            out_channels=config['embedding_dim'],
            metadata={'edge_types': [('res','contactPoints','res')]},
            num_embeddings=config['num_embeddings'],
            commitment_cost=0.9,
            edge_dim=1,
            encoder_hidden=config['hidden_size'],
            EMA=True,
            nheads=10,
            dropout_p=0.01,
            reset_codes=False,
            flavor='transformer',
            fftin=True,
            use_commitment_scheduling=True,
            commitment_warmup_steps=1000,
            commitment_schedule='cosine_with_restart',
            commitment_start=0.5,
            concat_positions=True
        )
        
        # Initialize decoder (MultiMonoDecoder)
        mono_configs = {
            'sequence_transformer': {
                'in_channels': {'res': config['embedding_dim']},
                'xdim': 20,
                'concat_positions': True,
                'hidden_channels': {
                    ('res','backbone','res'): [config['hidden_size']], 
                    ('res','backbonerev','res'): [config['hidden_size']]
                },
                'layers': 2,
                'AAdecoder_hidden': [config['hidden_size'], config['hidden_size'], config['hidden_size']//2],
                'amino_mapper': converter.aaindex,
                'flavor': 'sage',
                'nheads': 4,
                'dropout': 0.001,
                'normalize': False,
                'residual': False,
                'use_cnn_decoder': True,
                'output_ss': False
            },
            'geometry_cnn': {
                'in_channels': {
                    'res': config['embedding_dim'], 
                    'godnode4decoder': ndim_godnode, 
                    'foldx': 23,
                    'fft2r': ndim_fft2r, 
                    'fft2i': ndim_fft2i
                },
                'concat_positions': False,
                'conv_channels': [config['hidden_size'], config['hidden_size']//2, config['hidden_size']//2],
                'kernel_sizes': [3, 3, 3],
                'FFT2decoder_hidden': [config['hidden_size']//2, config['hidden_size']//2],
                'contactdecoder_hidden': [config['hidden_size']//2, config['hidden_size']//4],
                'ssdecoder_hidden': [config['hidden_size']//2, config['hidden_size']//2],
                'Xdecoder_hidden': [config['hidden_size']//2, config['hidden_size']//4],
                'anglesdecoder_hidden': [config['hidden_size']//2, config['hidden_size']//4],
                'RTdecoder_hidden': [config['hidden_size']//2, config['hidden_size']//4],
                'metadata': converter.metadata,
                'dropout': 0.001,
                'output_fft': False,
                'output_rt': False,
                'output_angles': False,
                'output_ss': True,
                'normalize': True,
                'residual': False,
                'output_edge_logits': True,
                'ncat': 8,
                'contact_mlp': False,
                'pool_type': 'global_mean'
            },
        }
        self.decoder = MultiMonoDecoder(configs=mono_configs)
        
        # Store loss weights
        self.automatic_optimization = False  # Manual optimization for gradient accumulation
        
    def forward(self, data):
        """Forward pass through encoder and decoder."""
        z, vqloss = self.encoder(data)
        data['res'].x = z
        out = self.decoder(data, None)
        return out, vqloss
    
    def compute_losses(self, batch, out, vqloss):
        """Compute all loss components."""
        device = self.device
        
        # Edge reconstruction loss
        edge_index = batch.edge_index_dict.get(('res', 'contactPoints', 'res')) if hasattr(batch, 'edge_index_dict') else None
        logitloss = torch.tensor(0.0, device=device)
        edgeloss = torch.tensor(0.0, device=device)
        if edge_index is not None:
            edgeloss, logitloss = recon_loss_diag(
                batch, edge_index, self.decoder, 
                plddt=self.config['mask_plddt'], 
                key='edge_probs'
            )
        
        # Amino acid reconstruction loss
        xloss = aa_reconstruction_loss(batch['AA'].x, out['aa'])
        
        # FFT2 loss
        fft2loss = torch.tensor(0.0, device=device)
        if 'fft2pred' in out and out['fft2pred'] is not None:
            fft2loss = F.smooth_l1_loss(
                torch.cat([batch['fourier2dr'].x, batch['fourier2di'].x], axis=1), 
                out['fft2pred']
            )
        
        # Angles loss
        angles_loss = torch.tensor(0.0, device=device)
        if out.get('angles') is not None:
            angles_loss = angles_reconstruction_loss(
                out['angles'], 
                batch['bondangles'].x,
                plddt_mask=batch['plddt'].x if self.config['mask_plddt'] else None
            )
        
        # Secondary structure loss
        ss_loss = torch.tensor(0.0, device=device)
        if out.get('ss_pred') is not None:
            if self.config['mask_plddt']:
                mask = (batch['plddt'].x >= self.config['plddt_threshold']).squeeze()
                ss_loss = F.cross_entropy(out['ss_pred'][mask], batch['ss'].x[mask])
            else:
                ss_loss = F.cross_entropy(out['ss_pred'], batch['ss'].x)
        
        # Total loss
        total_loss = (
            self.config['xweight'] * xloss +
            self.config['edgeweight'] * edgeloss +
            self.config['vqweight'] * vqloss +
            self.config['fft2weight'] * fft2loss +
            self.config['angles_weight'] * angles_loss +
            self.config['ss_weight'] * ss_loss +
            self.config['logitweight'] * logitloss
        )
        
        losses = {
            'loss': total_loss,
            'aa_loss': xloss,
            'edge_loss': edgeloss,
            'vq_loss': vqloss if isinstance(vqloss, torch.Tensor) else torch.tensor(vqloss, device=device),
            'fft2_loss': fft2loss,
            'angles_loss': angles_loss,
            'ss_loss': ss_loss,
            'logit_loss': logitloss
        }
        
        return losses
    
    def training_step(self, batch, batch_idx):
        """Training step with manual optimization."""
        opt = self.optimizers()
        
        # Forward pass
        out, vqloss = self(batch)
        
        # Compute losses
        losses = self.compute_losses(batch, out, vqloss)
        
        # Scale loss by gradient accumulation steps
        scaled_loss = losses['loss'] / self.config['gradient_accumulation_steps']
        
        # Manual backward
        self.manual_backward(scaled_loss)
        
        # Update weights every gradient_accumulation_steps
        if (batch_idx + 1) % self.config['gradient_accumulation_steps'] == 0:
            if self.config['clip_grad']:
                self.clip_gradients(opt, gradient_clip_val=1.0, gradient_clip_algorithm="norm")
            
            opt.step()
            opt.zero_grad()
            
            # Step scheduler if using step-based
            sch = self.lr_schedulers()
            if sch is not None and self.config['scheduler_type'] != 'plateau':
                sch.step()
        
        # Log metrics
        for key, value in losses.items():
            self.log(f'train/{key}', value, on_step=True, on_epoch=True, prog_bar=(key == 'loss'), sync_dist=True)
        
        # Log learning rate
        self.log('lr', opt.param_groups[0]['lr'], on_step=True, prog_bar=True, sync_dist=True)
        
        return losses['loss']
    
    def validation_step(self, batch, batch_idx):
        """Validation step."""
        out, vqloss = self(batch)
        losses = self.compute_losses(batch, out, vqloss)
        
        # Log validation metrics
        for key, value in losses.items():
            self.log(f'val/{key}', value, on_epoch=True, prog_bar=(key == 'loss'), sync_dist=True)
        
        return losses['loss']
    
    def on_validation_epoch_end(self):
        """Update scheduler at end of validation epoch if using plateau scheduler."""
        sch = self.lr_schedulers()
        if sch is not None and self.config['scheduler_type'] == 'plateau':
            # Get validation loss
            val_loss = self.trainer.callback_metrics.get('val/loss')
            if val_loss is not None:
                sch.step(val_loss)
    
    def configure_optimizers(self):
        """Configure optimizers and learning rate schedulers."""
        from transformers import (
            get_linear_schedule_with_warmup,
            get_cosine_schedule_with_warmup,
            get_cosine_with_hard_restarts_schedule_with_warmup,
            get_polynomial_decay_schedule_with_warmup
        )
        
        if self.config['use_muon']:
            from muon import MuonWithAuxAdam
            
            # Separate parameters for Muon optimizer
            hidden_weights = []
            hidden_gains_biases = []
            nonhidden_params = []
            
            def has_modular_structure(model):
                return hasattr(model, 'input') and hasattr(model, 'body') and hasattr(model, 'head')
            
            # Process encoder
            if has_modular_structure(self.encoder):
                hidden_weights += [p for p in self.encoder.body.parameters() if p.ndim >= 2]
                hidden_gains_biases += [p for p in self.encoder.body.parameters() if p.ndim < 2]
                nonhidden_params += [*self.encoder.head.parameters(), *self.encoder.input.parameters()]
            else:
                nonhidden_params += list(self.encoder.parameters())
            
            # Process decoder
            if hasattr(self.decoder, 'decoders'):
                for name, subdecoder in self.decoder.decoders.items():
                    if has_modular_structure(subdecoder):
                        hidden_weights += [p for p in subdecoder.body.parameters() if p.ndim >= 2]
                        hidden_gains_biases += [p for p in subdecoder.body.parameters() if p.ndim < 2]
                        nonhidden_params += [*subdecoder.head.parameters(), *subdecoder.input.parameters()]
                    else:
                        nonhidden_params += list(subdecoder.parameters())
            else:
                nonhidden_params += list(self.decoder.parameters())
            
            param_groups = [
                dict(params=hidden_weights, use_muon=True,
                     lr=self.config['muon_lr'], weight_decay=self.config['weight_decay']),
                dict(params=hidden_gains_biases + nonhidden_params, use_muon=False,
                     lr=self.config['adamw_lr'], betas=(0.9, 0.95), weight_decay=self.config['weight_decay']),
            ]
            optimizer = MuonWithAuxAdam(param_groups)
        else:
            optimizer = torch.optim.AdamW(
                list(self.encoder.parameters()) + list(self.decoder.parameters()),
                lr=self.config['learning_rate']
            )
        
        # Calculate total training steps
        if self.trainer.max_epochs:
            num_training_steps = self.trainer.estimated_stepping_batches
        else:
            num_training_steps = self.config['num_epochs'] * len(self.trainer.datamodule.train_dataloader())
        
        num_warmup_steps = self.config['warmup_steps'] if self.config['warmup_steps'] else int(num_training_steps * self.config['warmup_ratio'])
        
        # Configure scheduler
        scheduler_type = self.config['scheduler_type']
        
        if scheduler_type == 'linear':
            scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)
            return {'optimizer': optimizer, 'lr_scheduler': {'scheduler': scheduler, 'interval': 'step'}}
        elif scheduler_type == 'cosine':
            scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)
            return {'optimizer': optimizer, 'lr_scheduler': {'scheduler': scheduler, 'interval': 'step'}}
        elif scheduler_type == 'cosine_with_restarts':
            scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)
            return {'optimizer': optimizer, 'lr_scheduler': {'scheduler': scheduler, 'interval': 'step'}}
        elif scheduler_type == 'polynomial':
            scheduler = get_polynomial_decay_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, lr_end=0.0, power=1.0)
            return {'optimizer': optimizer, 'lr_scheduler': {'scheduler': scheduler, 'interval': 'step'}}
        elif scheduler_type == 'plateau':
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)
            return {'optimizer': optimizer, 'lr_scheduler': {'scheduler': scheduler, 'interval': 'epoch', 'monitor': 'val/loss'}}
        else:
            return optimizer

print("ProteinStructureModel defined")

ProteinStructureModel defined


In [8]:
# Initialize the data module
data_module = ProteinDataModule(
    dataset=struct_dat,
    batch_size=config['batch_size'],
    num_workers=config['num_workers'],
    train_split=0.9
)

# Setup the data module to get train/val splits
data_module.setup()

print("Data module initialized and ready")

Dataset split: 4499 train, 500 validation
Data module initialized and ready


In [9]:
# Initialize the model
model = ProteinStructureModel(
    config=config,
    ndim=ndim,
    ndim_godnode=ndim_godnode,
    ndim_fft2i=ndim_fft2i,
    ndim_fft2r=ndim_fft2r,
    converter=converter
)

print("Model initialized")
print(f"Encoder: {type(model.encoder).__name__}")
print(f"Decoder: {type(model.decoder).__name__}")

# Count parameters
encoder_params = sum(p.numel() for p in model.encoder.parameters())
decoder_params = sum(p.numel() for p in model.decoder.parameters())
total_params = encoder_params + decoder_params

print(f"\nParameter counts:")
print(f"  Encoder: {encoder_params:,}")
print(f"  Decoder: {decoder_params:,}")
print(f"  Total:   {total_params:,}")

Seed set to 42


Seed set to 42
Seed set to 42


Initializing decoder for task: sequence_transformer
False True False False False
256 4 2 0.001
Initializing decoder for task: geometry_cnn
False False False False False
Model initialized
Encoder: mk1_Encoder
Decoder: MultiMonoDecoder

Parameter counts:
  Encoder: 19,198,372
  Decoder: 3,875,231
  Total:   23,073,603


In [10]:
# Configure callbacks
checkpoint_callback = ModelCheckpoint(
    dirpath='models/lightning_checkpoints',
    filename='protein-{epoch:02d}-{val/loss:.4f}',
    monitor='val/loss',
    mode='min',
    save_top_k=config['save_top_k'],
    save_last=True,
    verbose=True
)

lr_monitor = LearningRateMonitor(logging_interval='step')

# Configure logger
logger = TensorBoardLogger(
    save_dir='runs',
    name='protein_lightning',
    default_hp_metric=False
)

print("Callbacks and logger configured")
print(f"  Checkpoints will be saved to: models/lightning_checkpoints")
print(f"  TensorBoard logs will be saved to: runs/protein_lightning")
print(f"  Monitoring: val/loss (save top {config['save_top_k']} models)")

Callbacks and logger configured
  Checkpoints will be saved to: models/lightning_checkpoints
  TensorBoard logs will be saved to: runs/protein_lightning
  Monitoring: val/loss (save top 3 models)


In [11]:
# Configure the Trainer for multi-GPU training
trainer = pl.Trainer(
    max_epochs=config['num_epochs'],
    accelerator=config['accelerator'],
    devices=config['devices'],  # -1 uses all available GPUs
    strategy=config['strategy'],  # DDP for multi-GPU
    precision=config['precision'],  # Mixed precision training
    callbacks=[checkpoint_callback, lr_monitor],
    logger=logger,
    log_every_n_steps=config['log_every_n_steps'],
    val_check_interval=config['val_check_interval'],
    #gradient_clip_val=,
    accumulate_grad_batches=config['gradient_accumulation_steps'],
    deterministic=True,
    enable_progress_bar=True,
    enable_model_summary=True
)

print("PyTorch Lightning Trainer configured")
print(f"\nTraining Configuration:")
print(f"  Max epochs: {config['num_epochs']}")
print(f"  Accelerator: {config['accelerator']}")
#print(f"  Devices: {config['devices']} ({'all GPUs' if config['devices'] == -1 else f'{config['devices']} GPUs'})")
print(f"  Strategy: {config['strategy']}")
print(f"  Precision: {config['precision']}")
print(f"  Gradient accumulation steps: {config['gradient_accumulation_steps']}")
print(f"  Gradient clipping: {config['clip_grad']}")

''' 
if torch.cuda.is_available():
    print(f"\nAvailable GPUs: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")
        print(f"    Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB")
'''

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
`Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..


PyTorch Lightning Trainer configured

Training Configuration:
  Max epochs: 300
  Accelerator: gpu
  Strategy: auto
  Precision: 16-mixed
  Gradient accumulation steps: 1
  Gradient clipping: True


' \nif torch.cuda.is_available():\n    print(f"\nAvailable GPUs: {torch.cuda.device_count()}")\n    for i in range(torch.cuda.device_count()):\n        print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")\n        print(f"    Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB")\n'

In [12]:
# Start training!
print("="*80)
print("STARTING MULTI-GPU TRAINING WITH PYTORCH LIGHTNING")
print("="*80)

trainer.fit(model, data_module)

print("\n" + "="*80)
print("TRAINING COMPLETE!")
print("="*80)
print(f"Best model checkpoint: {checkpoint_callback.best_model_path}")
print(f"Best validation loss: {checkpoint_callback.best_model_score:.4f}")

You are using a CUDA device ('NVIDIA RTX PRO 4000 Blackwell') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


STARTING MULTI-GPU TRAINING WITH PYTORCH LIGHTNING


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Loading `train_dataloader` to estimate number of stepping batches.


Dataset split: 4499 train, 500 validation


/home/dmoi/miniforge3/envs/foldtree2/lib/python3.9/site-packages/pytorch_lightning/core/optimizer.py:317: The lr scheduler dict contains the key(s) ['monitor'], but the keys will be ignored. You need to call `lr_scheduler.step()` manually in manual optimization.
/home/dmoi/miniforge3/envs/foldtree2/lib/python3.9/site-packages/pytorch_lightning/utilities/model_summary/model_summary.py:242: Precision 16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.

  | Name    | Type             | Params | Mode  | FLOPs
-------------------------------------------------------------
0 | encoder | mk1_Encoder      | 19.2 M | train | 0    
1 | decoder | MultiMonoDecoder | 3.9 M  | train | 0    
-------------------------------------------------------------
23.1 M    Trainable params
0         Non-trainable params
23.1 M    Total params
92.294    Total estimated model params size (MB)
141       Modules in train mode
0         Modules in e

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 10.74 GiB. GPU 0 has a total capacity of 23.43 GiB of which 6.71 GiB is free. Including non-PyTorch memory, this process has 16.68 GiB memory in use. Of the allocated memory 15.54 GiB is allocated by PyTorch, and 873.63 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
# Load the best checkpoint for evaluation
best_model = ProteinStructureModel.load_from_checkpoint(
    checkpoint_callback.best_model_path,
    config=config,
    ndim=ndim,
    ndim_godnode=ndim_godnode,
    ndim_fft2i=ndim_fft2i,
    ndim_fft2r=ndim_fft2r,
    converter=converter
)

best_model.eval()
best_model.freeze()

print(f"Loaded best model from: {checkpoint_callback.best_model_path}")
print(f"Best validation loss: {checkpoint_callback.best_model_score:.4f}")

In [None]:
# Visualize training metrics from TensorBoard logs
import pandas as pd
from tensorboard.backend.event_processing import event_accumulator

def load_tensorboard_logs(log_dir):
    """Load and parse TensorBoard event files."""
    ea = event_accumulator.EventAccumulator(log_dir)
    ea.Reload()
    
    # Get all scalar tags
    tags = ea.Tags()['scalars']
    
    data = {}
    for tag in tags:
        events = ea.Scalars(tag)
        data[tag] = pd.DataFrame([(e.step, e.value) for e in events], columns=['step', tag])
    
    return data

# Plot training curves
def plot_training_metrics():
    """Plot training and validation metrics."""
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    fig.suptitle('Training Metrics - Multi-GPU Lightning', fontsize=16, fontweight='bold')
    
    # This is a placeholder - actual implementation would read from TensorBoard logs
    # or use the trainer's logged metrics
    
    metrics_to_plot = [
        ('train/aa_loss', 'val/aa_loss', 'Amino Acid Loss'),
        ('train/edge_loss', 'val/edge_loss', 'Edge Loss'),
        ('train/vq_loss', 'val/vq_loss', 'VQ Loss'),
        ('train/angles_loss', 'val/angles_loss', 'Angles Loss'),
        ('train/ss_loss', 'val/ss_loss', 'Secondary Structure Loss'),
        ('lr', None, 'Learning Rate')
    ]
    
    for idx, (train_metric, val_metric, title) in enumerate(metrics_to_plot):
        ax = axes[idx // 3, idx % 3]
        ax.set_title(title)
        ax.set_xlabel('Step')
        ax.set_ylabel('Value')
        ax.grid(True, alpha=0.3)
        ax.legend()
    
    plt.tight_layout()
    return fig

print("Visualization functions defined")
print("Note: Run plot_training_metrics() after training to visualize results")

In [None]:
# Test the trained model on a sample
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
best_model = best_model.to(device)

# Get a random sample from validation set
sample_idx = random.randint(0, len(data_module.val_dataset) - 1)
data_sample = data_module.val_dataset[sample_idx]
data_sample = data_sample.to(device)

# Make prediction
with torch.no_grad():
    out, vqloss = best_model(data_sample)
    
print(f"Sample ID: {data_sample.identifier}")
print(f"\nModel outputs:")
for key, value in out.items():
    if value is not None:
        print(f"  {key}: {value.shape if hasattr(value, 'shape') else type(value)}")
print(f"\nVQ Loss: {vqloss.item() if isinstance(vqloss, torch.Tensor) else vqloss:.4f}")

## How to View TensorBoard Logs

To visualize training metrics in TensorBoard, run the following command in a terminal:

```bash
tensorboard --logdir=runs/protein_lightning
```

Then open your browser to `http://localhost:6006` to view:
- Training and validation loss curves
- Learning rate schedules
- Individual loss components (AA, edge, VQ, angles, SS, etc.)
- GPU utilization and memory usage

## Multi-GPU Training Benefits

This Lightning implementation provides:
1. **Automatic Data Parallelism**: Data is automatically distributed across all available GPUs using DDP
2. **Gradient Accumulation**: Effective batch size = batch_size √ó gradient_accumulation_steps √ó num_gpus
3. **Mixed Precision**: Automatic mixed precision (FP16) for faster training and reduced memory
4. **Efficient Checkpointing**: Automatic model checkpointing with best model selection
5. **Distributed Logging**: All metrics are properly synchronized across GPUs
6. **Fault Tolerance**: Training can be resumed from checkpoints

## Next Steps

- Monitor training progress in TensorBoard
- Adjust hyperparameters in the config dictionary
- Experiment with different schedulers and optimizers
- Evaluate model on test set
- Generate comprehensive visualizations of predictions

In [None]:
# Summary of training configuration
print("="*80)
print("MULTI-GPU TRAINING CONFIGURATION SUMMARY")
print("="*80)
print(f"\nüìä Data Configuration:")
print(f"   Dataset: {config['dataset_path']}")
print(f"   Total structures: {len(struct_dat):,}")
print(f"   Training structures: {len(data_module.train_dataset):,}")
print(f"   Validation structures: {len(data_module.val_dataset):,}")
print(f"   Batch size per GPU: {config['batch_size']}")

if torch.cuda.is_available():
    print(f"\nüñ•Ô∏è  Hardware Configuration:")
    print(f"   Number of GPUs: {torch.cuda.device_count()}")
    print(f"   Effective batch size: {config['batch_size'] * torch.cuda.device_count() * config['gradient_accumulation_steps']}")
    for i in range(torch.cuda.device_count()):
        print(f"   GPU {i}: {torch.cuda.get_device_name(i)}")

print(f"\nüß† Model Configuration:")
print(f"   Encoder: mk1_Encoder")
print(f"   Decoder: MultiMonoDecoder")
print(f"   Hidden size: {config['hidden_size']}")
print(f"   Embedding dim: {config['embedding_dim']}")
print(f"   Num embeddings: {config['num_embeddings']}")
print(f"   Total parameters: {total_params:,}")

print(f"\n‚öôÔ∏è  Training Configuration:")
print(f"   Max epochs: {config['num_epochs']}")
print(f"   Optimizer: {'Muon + AdamW' if config['use_muon'] else 'AdamW'}")
print(f"   Learning rate: {config['learning_rate']}")
print(f"   Scheduler: {config['scheduler_type']}")
print(f"   Mixed precision: {config['precision']}")
print(f"   Gradient clipping: {config['clip_grad']}")
print(f"   Gradient accumulation steps: {config['gradient_accumulation_steps']}")

print(f"\nüíæ Output Locations:")
print(f"   Checkpoints: models/lightning_checkpoints/")
print(f"   TensorBoard logs: runs/protein_lightning/")

print("\n" + "="*80)

## Visualization and Analysis

The following cells provide utilities for visualizing model predictions and analyzing performance.

In [None]:
# Set seeds for reproducibility
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

print("Random seeds set for reproducibility")

In [None]:
cd /home/dmoi/projects/foldtree2/

In [None]:
# FoldTree2 Multi-GPU Training with PyTorch Lightning

This notebook trains a protein structure prediction model using PyTorch Lightning for efficient multi-GPU training. The implementation replicates the logic from test_monodecoders.ipynb but leverages Lightning's distributed training capabilities.

## Key Features
- **Multi-GPU Support**: Automatically uses all available GPUs with DDP (Distributed Data Parallel)
- **Vector Quantized Encoding**: Proteins encoded into discrete embedding sequences
- **Multi-task Decoding**: Predicts amino acid sequences, contact maps, and geometric properties
- **Mixed Precision Training**: Automatic mixed precision for faster training
- **Advanced Optimizers**: Support for Muon and AdamW optimizers
- **Learning Rate Scheduling**: Multiple scheduler options (linear, cosine, plateau, etc.)

## Training Components
The notebook demonstrates:
- Custom LightningModule for encoder-decoder architecture
- LightningDataModule for efficient data loading
- Multi-GPU distributed training with DDP strategy
- Comprehensive logging and visualization
- Automatic checkpointing and model saving