# Train Attention UNet (MyUNet_Attn) on Colab

Trains the `repaint_gaussian_attn` experiment — Guided-Diffusion-style UNet with
self-attention for ocean velocity inpainting via RePaint.

**Architecture:** MyUNet_Attn (23.1M params)
- ResBlocks with AdaGN time conditioning + residual skip connections
- Multi-head self-attention at 16×32, 8×16, and 4×8 resolutions
- Channel schedule: 64 → 128 → 256 → 256, bottleneck 4×8

**Prerequisites:**
1. Push latest code to GitHub (including `unet_xl_attn.py`)
2. Upload `stjohn_hourly_5m_velocity_ramhead_v2.mat` (942 MB) to Google Drive
   - Put it in: `My Drive/Ocean Inpainting/`
3. Upload `boundaries.yaml` to the same Drive folder

**Runtime:** Select GPU runtime (Runtime → Change runtime type → A100 or V100)

**Estimated time:** ~4-6 hours for 1000 epochs on A100 (Colab Pro)

## 1. Setup: Mount Drive, Clone Repo, Install Deps

In [None]:
# Mount Google Drive (for data + saving checkpoints)
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Clone the repo
import os
REPO_URL = 'https://github.com/DrCaley/difussionInpaintingVectorFields.git'
REPO_DIR = '/content/diffusionInpaintingVectorFields'

if os.path.exists(REPO_DIR):
    print(f'Repo already cloned at {REPO_DIR}, pulling latest...')
    !cd {REPO_DIR} && git pull
else:
    !git clone {REPO_URL} {REPO_DIR}

%cd {REPO_DIR}
!pwd

In [None]:
# Install dependencies (Colab already has torch, numpy, matplotlib)
!pip install -q tqdm pyyaml scipy gpytorch

In [None]:
# Symlink data from Google Drive into the expected location
DRIVE_DATA = '/content/drive/MyDrive/Ocean Inpainting'
LOCAL_DATA = f'{REPO_DIR}/data/rams_head'

os.makedirs(LOCAL_DATA, exist_ok=True)

# Symlink the .mat file (942 MB — don't copy, just link)
mat_src = f'{DRIVE_DATA}/stjohn_hourly_5m_velocity_ramhead_v2.mat'
mat_dst = f'{LOCAL_DATA}/stjohn_hourly_5m_velocity_ramhead_v2.mat'
bounds_src = f'{DRIVE_DATA}/boundaries.yaml'
bounds_dst = f'{LOCAL_DATA}/boundaries.yaml'

for src, dst in [(mat_src, mat_dst), (bounds_src, bounds_dst)]:
    if not os.path.exists(dst):
        assert os.path.exists(src), f'Missing: {src}\nUpload to Google Drive first!'
        os.symlink(src, dst)
        print(f'Linked {dst} → {src}')
    else:
        print(f'Already exists: {dst}')

!ls -lh {LOCAL_DATA}/

In [None]:
# Generate the data.pickle file (train/val/test split from raw .mat)
# This is required by DDInitializer and only needs to run once per Colab session
import os
PICKLE_PATH = f'{REPO_DIR}/data.pickle'

if not os.path.exists(PICKLE_PATH):
    print('Generating data.pickle from .mat file...')
    %cd {REPO_DIR}
    !python data_prep/spliting_data_sets.py
    assert os.path.exists(PICKLE_PATH), 'data.pickle was not created!'
    print(f'Created: {PICKLE_PATH}')
else:
    print(f'data.pickle already exists: {PICKLE_PATH}')

!ls -lh {PICKLE_PATH}

