# Sunfish Masked Diffusion — Colab TPU Training

This notebook is TPU-only. In Colab: **Runtime → Change runtime type → TPU**.

In [None]:
import os
tpu_addr = os.environ.get("COLAB_TPU_ADDR")
print("COLAB_TPU_ADDR:", tpu_addr)

try:
    import torch_xla
    import torch_xla.core.xla_model as xm
    print("XLA devices:", xm.get_xla_supported_devices())
except Exception as exc:
    raise SystemExit("torch_xla not available. Switch runtime to TPU and restart.")


## Clone repo

In [None]:
!rm -rf /content/Sunfish
!git clone https://github.com/Sculptor-AI/Sunfish /content/Sunfish

## Install dependencies (skip torch on TPU)

In [None]:
%cd /content/Sunfish
!pip -q install "pytorch-lightning>=2.1.0" "transformers>=4.40.0" "accelerate>=0.27.0" "datasets>=2.16.0" "numpy>=1.24.0" "tqdm>=4.66.0" "wandb>=0.16.0" "tensorboard>=2.15.0"

## (Optional) Mount Google Drive for checkpoints

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## (Optional) Persist checkpoints to Drive
This replaces `checkpoints/` with a Drive-backed folder.

In [None]:
import os
import shutil
drive_root = "/content/drive/MyDrive"
drive_ckpt = f"{drive_root}/sunfish_checkpoints"
ckpt_path = "/content/Sunfish/checkpoints"
if os.path.isdir(drive_root):
    os.makedirs(drive_ckpt, exist_ok=True)
    if os.path.islink(ckpt_path):
        os.unlink(ckpt_path)
    elif os.path.exists(ckpt_path):
        shutil.rmtree(ckpt_path)
    os.symlink(drive_ckpt, ckpt_path)
    print(f"Checkpoints -> {drive_ckpt}")
else:
    print("Drive not mounted, skipping checkpoint symlink.")


## Train on TPU (OpenWebText)
Adjust `--max-steps` to fit your Colab runtime.

In [None]:
%cd /content/Sunfish
!python train_masked.py --tpu --dataset openwebtext --max-steps 10000 --checkpoint-every 1000 --overwrite-last --accumulate 16 --num-workers 0 --save-top-k 2 --name colab-tpu-owt

## Resume training

In [None]:
%cd /content/Sunfish
!python train_masked.py --tpu --resume checkpoints/masked/last.ckpt --dataset openwebtext --max-steps 20000 --checkpoint-every 1000 --overwrite-last --accumulate 16 --num-workers 0 --save-top-k 2 --name colab-tpu-owt

## Sample (runs on CPU in TPU runtime)

In [None]:
%cd /content/Sunfish
!python sample_masked.py checkpoints/masked/last.ckpt --mode infill --text "Q: The opposite of hot is [MASK]." --infill-len 1 --num-steps 150 --temperature 0.6 --top-k 10