# Single-Protein DDPM Training

This notebook trains a DDPM model on **one specific protein** by learning to denoise different conformational states (frames) from an MD trajectory.

**Workflow:**
1. Specify protein name and parameters
2. Dynamically create/load trajectory data for that protein
3. Train DDPM to denoise frames
4. Generate and evaluate results

## Step 1: Environment Setup

In [None]:
# Check environment
try:
    import google.colab
    IN_COLAB = True
    print("✓ Running on Google Colab")
except:
    IN_COLAB = False
    print("✓ Running locally")

# Install dependencies
!pip install -q torch torchvision torchaudio
!pip install -q omegaconf pandas tqdm numpy matplotlib lightning mdtraj requests



import torch
print(f"\n✓ PyTorch {torch.__version__}")
print(f"✓ CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"✓ GPU: {torch.cuda.get_device_name(0)}")

In [None]:
from google.colab import drive
import os
drive.mount('/content/drive')

# 1. Create a folder on your Drive for permanent storage
!mkdir -p /content/drive/MyDrive/protein_data/data

# 2. Map the local 'data' folder to your Drive folder
if not os.path.exists('data'):
    !ln -s /content/drive/MyDrive/protein_data/data data

## Step 2: Get Code from GitHub

Clone your repository to get the `gen_model` code (no data needed)

In [None]:
# Clone repository
REPO_URL = "https://github.com/JiwonJJeong/winter-gen-pproject.git"

import os

if not os.path.exists('winter-gen-pproject'):
    print(f"Cloning from: {REPO_URL}")
    !git clone {REPO_URL}
    %cd winter-gen-pproject/
    print("\n✓ Repository cloned")
else:
    # Pull latest and clear only gen_model bytecode cache
    !git -C /content/winter-gen-pproject pull origin main
    !find /content/winter-gen-pproject/gen_model -name '*.pyc' -delete 2>/dev/null; true
    !find /content/winter-gen-pproject/gen_model -name '__pycache__' -type d -exec rm -rf {} + 2>/dev/null; true
    print("✓ Code updated & bytecode cache cleared")

!ls -la gen_model/
print("\n✓ Code ready")

## Step 3: Configure Protein and Training

**Customize these settings for your protein:**

In [None]:
from omegaconf import OmegaConf

protein_config = OmegaConf.create({

    # ========== Protein Settings ==========
    'protein': {
        'name': '4o66_C',           # Protein name (without _R suffix)
        'replica': 1,               # Replica number
        'num_frames': 200,          # Number of trajectory frames to generate
        'num_residues': 100,        # Number of residues in the protein

        # Data split ratios (must sum to < 1.0)
        'train_early_ratio': 0.3,   # Early training frames (30%)
        'train_ratio': 0.4,         # Main training frames (40%)
        'val_ratio': 0.2,           # Validation frames (20%)
        # Remaining frames are test (~10%)
    },

    # ========== SE(3) Diffusion Settings ==========
    'se3': {
        'diffuse_rot': True,        # Diffuse rotations on SO(3)
        'diffuse_trans': True,      # Diffuse translations on R^3
        'so3': {
            'schedule': 'logarithmic',
            'min_sigma': 0.1,
            'max_sigma': 1.5,
            'num_sigma': 1000,
            'use_cached_score': False,
            'cache_dir': '/tmp/igso3_cache',
            'num_omega': 1000,
        },
        'r3': {
            'min_b': 0.1,
            'max_b': 20.0,
            'coordinate_scaling': 0.1,  # Angstrom → normalized units
        },
    },

    # ========== Score Network Architecture ==========
    'score_model': {
        'node_embed_size': 256,
        'edge_embed_size': 128,
        'embed': {
            'index_embed_size': 32,
            'embed_self_conditioning': True,
            'num_bins': 22,
            'min_bin': 1e-5,
            'max_bin': 20.0,
        },
        'ipa': {
            'c_s': 256,             # Must match node_embed_size
            'c_z': 128,             # Must match edge_embed_size
            'c_hidden': 16,
            'no_heads': 12,
            'no_qk_points': 4,
            'no_v_points': 8,
            'c_skip': 64,
            'num_blocks': 4,
            'coordinate_scaling': 0.1,
            'seq_tfmr_num_heads': 4,
            'seq_tfmr_num_layers': 2,
        },
    },

    # ========== Training Settings ==========
    'training': {
        'batch_size': 8,            # Batch size (4-16 depending on GPU)
        'num_epochs': 100,          # Number of training epochs (50-200)
        'learning_rate': 1e-4,      # Learning rate (1e-5 to 5e-4)
        'rot_loss_weight': 1.0,     # Weight for rotation score loss
        'trans_loss_weight': 1.0,   # Weight for translation score loss
        'psi_loss_weight': 1.0,     # Weight for psi torsion angle loss
    },

    # ========== Inference Settings ==========
    'inference': {
        'num_samples': 5,           # Number of samples to test
        'num_steps': 200,           # Reverse diffusion steps
    }
})

print("Configuration Summary:")
print("="*80)
print(f"Protein: {protein_config.protein.name}_R{protein_config.protein.replica}")
print(f"Frames: {protein_config.protein.num_frames}")
print(f"SE(3) diffusion: rot={protein_config.se3.diffuse_rot}, trans={protein_config.se3.diffuse_trans}")
print(f"Score network: node_dim={protein_config.score_model.node_embed_size}, "
      f"edge_dim={protein_config.score_model.edge_embed_size}, "
      f"IPA blocks={protein_config.score_model.ipa.num_blocks}")
print(f"Loss weights: rot={protein_config.training.rot_loss_weight}, "
      f"trans={protein_config.training.trans_loss_weight}, "
      f"psi={protein_config.training.psi_loss_weight}")
print(f"Training epochs: {protein_config.training.num_epochs}")
print(f"Batch size: {protein_config.training.batch_size}")
print("="*80)


## Step 4: Create/Load Protein Data

This creates the data structure for your specified protein

In [None]:
# Use the automatic download and prep script
prot_name = protein_config.protein.name
!python scripts/download_and_prep.py {prot_name} --data_dir ./data --out_dir ./data --suffix _latent
# Setup paths for verification
prot_cfg = protein_config.protein
PROTEIN_FULL_NAME = f"{prot_cfg.name}_R{prot_cfg.replica}"
protein_dir = f'data/{prot_cfg.name}'
traj_path = f'{protein_dir}/{PROTEIN_FULL_NAME}_latent.npy'

if os.path.exists(traj_path):
    print(f"✅ Data ready at: {traj_path}")
else:
    print(f"❌ Error: Data not found at {traj_path}")


## Step 5: Configure Dataset and Model

In [None]:
# Import modules
import sys
sys.path.insert(0, '.')

from gen_model.simple_train import SE3Module
from gen_model.dataset import MDGenDataset
from gen_model.diffusion.se3_diffuser import SE3Diffuser

print("✓ Modules imported")

# Create dataset config
data_config = OmegaConf.create({
    'data_dir': 'data',
    'atlas_csv': 'gen_model/splits/atlas.csv',
    'train_split': 'gen_model/splits/frame_splits.csv',
    'suffix': '_latent',
    'frame_interval': None,
    'crop_ratio': 0.95,
    'min_t': 0.01,

    # Single protein filters
    'pep_name': prot_cfg.name,     # Only load this protein
    'replica': prot_cfg.replica,   # Only load this replica
})

print(f"\nDataset config:")
print(f"  Protein: {data_config.pep_name}")
print(f"  Replica: {data_config.replica}")


## Step 6: Load Dataset

In [None]:
# Create SE3Diffuser (initialises IGSO3 look-up tables — takes ~10 s first run)
print("Initialising SE3Diffuser...")
se3_diffuser = SE3Diffuser(protein_config.se3)
print("✓ SE3Diffuser ready\n")

# Create datasets with diffuser so each batch already contains
# rigids_t, rot_score, trans_score computed from the forward process
print("Loading datasets...\n")

train_dataset = MDGenDataset(
    args=data_config,
    diffuser=se3_diffuser,
    mode='train',
    repeat=1,
    num_consecutive=1,
    stride=1
)

val_dataset = MDGenDataset(
    args=data_config,
    diffuser=se3_diffuser,
    mode='val',
    repeat=1,
    num_consecutive=1,
    stride=1
)
val_dataset.coord_scale = float(train_dataset.coord_scale)

print(f"✓ Training frames: {len(train_dataset)}")
print(f"✓ Validation frames: {len(val_dataset)}")

# Inspect a sample to confirm keys
sample = train_dataset[0]
n_residues = sample['res_mask'].shape[0]
print(f"\nSample keys: {list(sample.keys())}")
print(f"Residues: {n_residues}")
print(f"rigids_t shape: {sample['rigids_t'].shape}  (per-residue SE3 frame at time t)")
print(f"rot_score shape: {sample['rot_score'].shape}")
print(f"trans_score shape: {sample['trans_score'].shape}")
print(f"t (noise level): {sample['t']:.4f}")


## Step 7: Create Model

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}\n")

