# Simple DDPM Training and Inference for Protein Frames

This notebook trains a basic DDPM model to denoise protein frames from MD trajectories.

**What this does:**
- Loads single frames from MD trajectory data
- Adds noise according to a diffusion schedule
- Trains a model to predict and remove the noise
- Performs inference to denoise frames

**Steps:**
1. Setup environment and upload data
2. Configure hyperparameters
3. Train the model
4. Run inference and evaluate

## 1. Setup Environment

In [None]:
# Check if running on Colab
try:
    import google.colab
    IN_COLAB = True
    print("Running on Google Colab")
except:
    IN_COLAB = False
    print("Not running on Google Colab")

# Install dependencies
!pip install -q torch torchvision torchaudio
!pip install -q omegaconf pandas tqdm numpy

print("✓ Dependencies installed")

In [None]:
# Check GPU availability
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")

## 2. Upload Project Files

You have two options:
1. **Option A**: Upload the entire `gen_model` directory as a zip file
2. **Option B**: Clone from GitHub (if you have a repo)

Choose one of the options below:

In [None]:
# Option A: Upload gen_model.zip
# 1. On your local machine: zip -r gen_model.zip gen_model/ data/
# 2. Upload the zip file using the code below

if IN_COLAB:
    from google.colab import files
    print("Upload gen_model.zip (containing gen_model/ and data/ directories)")
    uploaded = files.upload()
    
    # Extract
    !unzip -q gen_model.zip
    print("✓ Files extracted")
else:
    print("Not on Colab - skip this step")

In [None]:
# Option B: Clone from GitHub
# Uncomment and modify the line below if you have a GitHub repo

# !git clone https://github.com/YOUR_USERNAME/YOUR_REPO.git
# %cd YOUR_REPO
# print("✓ Repository cloned")

In [None]:
# Verify files are present
import os
print("Checking for required files...")
print(f"gen_model exists: {os.path.exists('gen_model')}")
print(f"data exists: {os.path.exists('data')}")
print(f"simple_train.py exists: {os.path.exists('gen_model/simple_train.py')}")
print(f"simple_inference.py exists: {os.path.exists('gen_model/simple_inference.py')}")
print(f"dataset.py exists: {os.path.exists('gen_model/dataset.py')}")

# List data files
if os.path.exists('data'):
    !ls -lh data/

## 3. Configure Hyperparameters

**Key hyperparameters to tune:**

In [None]:
from omegaconf import OmegaConf

# ============================================================================
# HYPERPARAMETERS - Adjust these for your experiments
# ============================================================================

config = OmegaConf.create({
    # Dataset configuration
    'data': {
        'data_dir': 'data',
        'atlas_csv': 'data/atlas.csv',
        'train_split': 'gen_model/splits/frame_splits.csv',
        'suffix': '_latent',  # File suffix for .npy files
        'frame_interval': None,  # Sample every N frames (None = all frames)
        'crop_ratio': 0.95,  # Ratio of residues to keep (0.95 = 95%)
        'min_t': 0.01,  # Minimum diffusion timestep
    },
    
    # Diffusion configuration
    'diffusion': {
        'timesteps': 1000,  # Number of diffusion steps (100-1000 typical)
        'beta_start': 0.0001,  # Starting noise level (smaller = less noise initially)
        'beta_end': 0.02,  # Ending noise level (larger = more noise at end)
        # Note: Can try cosine schedule for potentially better results
    },
    
    # Model architecture
    'model': {
        'hidden_dim': 256,  # Hidden layer size (128-512 typical)
        'time_emb_dim': 128,  # Time embedding dimension (64-256 typical)
        # Larger = more capacity but slower training
    },
    
    # Training configuration
    'training': {
        'batch_size': 8,  # Batch size (reduce if OOM, increase for faster training)
        'num_epochs': 100,  # Number of training epochs (50-200 typical)
        'learning_rate': 1e-4,  # Learning rate (1e-5 to 1e-3 typical)
        'num_workers': 2,  # DataLoader workers (0-4 on Colab)
        'save_every': 10,  # Save checkpoint every N epochs
    },
    
    # Checkpoint configuration
    'checkpoint': {
        'save_dir': 'checkpoints/simple_ddpm',
        'load_from': None,  # Path to checkpoint to resume from (or None)
    },
    
    # Inference configuration
    'inference': {
        'num_samples': 5,  # Number of samples to test
        'output_dir': 'outputs/simple_ddpm',
        'denoise_steps': 1000,  # Number of denoising steps (can be < training steps)
    },
})