In [None]:
# Verify GPU is available and check memory
import torch
print(f'CUDA available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f'Memory: {mem_gb:.1f} GB')
    # Recommend batch size based on available GPU memory
    if mem_gb >= 40:     # A100
        rec_bs = 128
    elif mem_gb >= 16:   # V100 / T4
        rec_bs = 64
    else:
        rec_bs = 32
    print(f'Recommended batch_size for this GPU: {rec_bs}')
else:
    print('WARNING: No GPU! Go to Runtime → Change runtime type → GPU')

## 2. Validate Experiment Config

In [None]:
# Dry-run to validate config
!PYTHONPATH=. python experiments/run_experiment.py \
    --dry-run experiments/02_inpaint_algorithm/repaint_gaussian_attn/config.yaml

## 3. Create Colab-Optimized Config

The repo config uses `batch_size: 16` (for local MPS training).
This cell writes a Colab-specific override with GPU-appropriate batch size.

Adjust `COLAB_BATCH_SIZE` below if needed (64 for T4/V100, 128 for A100).

In [None]:
# ── Adjust this based on your GPU (see cell above) ──
COLAB_BATCH_SIZE = 64   # 64 for T4/V100 (16 GB), 128 for A100 (40 GB)

# Write a Colab-specific config that overrides batch_size
COLAB_CFG = f'{REPO_DIR}/experiments/02_inpaint_algorithm/repaint_gaussian_attn/colab_config.yaml'

with open(COLAB_CFG, 'w') as f:
    f.write(f"""# Auto-generated Colab config — overrides batch_size for GPU
model_name: repaint_gaussian_attn_eps_t250
noise_function: gaussian
unet_type: standard_attn
prediction_target: eps
mask_xt: false
batch_size: {COLAB_BATCH_SIZE}
lr: 0.0003
max_grad_norm: 1.0
""")

print(f'Wrote Colab config with batch_size={COLAB_BATCH_SIZE}')
print(f'  → {COLAB_CFG}')
!cat {COLAB_CFG}

## 4. Quick Smoke Test (3 epochs)

In [None]:
# Smoke test — verify everything works on GPU before committing to full run
!PYTHONPATH=. python experiments/run_experiment.py \
    --smoke {COLAB_CFG}

## 5. Full Training (1000 epochs)

Uses `batch_size: 64` (or 128 on A100), `lr: 0.0003`, `max_grad_norm: 1.0`.

With Colab Pro (A100), this should take ~4-6 hours for 1000 epochs.
With T4, expect ~10-12 hours.

Checkpoints are saved automatically (best + periodic).

**Tip:** Keep the browser tab active to avoid disconnection.

In [None]:
# Full training run
!PYTHONPATH=. python experiments/run_experiment.py {COLAB_CFG}

## 6. Save Results to Google Drive

Copy checkpoints to Drive so they survive Colab shutdown.

In [None]:
import shutil

SRC = f'{REPO_DIR}/experiments/02_inpaint_algorithm/repaint_gaussian_attn/results'
DST = '/content/drive/MyDrive/Ocean Inpainting/training_results/repaint_gaussian_attn'

os.makedirs(DST, exist_ok=True)

# Copy all checkpoint and log files
copied = 0
for f in os.listdir(SRC):
    if f.endswith(('.pt', '.yaml', '.csv', '.png')):
        shutil.copy2(os.path.join(SRC, f), os.path.join(DST, f))
        print(f'  Copied: {f}')
        copied += 1

print(f'\nCopied {copied} files to {DST}')
!ls -lh {DST}/

## 7. Resume Training (if interrupted)

If Colab disconnected, re-run cells 1-3 (mount + clone + colab config), then run this cell.
It resumes from the best checkpoint saved on Google Drive.

In [None]:
import shutil

# First, restore checkpoint from Drive if results folder is empty
SRC_DRIVE = '/content/drive/MyDrive/Ocean Inpainting/training_results/repaint_gaussian_attn'
DST_LOCAL = f'{REPO_DIR}/experiments/02_inpaint_algorithm/repaint_gaussian_attn/results'
os.makedirs(DST_LOCAL, exist_ok=True)

if os.path.exists(SRC_DRIVE):
    for f in os.listdir(SRC_DRIVE):
        if f.endswith('.pt'):
            dst_path = os.path.join(DST_LOCAL, f)
            if not os.path.exists(dst_path):
                shutil.copy2(os.path.join(SRC_DRIVE, f), dst_path)
                print(f'  Restored: {f}')

# Find best checkpoint
ckpt = None
for f in os.listdir(DST_LOCAL):
    if 'best_checkpoint' in f and f.endswith('.pt'):
        ckpt = os.path.join(DST_LOCAL, f)
        break

if ckpt:
    print(f'Resuming from: {ckpt}')
    # Create a temporary resume config
    resume_yaml = os.path.join(DST_LOCAL, 'resume_config.yaml')
    with open(resume_yaml, 'w') as f:
        f.write(f"""# Auto-generated resume config
model_name: repaint_gaussian_attn_eps_t250
noise_function: gaussian
unet_type: standard_attn
prediction_target: eps
mask_xt: false
batch_size: 64
lr: 0.0003
max_grad_norm: 1.0
retrain_mode: true
model_to_retrain: {ckpt}
reset_best: false
""")
    !PYTHONPATH=. python experiments/run_experiment.py {resume_yaml}
else:
    print('No checkpoint found — starting fresh')
    !PYTHONPATH=. python experiments/run_experiment.py \
        experiments/02_inpaint_algorithm/repaint_gaussian_attn/config.yaml

## 8. Download Best Checkpoint Locally

After training completes, download the checkpoint to use on your local machine.

In [None]:
from google.colab import files

RESULTS = f'{REPO_DIR}/experiments/02_inpaint_algorithm/repaint_gaussian_attn/results'
for f in os.listdir(RESULTS):
    if 'best_checkpoint' in f and f.endswith('.pt'):
        print(f'Downloading: {f}')
        files.download(os.path.join(RESULTS, f))
        break