# SE3Module wraps ScoreNetwork (IPA-based) + SE3Diffuser (for score computation)
model_pl = SE3Module(
    model_conf=protein_config.score_model,
    se3_conf=protein_config.se3,
    lr=protein_config.training.learning_rate,
    rot_loss_weight=protein_config.training.rot_loss_weight,
    trans_loss_weight=protein_config.training.trans_loss_weight,
    psi_loss_weight=protein_config.training.psi_loss_weight,
)

n_params = sum(p.numel() for p in model_pl.parameters())
print(f"✓ SE3 Score Network (IPA): {n_params:,} parameters (~{n_params*4/1024/1024:.1f} MB)")
print(f"  node_embed_size: {protein_config.score_model.node_embed_size}")
print(f"  edge_embed_size: {protein_config.score_model.edge_embed_size}")
print(f"  IPA blocks: {protein_config.score_model.ipa.num_blocks}")
print(f"  Loss: rot + trans (score-matching) + psi (sin/cos MSE)")


## Step 8: Train

In [None]:
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader

# 1. DataLoaders
train_loader = DataLoader(train_dataset, batch_size=protein_config.training.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=protein_config.training.batch_size, shuffle=False)

# 2. Checkpointing
checkpoint_callback = ModelCheckpoint(
    dirpath=f'checkpoints/{prot_cfg.name}_se3',
    filename='se3-{epoch:02d}-{val_loss:.4f}',
    save_top_k=3,
    monitor='val_loss',
    mode='min',
    save_last=True,
)