print("Configuration:")
print(OmegaConf.to_yaml(config))

### Hyperparameter Tuning Guide

| Hyperparameter | Description | Typical Range | Impact |
|----------------|-------------|---------------|--------|
| **timesteps** | Number of diffusion steps | 100-1000 | More steps = better quality but slower |
| **beta_start** | Initial noise level | 0.0001-0.001 | Lower = gradual noising |
| **beta_end** | Final noise level | 0.01-0.05 | Higher = more noise |
| **hidden_dim** | Model capacity | 128-512 | Larger = more capacity, slower |
| **batch_size** | Samples per iteration | 4-32 | Larger = faster but needs more memory |
| **num_epochs** | Training iterations | 50-200 | More = better but can overfit |
| **learning_rate** | Optimization step size | 1e-5 to 1e-3 | Too high = unstable, too low = slow |

**Recommended starting points:**
- **Quick test**: timesteps=100, hidden_dim=128, batch_size=16, num_epochs=50
- **Balanced**: timesteps=500, hidden_dim=256, batch_size=8, num_epochs=100 (default)
- **High quality**: timesteps=1000, hidden_dim=512, batch_size=4, num_epochs=200

## 4. Import Training Code

In [None]:
import sys
import os

# Add gen_model to path
if 'gen_model' not in sys.path:
    sys.path.insert(0, os.path.abspath('.'))

# Import training modules
from gen_model.simple_train import SimpleDDPM, SimpleDenoiseModel, train_ddpm
from gen_model.simple_inference import (
    denoise_step, 
    sample_from_noise, 
    denoise_frame,
    load_checkpoint,
    test_with_dataset
)
from gen_model.dataset import MDGenDataset

import torch
import numpy as np
from tqdm import tqdm

print("✓ Modules imported successfully")

## 5. Prepare Dataset

In [None]:
# Create training dataset
print("Loading training dataset...")
train_dataset = MDGenDataset(
    args=config.data,
    diffuser=None,  # No SE3 diffuser needed for basic DDPM
    split=config.data.train_split,
    mode='train',
    repeat=1,
    num_consecutive=1,
    stride=1
)

print(f"✓ Training dataset size: {len(train_dataset)}")

# Create validation dataset
print("Loading validation dataset...")
val_dataset = MDGenDataset(
    args=config.data,
    diffuser=None,
    split=config.data.train_split,
    mode='val',
    repeat=1,
    num_consecutive=1,
    stride=1
)

print(f"✓ Validation dataset size: {len(val_dataset)}")

# Get sample to determine input dimensions
sample = train_dataset[0]
print(f"\nSample keys: {sample.keys()}")

if 'atom14_pos' in sample:
    sample_data = sample['atom14_pos']
    data_key = 'atom14_pos'
elif 'rigids_0' in sample:
    sample_data = sample['rigids_0'][..., 4:]
    data_key = 'rigids_0'
else:
    raise ValueError("Sample must contain 'atom14_pos' or 'rigids_0'")

print(f"Using data key: {data_key}")
print(f"Sample data shape: {sample_data.shape}")

# Determine input dimensions
if len(sample_data.shape) == 3:
    # [N_res, N_atoms, 3]
    n_residues = sample_data.shape[0]
    in_channels = sample_data.shape[1] * sample_data.shape[2]
else:
    # [N_res, 3] or [N_res, C]
    n_residues = sample_data.shape[0]
    in_channels = sample_data.shape[1]

print(f"Number of residues: {n_residues}")
print(f"Input channels (flattened): {in_channels}")

## 6. Initialize Model and Diffusion

In [None]:
# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Create diffusion scheduler
diffusion = SimpleDDPM(
    timesteps=config.diffusion.timesteps,
    beta_start=config.diffusion.beta_start,
    beta_end=config.diffusion.beta_end
)
diffusion = diffusion.to(device)
print(f"✓ Diffusion scheduler created with {config.diffusion.timesteps} timesteps")

# Create model
model = SimpleDenoiseModel(
    in_channels=in_channels,
    hidden_dim=config.model.hidden_dim,
    time_emb_dim=config.model.time_emb_dim
)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"✓ Model created")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Model size: ~{total_params * 4 / 1024 / 1024:.2f} MB")

## 7. Training

In [None]:
# Train the model
print("\n" + "="*80)
print("Starting training...")
print("="*80)

