# Train Equalized Div-Free Noise DDPM on Colab

Trains the `fwd_diff_eq_divfree` experiment — spectrally-equalized divergence-free noise
that fixes the low-frequency spectral gap in standard div-free noise.

**Prerequisites:**
1. Push latest code to GitHub (including the equalized noise class)
2. Upload `stjohn_hourly_5m_velocity_ramhead_v2.mat` (942 MB) to Google Drive
   - Put it in: `My Drive/research_data/rams_head/`
3. Upload `boundaries.yaml` to the same Drive folder

**Runtime:** Select GPU runtime (Runtime → Change runtime type → T4 GPU)

**Estimated time:** ~12-15 hours for 1000 epochs on T4

## 1. Setup: Mount Drive, Clone Repo, Install Deps

In [None]:
# Mount Google Drive (for data + saving checkpoints)
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Clone the repo
import os
REPO_URL = 'https://github.com/DrCaley/difussionInpaintingVectorFields.git'
REPO_DIR = '/content/diffusionInpaintingVectorFields'

if os.path.exists(REPO_DIR):
    print(f'Repo already cloned at {REPO_DIR}, pulling latest...')
    !cd {REPO_DIR} && git pull
else:
    !git clone {REPO_URL} {REPO_DIR}

%cd {REPO_DIR}
!pwd

In [None]:
# Install dependencies (Colab already has torch, numpy, matplotlib)
!pip install -q tqdm pyyaml scipy gpytorch

In [None]:
# Symlink data from Google Drive into the expected location
DRIVE_DATA = '/content/drive/MyDrive/research_data/rams_head'
LOCAL_DATA = f'{REPO_DIR}/data/rams_head'

os.makedirs(LOCAL_DATA, exist_ok=True)

# Symlink the .mat file (942 MB — don't copy, just link)
mat_src = f'{DRIVE_DATA}/stjohn_hourly_5m_velocity_ramhead_v2.mat'
mat_dst = f'{LOCAL_DATA}/stjohn_hourly_5m_velocity_ramhead_v2.mat'
bounds_src = f'{DRIVE_DATA}/boundaries.yaml'
bounds_dst = f'{LOCAL_DATA}/boundaries.yaml'

for src, dst in [(mat_src, mat_dst), (bounds_src, bounds_dst)]:
    if not os.path.exists(dst):
        assert os.path.exists(src), f'Missing: {src}\nUpload to Google Drive first!'
        os.symlink(src, dst)
        print(f'Linked {dst} → {src}')
    else:
        print(f'Already exists: {dst}')

!ls -lh {LOCAL_DATA}/

In [None]:
# Verify GPU is available
import torch
print(f'CUDA available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB')
else:
    print('WARNING: No GPU! Go to Runtime → Change runtime type → T4 GPU')

## 2. Validate Experiment Config

In [None]:
# Dry-run to validate config
!PYTHONPATH=. python experiments/run_experiment.py \
    --dry-run experiments/01_noise_strategy/fwd_divfree_equalized/config.yaml

## 3. Quick Smoke Test (3 epochs)

In [None]:
# Smoke test — verify everything works on GPU before committing to full run
!PYTHONPATH=. python experiments/run_experiment.py \
    --smoke experiments/01_noise_strategy/fwd_divfree_equalized/config.yaml

## 4. Full Training (1000 epochs)

**Important:** Colab may disconnect after ~4-12 hours depending on your plan.
The training saves best checkpoint automatically, so you can resume if interrupted.

Tips to avoid disconnection:
- Keep the browser tab open and active
- Colab Pro gives longer runtime (~24h)
- Checkpoints are saved to the experiment results/ folder

In [None]:
# Full training run
!PYTHONPATH=. python experiments/run_experiment.py \
    experiments/01_noise_strategy/fwd_divfree_equalized/config.yaml

## 5. Save Results to Google Drive

Copy checkpoints to Drive so they survive Colab shutdown.

In [None]:
import shutil

SRC = f'{REPO_DIR}/experiments/01_noise_strategy/fwd_divfree_equalized/results'
DST = '/content/drive/MyDrive/research_data/training_results/fwd_divfree_equalized'

os.makedirs(DST, exist_ok=True)

# Copy all checkpoint and log files
copied = 0
for f in os.listdir(SRC):
    if f.endswith(('.pt', '.yaml', '.csv', '.png')):
        shutil.copy2(os.path.join(SRC, f), os.path.join(DST, f))
        print(f'  Copied: {f}')
        copied += 1

print(f'\nCopied {copied} files to {DST}')
!ls -lh {DST}/

## 6. Resume Training (if interrupted)

If Colab disconnected, re-run cells 1-2 (mount + clone), then run this cell.
It resumes from the best checkpoint saved in the results folder.

In [None]:
# First, restore checkpoint from Drive if results folder is empty
SRC_DRIVE = '/content/drive/MyDrive/research_data/training_results/fwd_divfree_equalized'
DST_LOCAL = f'{REPO_DIR}/experiments/01_noise_strategy/fwd_divfree_equalized/results'
os.makedirs(DST_LOCAL, exist_ok=True)

if os.path.exists(SRC_DRIVE):
    for f in os.listdir(SRC_DRIVE):
        if f.endswith('.pt'):
            dst_path = os.path.join(DST_LOCAL, f)
            if not os.path.exists(dst_path):
                shutil.copy2(os.path.join(SRC_DRIVE, f), dst_path)
                print(f'  Restored: {f}')

# Find best checkpoint
ckpt = None
for f in os.listdir(DST_LOCAL):
    if 'best_checkpoint' in f and f.endswith('.pt'):
        ckpt = os.path.join(DST_LOCAL, f)
        break

if ckpt:
    print(f'Resuming from: {ckpt}')
    # Create a temporary resume config
    resume_yaml = os.path.join(DST_LOCAL, 'resume_config.yaml')
    with open(resume_yaml, 'w') as f:
        f.write(f"""# Auto-generated resume config
model_name: fwd_diff_eq_divfree_eps_t250
noise_function: fwd_diff_eq_divfree
unet_type: standard
prediction_target: eps
mask_xt: false
p_uncond: 0.0
retrain_mode: true
model_to_retrain: {ckpt}
reset_best: false
""")
    !PYTHONPATH=. python experiments/run_experiment.py {resume_yaml}
else:
    print('No checkpoint found — starting fresh')
    !PYTHONPATH=. python experiments/run_experiment.py \
        experiments/01_noise_strategy/fwd_divfree_equalized/config.yaml

## 7. Download Best Checkpoint Locally

After training completes, download the checkpoint to use on your local machine.

In [None]:
from google.colab import files

RESULTS = f'{REPO_DIR}/experiments/01_noise_strategy/fwd_divfree_equalized/results'
for f in os.listdir(RESULTS):
    if 'best_checkpoint' in f and f.endswith('.pt'):
        print(f'Downloading: {f}')
        files.download(os.path.join(RESULTS, f))
        break