# Sunfish Masked Diffusion — Colab Training

This notebook installs dependencies, sets up storage, and runs masked-diffusion training on a Colab GPU.

In [None]:
# Check GPU
!nvidia-smi -L

## (Optional) Mount Google Drive
This lets you keep checkpoints across Colab sessions.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Clone repo (or upload manually)
If you already uploaded the repo, skip this cell.

In [None]:
!rm -rf /content/Sunfish
!git clone https://github.com/Sculptor-AI/Sunfish /content/Sunfish

## Install dependencies

In [None]:
%cd /content/Sunfish
!pip -q install -r requirements.txt

## (Optional) Persist checkpoints to Drive
This replaces `checkpoints/` with a Drive-backed folder.

In [None]:
import os
import shutil
drive_root = "/content/drive/MyDrive"
drive_ckpt = f"{drive_root}/sunfish_checkpoints"
ckpt_path = "/content/Sunfish/checkpoints"
if os.path.isdir(drive_root):
    os.makedirs(drive_ckpt, exist_ok=True)
    if os.path.islink(ckpt_path):
        os.unlink(ckpt_path)
    elif os.path.exists(ckpt_path):
        shutil.rmtree(ckpt_path)
    os.symlink(drive_ckpt, ckpt_path)
    print(f"Checkpoints -> {drive_ckpt}")
else:
    print("Drive not mounted, skipping checkpoint symlink.")


## Train (OpenWebText)
Adjust `--max-steps` to fit your Colab runtime.

In [None]:
%cd /content/Sunfish
!python train_masked.py --dataset openwebtext --max-steps 10000 --checkpoint-every 1000 --overwrite-last --accumulate 16 --num-workers 0 --save-top-k 2 --name colab-owt

## Resume training
Use this if you restart the Colab session and want to continue.

In [None]:
%cd /content/Sunfish
!python train_masked.py --resume checkpoints/masked/last.ckpt --dataset openwebtext --max-steps 20000 --checkpoint-every 1000 --overwrite-last --accumulate 16 --num-workers 0 --save-top-k 2 --name colab-owt

## Sample (prompted generation)

In [None]:
%cd /content/Sunfish
!python sample_masked.py checkpoints/masked/last.ckpt --num-samples 1 --seq-len 128 --num-steps 150 --temperature 0.7 --top-k 20 --top-p 0.95 --prompt "Once upon a time"

## Sample (infill)

In [None]:
%cd /content/Sunfish
!python sample_masked.py checkpoints/masked/last.ckpt --mode infill --text "Q: The opposite of hot is [MASK]." --infill-len 1 --num-steps 150 --temperature 0.6 --top-k 10