train_ddpm(
    dataset=train_dataset,
    model=model,
    diffusion=diffusion,
    device=device,
    batch_size=config.training.batch_size,
    num_epochs=config.training.num_epochs,
    lr=config.training.learning_rate,
    save_dir=config.checkpoint.save_dir,
    save_every=config.training.save_every
)

print("\n" + "="*80)
print("Training complete!")
print("="*80)

## 8. Inference and Evaluation

In [None]:
# Load the best checkpoint
import glob

checkpoint_files = sorted(glob.glob(f"{config.checkpoint.save_dir}/*.pt"))
if checkpoint_files:
    latest_checkpoint = checkpoint_files[-1]
    print(f"Loading checkpoint: {latest_checkpoint}")
    
    model, epoch, loss = load_checkpoint(latest_checkpoint, model, device)
    model = model.to(device)
    print(f"✓ Loaded checkpoint from epoch {epoch}")
else:
    print("No checkpoints found. Using current model state.")

In [None]:
# Test denoising on validation set
print("\nTesting denoising on validation samples...")
print("="*80)

results = test_with_dataset(
    model=model,
    diffusion=diffusion,
    dataset=val_dataset,
    device=device,
    num_samples=config.inference.num_samples
)

# Calculate statistics
mse_values = [r['mse'] for r in results]
print("\n" + "="*80)
print(f"Results Summary:")
print(f"  Average MSE: {np.mean(mse_values):.6f}")
print(f"  Std MSE: {np.std(mse_values):.6f}")
print(f"  Min MSE: {np.min(mse_values):.6f}")
print(f"  Max MSE: {np.max(mse_values):.6f}")
print("="*80)

In [None]:
# Save results
import os
os.makedirs(config.inference.output_dir, exist_ok=True)

for i, result in enumerate(results):
    np.save(f"{config.inference.output_dir}/test_{i}_original.npy", result['original'])
    np.save(f"{config.inference.output_dir}/test_{i}_noisy.npy", result['noisy'])
    np.save(f"{config.inference.output_dir}/test_{i}_denoised.npy", result['denoised'])

print(f"✓ Results saved to {config.inference.output_dir}")

In [None]:
# Visualize reconstruction quality
import matplotlib.pyplot as plt

# Plot MSE distribution
plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
plt.bar(range(len(mse_values)), mse_values)
plt.xlabel('Sample Index')
plt.ylabel('MSE')
plt.title('Reconstruction Error per Sample')
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.hist(mse_values, bins=10, edgecolor='black')
plt.xlabel('MSE')
plt.ylabel('Count')
plt.title('MSE Distribution')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"{config.inference.output_dir}/mse_analysis.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"✓ Visualization saved to {config.inference.output_dir}/mse_analysis.png")

## 9. Generate New Samples from Noise

In [None]:
# Generate samples from pure noise
print("Generating samples from random noise...")

num_samples_to_generate = 3
shape = (num_samples_to_generate, n_residues, in_channels)

generated_samples = sample_from_noise(
    model=model,
    diffusion=diffusion,
    shape=shape,
    device=device,
    num_steps=config.inference.denoise_steps
)

# Save generated samples
for i in range(num_samples_to_generate):
    sample_path = f"{config.inference.output_dir}/generated_sample_{i}.npy"
    np.save(sample_path, generated_samples[i].cpu().numpy())
    print(f"  Saved: {sample_path}")

print(f"\n✓ Generated {num_samples_to_generate} samples from noise")

## 10. Download Results

In [None]:
# Create zip file with results
!zip -r results.zip {config.checkpoint.save_dir} {config.inference.output_dir}

if IN_COLAB:
    from google.colab import files
    files.download('results.zip')
    print("✓ Results downloaded")
else:
    print("Results saved locally in results.zip")

## Summary

**What we did:**
1. ✓ Set up environment and uploaded data
2. ✓ Configured hyperparameters
3. ✓ Loaded MD trajectory dataset
4. ✓ Created DDPM model and diffusion scheduler
5. ✓ Trained the model
6. ✓ Tested denoising on validation set
7. ✓ Generated new samples from noise
8. ✓ Saved and downloaded results

**Next steps to improve:**
- Tune hyperparameters (see tuning guide above)
- Try longer training (more epochs)
- Experiment with different model architectures
- Add SE3 equivariance for better protein structure modeling
- Implement conditional generation (condition on sequence, etc.)

**Files saved:**
- Checkpoints: `{config.checkpoint.save_dir}/`
- Results: `{config.inference.output_dir}/`
- Visualizations: `{config.inference.output_dir}/mse_analysis.png`