# Single-Protein DDPM Training

This notebook trains a DDPM model on **one specific protein** by learning to denoise different conformational states (frames) from an MD trajectory.

**Workflow:**
1. Specify protein name and parameters
2. Dynamically create/load trajectory data for that protein
3. Train DDPM to denoise frames
4. Generate and evaluate results

## Step 1: Environment Setup

In [None]:
# Check environment
try:
    import google.colab
    IN_COLAB = True
    print("✓ Running on Google Colab")
except:
    IN_COLAB = False
    print("✓ Running locally")

# Install dependencies
!pip install -q torch torchvision torchaudio
!pip install -q omegaconf pandas tqdm numpy matplotlib

import torch
print(f"\n✓ PyTorch {torch.__version__}")
print(f"✓ CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"✓ GPU: {torch.cuda.get_device_name(0)}")

## Step 2: Upload gen_model Code

Upload **only** `gen_model.zip` (just the code, NO data needed)

In [None]:
# Upload gen_model code
# Create gen_model.zip locally: zip -r gen_model.zip gen_model/

if IN_COLAB:
    from google.colab import files
    print("Upload gen_model.zip (code only, no data)")
    uploaded = files.upload()
    !unzip -q gen_model.zip
    print("\n✓ Code extracted")
    !ls gen_model/
else:
    print("Not on Colab - assuming gen_model/ exists")
    !ls gen_model/

## Step 3: Protein Configuration

**Specify your protein here** - data will be created automatically

In [None]:
from omegaconf import OmegaConf

# ============================================================================
# PROTEIN CONFIGURATION - Change these for your protein
# ============================================================================

protein_config = OmegaConf.create({
    # Protein specification
    'protein': {
        'name': '4o66_C',        # Protein name (will create folder data/4o66_C/)
        'replica': 1,            # Replica number
        'num_frames': 200,       # Number of frames in trajectory
        'num_residues': 100,     # Number of residues
        
        # Data splits (by frame index)
        'train_early_ratio': 0.3,  # First 30% of frames for early training
        'train_ratio': 0.4,         # Next 40% for main training
        'val_ratio': 0.15,          # Next 15% for validation
        # Remaining 15% for test
    },
    
    # If you have REAL data, set this to True and provide download method
    'use_real_data': False,  # Set to True if downloading real data
    'data_source': None,     # URL, Google Drive path, etc.
    
    # Diffusion parameters
    'diffusion': {
        'timesteps': 500,      # 100=fast test, 500=balanced, 1000=high quality
        'beta_start': 0.0001,
        'beta_end': 0.02,
    },
    
    # Model architecture
    'model': {
        'hidden_dim': 256,     # 128=small, 256=balanced, 512=large
        'time_emb_dim': 128,
    },
    
    # Training
    'training': {
        'batch_size': 8,       # Reduce if OOM (8→4→2)
        'num_epochs': 100,     # 50=quick, 100=balanced, 200=thorough
        'learning_rate': 1e-4,
        'num_workers': 2,
        'save_every': 10,
    },
    
    # Inference
    'inference': {
        'num_samples': 5,
        'denoise_steps': 500,
    },
})

print("Protein Configuration:")
print(f"  Name: {protein_config.protein.name}")
print(f"  Replica: {protein_config.protein.replica}")
print(f"  Frames: {protein_config.protein.num_frames}")
print(f"  Residues: {protein_config.protein.num_residues}")
print(f"  Diffusion steps: {protein_config.diffusion.timesteps}")
print(f"  Training epochs: {protein_config.training.num_epochs}")

## Step 4: Create/Load Protein Data

This creates the data structure for your specified protein

In [None]:
import numpy as np
import pandas as pd
import os

prot_cfg = protein_config.protein
PROTEIN_FULL_NAME = f"{prot_cfg.name}_R{prot_cfg.replica}"

print(f"Setting up data for: {PROTEIN_FULL_NAME}")
print("="*80)

if protein_config.use_real_data:
    # ========== OPTION A: Download/Load Real Data ==========
    print("Loading real data...")
    
    # Customize this based on your data source
    # Example options:
    
    # From URL:
    # !wget -O data.tar.gz {protein_config.data_source}
    # !tar -xzf data.tar.gz
    
    # From Google Drive:
    # from google.colab import drive
    # drive.mount('/content/drive')
    # !cp -r /content/drive/MyDrive/md_data/{prot_cfg.name} data/
    
    # From Google Drive file ID:
    # !pip install -q gdown
    # !gdown {protein_config.data_source} -O data.zip
    # !unzip -q data.zip
    
    print("✓ Real data loaded")
    
else:
    # ========== OPTION B: Create Synthetic Data ==========
    print("Creating synthetic trajectory data...")
    
    # Create directory
    protein_dir = f'data/{prot_cfg.name}'
    os.makedirs(protein_dir, exist_ok=True)
    
    # Create trajectory: [num_frames, num_residues, 14 atoms, xyz]
    trajectory = np.random.randn(
        prot_cfg.num_frames,
        prot_cfg.num_residues,
        14,  # atom14 representation
        3    # x, y, z
    ).astype(np.float32)
    
    # Normalize to reasonable protein scale
    trajectory = trajectory * 5.0  # ~5 Angstrom std deviation
    
    # Save trajectory
    traj_path = f'{protein_dir}/{PROTEIN_FULL_NAME}_latent.npy'
    np.save(traj_path, trajectory)
    print(f"  ✓ Created: {traj_path}")
    print(f"    Shape: {trajectory.shape}")
    print(f"    Size: {trajectory.nbytes / 1024 / 1024:.2f} MB")

# Create atlas.csv (sequence mapping)
atlas_data = {
    'name': [PROTEIN_FULL_NAME],
    'seqres': ['A' * prot_cfg.num_residues]  # Dummy sequence
}
os.makedirs('data', exist_ok=True)
pd.DataFrame(atlas_data).to_csv('data/atlas.csv', index=False)
print(f"\n  ✓ Created: data/atlas.csv")

# Create frame splits
train_early_end = int(prot_cfg.num_frames * prot_cfg.train_early_ratio)
train_end = int(prot_cfg.num_frames * (prot_cfg.train_early_ratio + prot_cfg.train_ratio))
val_end = int(prot_cfg.num_frames * (prot_cfg.train_early_ratio + prot_cfg.train_ratio + prot_cfg.val_ratio))

splits_data = {
    'name': [PROTEIN_FULL_NAME],
    'train_early_end': [train_early_end],
    'train_end': [train_end],
    'val_end': [val_end],
}

os.makedirs('gen_model/splits', exist_ok=True)
pd.DataFrame(splits_data).to_csv('gen_model/splits/frame_splits.csv', index=False)
print(f"  ✓ Created: gen_model/splits/frame_splits.csv")

print(f"\nData splits (by frame index):")
print(f"  Train early: frames 0-{train_early_end} ({train_early_end} frames)")
print(f"  Train: frames {train_early_end}-{train_end} ({train_end - train_early_end} frames)")
print(f"  Val: frames {train_end}-{val_end} ({val_end - train_end} frames)")
print(f"  Test: frames {val_end}-{prot_cfg.num_frames} ({prot_cfg.num_frames - val_end} frames)")

print("\n" + "="*80)
print(f"✓ Data ready for protein: {PROTEIN_FULL_NAME}")
print("="*80)

## Step 5: Configure Dataset and Model

In [None]:
# Import modules
import sys
sys.path.insert(0, '.')

from gen_model.simple_train import SimpleDDPM, SimpleDenoiseModel, train_ddpm
from gen_model.simple_inference import (
    sample_from_noise, denoise_frame, load_checkpoint, test_with_dataset
)
from gen_model.dataset import MDGenDataset

print("✓ Modules imported")

# Create dataset config
data_config = OmegaConf.create({
    'data_dir': 'data',
    'atlas_csv': 'data/atlas.csv',
    'train_split': 'gen_model/splits/frame_splits.csv',
    'suffix': '_latent',
    'frame_interval': None,
    'crop_ratio': 0.95,
    'min_t': 0.01,
    
    # Single protein filters
    'pep_name': prot_cfg.name,     # Only load this protein
    'replica': prot_cfg.replica,   # Only load this replica
})

print(f"\nDataset config:")
print(f"  Protein: {data_config.pep_name}")
print(f"  Replica: {data_config.replica}")

## Step 6: Load Dataset

In [None]:
# Create datasets
print("Loading datasets...\n")

train_dataset = MDGenDataset(
    args=data_config,
    diffuser=None,
    mode='train',
    repeat=1,
    num_consecutive=1,
    stride=1
)

val_dataset = MDGenDataset(
    args=data_config,
    diffuser=None,
    mode='val',
    repeat=1,
    num_consecutive=1,
    stride=1
)

print(f"✓ Training frames: {len(train_dataset)}")
print(f"✓ Validation frames: {len(val_dataset)}")

# Check sample
sample = train_dataset[0]
print(f"\nSample keys: {list(sample.keys())}")

# Get dimensions
if 'atom14_pos' in sample:
    sample_data = sample['atom14_pos']
elif 'rigids_0' in sample:
    sample_data = sample['rigids_0'][..., 4:]
else:
    raise ValueError("Unknown data format")

if len(sample_data.shape) == 3:
    n_residues = sample_data.shape[0]
    in_channels = sample_data.shape[1] * sample_data.shape[2]
else:
    n_residues = sample_data.shape[0]
    in_channels = sample_data.shape[1]

print(f"\nData shape: {sample_data.shape}")
print(f"Residues: {n_residues}")
print(f"Flattened channels: {in_channels}")

## Step 7: Create Model

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}\n")

