# 🧬 Molecular Cross-Temperature Transport Training

**Local Mac Training - Memory Efficient**

Run cells sequentially. Estimated time: 2-4 hours for 500 epochs on CPU.

---

In [None]:
# Cell 1: Check Environment
import sys, torch, platform

print(f'Python: {sys.version}')
print(f'PyTorch: {torch.__version__}')
print(f'Platform: {platform.system()} {platform.machine()}')

try:
    import openmm, mdtraj
    print(f'✅ OpenMM: {openmm.version.short_version}')
    print(f'✅ MDTraj: {mdtraj.__version__}')
except ImportError as e:
    print(f'❌ Missing: {e}')
    print('Run: conda install -c conda-forge openmm openmmtools mdtraj')

In [None]:
# Cell 2: Set Backend & Imports
import os
os.environ['MPLBACKEND'] = 'Agg'

import yaml
import numpy as np
from pathlib import Path

sys.path.insert(0, str(Path.cwd()))

from src.distributions.molecular_pt import MolecularPTDataset
from src.models.scalable_transformer_flow import ScalableTransformerFlow  
from src.training.molecular_pt_trainer import MolecularPTTrainer
from src.training.molecular_validation import MolecularValidator

print('✅ All modules loaded')

In [None]:
# Cell 3: Configuration (ADJUST EPOCHS HERE)
with open('configs/experiments.yaml', 'r') as f:
    config_data = yaml.safe_load(f)

config = config_data['molecular_pt']['aa_300_450']

# Override for local Mac training
config['training']['epochs'] = 100  # ⬅️ CHANGE THIS (100 for testing, 500 for production)
config['training']['batch_size'] = 16  # Small batch for Mac memory
config['training']['eval_interval'] = 25

print(f'📋 Config:')
print(f'  Epochs: {config["training"]["epochs"]}')
print(f'  Batch: {config["training"]["batch_size"]}')
print(f'  LR: {config["training"]["learning_rate"]}')

In [None]:
# Cell 4: Create Dataset
dataset = MolecularPTDataset(
    data_path='datasets/AA/pt_AA.pt',
    source_temp_idx=0,  # 300K
    target_temp_idx=1,  # 450K
    normalize=True,
    normalize_mode='per_atom'
)

print(f'✅ Dataset: {dataset.source_temp}K → {dataset.target_temp}K')
print(f'   Samples: {len(dataset)}')

In [None]:
# Cell 5: Create Model
device = 'cpu'  # Force CPU for Mac stability

model = ScalableTransformerFlow(
    input_dim=69,
    num_flow_layers=8,
    embed_dim=192,
    num_heads=8,
    num_transformer_layers=5,
    dropout=0.1
).to(device)

num_params = sum(p.numel() for p in model.parameters())
print(f'✅ Model: {num_params:,} parameters (~{num_params*4/1e6:.1f} MB)')

In [None]:
# Cell 6: Create Trainer
trainer = MolecularPTTrainer(
    model=model,
    dataset=dataset,
    device=device,
    use_energy=True  # OpenMM energy evaluation
)

print('✅ Trainer initialized')

In [None]:
# Cell 7: TRAIN (⚠️ Takes 20 min - 4 hours depending on epochs)
print(f'🚀 Training for {config["training"]["epochs"]} epochs...')
print('⏱️  Estimated time:')
print(f'   100 epochs: ~20-40 min')
print(f'   500 epochs: ~2-4 hours\n')

os.makedirs('checkpoints/molecular_pt_aa_300_450', exist_ok=True)

try:
    trained_model, history = trainer.train(
        config,
        save_dir='checkpoints/molecular_pt_aa_300_450'
    )
    print('\n✅ Training completed!')
    print(f'   Final loss: {history[-1]["total_loss"]:.4f}')
except KeyboardInterrupt:
    print('\n⚠️ Training interrupted (model saved)')
except Exception as e:
    print(f'\n❌ Error: {e}')

In [None]:
# Cell 8: Validation
print('🔬 Running validation...\n')

validator = MolecularValidator(
    model=trained_model,
    dataset=dataset,
    pdb_path='datasets/AA/ref.pdb',
    device=device
)

metrics = validator.full_validation(
    num_samples=1000,  # Smaller for speed on Mac
    save_dir='plots/molecular_pt_aa_300_450'
)

print('\n✅ Validation complete!')

In [None]:
# Cell 9: View Results
from IPython.display import Image, display

print('📊 Results:\n')

# Loss curves
loss_path = 'checkpoints/molecular_pt_aa_300_450/loss_curves_300_450.png'
if os.path.exists(loss_path):
    print('📈 Loss Curves:')
    display(Image(loss_path))

# Ramachandran
rama_path = 'plots/molecular_pt_aa_300_450/ramachandran_300_450.png'
if os.path.exists(rama_path):
    print('\n🧬 Ramachandran:')
    display(Image(rama_path))

# Energy  
energy_path = 'plots/molecular_pt_aa_300_450/energy_validation_300_450.png'
if os.path.exists(energy_path):
    print('\n⚡ Energy:')
    display(Image(energy_path))

## 📋 Summary

**Files Generated:**
- Model: `checkpoints/molecular_pt_aa_300_450/molecular_pt_300_450.pt`
- Loss curves: `checkpoints/molecular_pt_aa_300_450/loss_curves_300_450.png`
- Ramachandran: `plots/molecular_pt_aa_300_450/ramachandran_300_450.png`
- Energy validation: `plots/molecular_pt_aa_300_450/energy_validation_300_450.png`

**To run full training:** Edit Cell 3, set `epochs = 500`, then run all cells.