# 🧬 Molecular Cross-Temperature Transport - Google Colab

**A100 GPU Training**

Run cells in order. Total time: ~1-2 hours for 500 epochs.

⚠️ **Make sure A100 GPU is enabled:** Runtime → Change runtime type → A100 GPU

---

In [None]:
# Cell 1: Check GPU
import torch
print(f'🎮 GPU: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'   Device: {torch.cuda.get_device_name(0)}')
    print(f'   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')
    print('\n✅ Make sure it says A100!')

In [None]:
# Cell 2: Clone Repository
!git clone https://github.com/antoniofrancaib/tarflow-pt-molecular.git
%cd tarflow-pt-molecular
!git pull
!ls -lh datasets/AA/pt_AA.pt

In [None]:
# Cell 3: Install Miniforge
import os

print('📦 Installing Miniforge23...')
!wget -q https://github.com/conda-forge/miniforge/releases/download/23.11.0-0/Miniforge3-Linux-x86_64.sh
!bash Miniforge3-Linux-x86_64.sh -b -p $HOME/miniforge3

os.environ['PATH'] = f"{os.environ['HOME']}/miniforge3/bin:" + os.environ['PATH']

print('📦 Installing OpenMM, MDTraj, PyYAML...')
!$HOME/miniforge3/bin/mamba install -c conda-forge openmm openmmtools mdtraj pyyaml -y

# Fix imports
!sed -i 's/from simtk import openmm as mm/import openmm as mm/g' src/training/openmm_energy.py
!sed -i 's/from simtk import unit/from openmm import unit/g' src/training/openmm_energy.py
!sed -i 's/from simtk.openmm import app/from openmm import app/g' src/training/openmm_energy.py

print('\n✅ Miniforge + OpenMM installed!')

In [None]:
# Cell 4: Install PyTorch + Dependencies
print('📦 Installing PyTorch and dependencies...')
!$HOME/miniforge3/bin/pip install torch matplotlib scipy scikit-learn tqdm nglview --quiet

print('\n✅ All dependencies installed!')

In [None]:
# Cell 5: Fix Scheduler
print('🔧 Fixing scheduler parameter...')
!sed -i 's/, verbose=True//g' src/training/molecular_pt_trainer.py
print('✅ Fixed!')

In [None]:
# Cell 6: Test OpenMM Energy
print('🧪 Testing OpenMM...')

test_script = '''import os
os.environ["MPLBACKEND"] = "Agg"
import sys
sys.path.insert(0, ".")
import torch, openmm
from src.training.openmm_energy import compute_potential_energy

data = torch.load("datasets/AA/pt_AA.pt", weights_only=False)
test_coords = data[0, 0, :5, :]
energies = compute_potential_energy(test_coords)

print(f"✅ Energy computation works!")
print(f"   Sample energies (kJ/mol): {[f'{e:.2f}' for e in energies.tolist()]}")
print(f"   Mean: {energies.mean():.2f} kJ/mol")
'''

with open('/tmp/test_openmm.py', 'w') as f:
    f.write(test_script)

!MPLBACKEND=Agg $HOME/miniforge3/bin/python /tmp/test_openmm.py

In [None]:
# Cell 7: Train on A100 (500 epochs, ~1-2 hours)
print('🚀 Training on A100 GPU (500 epochs)')
print('⏱️  Time: ~1-2 hours')
print('💰 Compute units: ~15-20 (15-20% of quota)\n')

!MPLBACKEND=Agg $HOME/miniforge3/bin/python main.py train-molecular \
    --preset aa_300_450 \
    --epochs 500 \
    --lr 5e-4 \
    --batch-size 32 \
    --validate

print('\n✅ Training complete!')

In [None]:
# Cell 8: Display Results
from IPython.display import Image, display
import os

print('📊 Results:\n')

loss_path = 'checkpoints/molecular_pt_aa_300_450/loss_curves_300_450.png'
if os.path.exists(loss_path):
    print('📈 Loss Curves:')
    display(Image(loss_path, width=800))

rama_path = 'plots/molecular_pt_aa_300_450/ramachandran_300_450.png'
if os.path.exists(rama_path):
    print('\n🧬 Ramachandran:')
    display(Image(rama_path, width=900))

energy_path = 'plots/molecular_pt_aa_300_450/energy_validation_300_450.png'
if os.path.exists(energy_path):
    print('\n⚡ Energy:')
    display(Image(energy_path, width=800))

In [None]:
# Cell 9: Download Results
!zip -r molecular_pt_results.zip checkpoints/molecular_pt_aa_300_450 plots/molecular_pt_aa_300_450 -q

from google.colab import files
files.download('molecular_pt_results.zip')

print('✅ Download complete!')