# Create diffusion scheduler
diffusion = SimpleDDPM(
    timesteps=protein_config.diffusion.timesteps,
    beta_start=protein_config.diffusion.beta_start,
    beta_end=protein_config.diffusion.beta_end
).to(device)

print(f"✓ Diffusion: {protein_config.diffusion.timesteps} steps")

# Create model
model = SimpleDenoiseModel(
    in_channels=in_channels,
    hidden_dim=protein_config.model.hidden_dim,
    time_emb_dim=protein_config.model.time_emb_dim
)

n_params = sum(p.numel() for p in model.parameters())
print(f"✓ Model: {n_params:,} parameters (~{n_params*4/1024/1024:.1f} MB)")

## Step 8: Train

In [None]:
print("\n" + "="*80)
print(f"Training on {PROTEIN_FULL_NAME}")
print("="*80 + "\n")

save_dir = f'checkpoints/{prot_cfg.name}_ddpm'

train_ddpm(
    dataset=train_dataset,
    model=model,
    diffusion=diffusion,
    device=device,
    batch_size=protein_config.training.batch_size,
    num_epochs=protein_config.training.num_epochs,
    lr=protein_config.training.learning_rate,
    save_dir=save_dir,
    save_every=protein_config.training.save_every
)

print("\n" + "="*80)
print("✓ Training complete!")
print("="*80)