# 3. Train  (model_pl created in Step 7)
trainer = L.Trainer(
    max_epochs=protein_config.training.num_epochs,
    accelerator="auto",
    devices=1,
    callbacks=[checkpoint_callback],
    precision="16-mixed" if torch.cuda.is_available() else 32,
)

trainer.fit(model_pl, train_dataloaders=train_loader, val_dataloaders=val_loader)


## Step 9: Evaluate

In [None]:
# Load best checkpoint
import glob
ckpts = sorted(glob.glob(f"checkpoints/{prot_cfg.name}_ddpm/*.ckpt"))
if ckpts:
    model, epoch, loss, coord_scale = load_checkpoint(ckpts[-1], model, device)
    print(f"✓ Loaded checkpoint from epoch {epoch}")

# Test with unscaling
results = test_with_dataset(
    model, diffusion, val_dataset, device,
    coord_scale=coord_scale,
    num_samples=protein_config.inference.num_samples
)


## Step 10: Generate Samples

In [None]:
# Generate new conformations from noise
print("Generating new samples from noise...\n")

shape = (3, n_residues, in_channels)
generated = sample_from_noise(
    model, diffusion, shape, device,
    num_steps=protein_config.inference.denoise_steps
)

# Save
output_dir = f'outputs/{prot_cfg.name}_ddpm'
os.makedirs(output_dir, exist_ok=True)

for i in range(3):
    path = f'{output_dir}/generated_sample_{i}.npy'
    np.save(path, generated[i].cpu().numpy())
    print(f"  Saved: {path}")

# Save test results
for i, r in enumerate(results):
    np.save(f'{output_dir}/test_{i}_original.npy', r['original'])
    np.save(f'{output_dir}/test_{i}_denoised.npy', r['denoised'])

print(f"\n✓ Results saved to {output_dir}")

## Step 11: Visualize

In [None]:
import matplotlib.pyplot as plt
mse_values = [r['mse'] for r in results]
plt.figure(figsize=(12, 4))

plt.subplot(1, 3, 1)
plt.bar(range(len(mse_values)), mse_values, color='steelblue')
plt.xlabel('Sample')
plt.ylabel('MSE')
plt.title(f'Reconstruction Error\n{PROTEIN_FULL_NAME}')
plt.grid(alpha=0.3)

plt.subplot(1, 3, 2)
plt.hist(mse_values, bins=10, edgecolor='black', color='steelblue')
plt.xlabel('MSE')
plt.ylabel('Count')
plt.title('MSE Distribution')
plt.grid(alpha=0.3)

plt.subplot(1, 3, 3)
plt.plot(mse_values, 'o-', color='steelblue')
plt.axhline(np.mean(mse_values), color='red', linestyle='--', label='Mean')
plt.xlabel('Sample')
plt.ylabel('MSE')
plt.title('MSE Trend')
plt.legend()
plt.grid(alpha=0.3)

plt.tight_layout()
plt.savefig(f'{output_dir}/analysis.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"✓ Saved: {output_dir}/analysis.png")

## Step 12: Download Results

In [None]:
# Package results
!zip -rq {prot_cfg.name}_results.zip {save_dir} {output_dir}

print(f"\nResults packaged: {prot_cfg.name}_results.zip")
print(f"  Checkpoints: {save_dir}/")
print(f"  Outputs: {output_dir}/")

if IN_COLAB:
    from google.colab import files
    files.download(f'{prot_cfg.name}_results.zip')
    print("\n✓ Download started")
else:
    print(f"\n✓ Saved locally as {prot_cfg.name}_results.zip")

## Summary

**What we did:**
1. ✓ Configured protein: `{protein_config.protein.name}_R{protein_config.protein.replica}`
2. ✓ Downloaded and preprocessed trajectory data using 
download_and_prep.py
3. ✓ Trained DDPM for {protein_config.training.num_epochs} epochs
4. ✓ Evaluated denoising on validation frames
5. ✓ Generated new conformations from noise

**Key insight:** The model learned the conformational space of **one specific protein** by training on different frames from its MD trajectory.

**To train on a different protein:**
- Change `protein_config.protein.name`
- Rerun from Step 3 onwards