## Step 9: Evaluate

In [None]:
# Load best checkpoint
import glob

ckpts = sorted(glob.glob(f"{save_dir}/*.pt"))
if ckpts:
    model, epoch, _ = load_checkpoint(ckpts[-1], model, device)
    print(f"✓ Loaded checkpoint from epoch {epoch}")

# Test denoising
print(f"\nTesting denoising on {protein_config.inference.num_samples} validation frames...\n")

results = test_with_dataset(
    model, diffusion, val_dataset, device,
    num_samples=protein_config.inference.num_samples
)

# Statistics
mse_values = [r['mse'] for r in results]
print(f"\n{'='*80}")
print(f"Results for {PROTEIN_FULL_NAME}:")
print(f"  Mean MSE: {np.mean(mse_values):.6f}")
print(f"  Std MSE:  {np.std(mse_values):.6f}")
print(f"  Min MSE:  {np.min(mse_values):.6f}")
print(f"  Max MSE:  {np.max(mse_values):.6f}")
print("="*80)

## Step 10: Generate Samples

In [None]:
# Generate new conformations from noise
print("Generating new samples from noise...\n")

shape = (3, n_residues, in_channels)
generated = sample_from_noise(
    model, diffusion, shape, device,
    num_steps=protein_config.inference.denoise_steps
)

# Save
output_dir = f'outputs/{prot_cfg.name}_ddpm'
os.makedirs(output_dir, exist_ok=True)

for i in range(3):
    path = f'{output_dir}/generated_sample_{i}.npy'
    np.save(path, generated[i].cpu().numpy())
    print(f"  Saved: {path}")

# Save test results
for i, r in enumerate(results):
    np.save(f'{output_dir}/test_{i}_original.npy', r['original'])
    np.save(f'{output_dir}/test_{i}_denoised.npy', r['denoised'])

print(f"\n✓ Results saved to {output_dir}")

## Step 11: Visualize

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 4))

plt.subplot(1, 3, 1)
plt.bar(range(len(mse_values)), mse_values, color='steelblue')
plt.xlabel('Sample')
plt.ylabel('MSE')
plt.title(f'Reconstruction Error\n{PROTEIN_FULL_NAME}')
plt.grid(alpha=0.3)

plt.subplot(1, 3, 2)
plt.hist(mse_values, bins=10, edgecolor='black', color='steelblue')
plt.xlabel('MSE')
plt.ylabel('Count')
plt.title('MSE Distribution')
plt.grid(alpha=0.3)

plt.subplot(1, 3, 3)
plt.plot(mse_values, 'o-', color='steelblue')
plt.axhline(np.mean(mse_values), color='red', linestyle='--', label='Mean')
plt.xlabel('Sample')
plt.ylabel('MSE')
plt.title('MSE Trend')
plt.legend()
plt.grid(alpha=0.3)

plt.tight_layout()
plt.savefig(f'{output_dir}/analysis.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"✓ Saved: {output_dir}/analysis.png")

## Step 12: Download Results

In [None]:
# Package results
!zip -rq {prot_cfg.name}_results.zip {save_dir} {output_dir}

print(f"\nResults packaged: {prot_cfg.name}_results.zip")
print(f"  Checkpoints: {save_dir}/")
print(f"  Outputs: {output_dir}/")

if IN_COLAB:
    from google.colab import files
    files.download(f'{prot_cfg.name}_results.zip')
    print("\n✓ Download started")
else:
    print(f"\n✓ Saved locally as {prot_cfg.name}_results.zip")

## Summary

**What we did:**
1. ✓ Configured protein: `{protein_config.protein.name}_R{protein_config.protein.replica}`
2. ✓ Created {protein_config.protein.num_frames} frames of trajectory data
3. ✓ Trained DDPM for {protein_config.training.num_epochs} epochs
4. ✓ Evaluated denoising on validation frames
5. ✓ Generated new conformations from noise

**Key insight:** The model learned the conformational space of **one specific protein** by training on different frames from its MD trajectory.

**To train on a different protein:**
- Change `protein_config.protein.name`
- Rerun from Step 3 onwards