# ARC-AGI PPO Training on Colab Pro+ (A100)

This notebook trains a PerceiverActorCritic model (~4.8M params) on the ARC-AGI-2 dataset using PPO.

**Before running:**
1. Go to `Runtime` → `Change runtime type`
2. Set **Hardware accelerator** to **GPU**
3. Select **A100 GPU** (Colab Pro+ required)
4. Click **Save**

Then run all cells in order (`Runtime` → `Run all`).


## 1. Install Dependencies


In [None]:
%%capture
!sudo apt-get update
%pip install -U pip
%pip install "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
%pip install flax optax orbax-checkpoint tensorstore imageio einops matplotlib pillow


## 2. Verify GPU is Available


In [None]:
import jax
import jax.numpy as jnp

print(f"JAX version: {jax.__version__}")
print(f"JAX devices: {jax.devices()}")
print(f"Device type: {jax.devices()[0].device_kind}")
print(f"Device platform: {jax.devices()[0].platform}")

# Quick GPU test
x = jax.random.normal(jax.random.PRNGKey(0), (1000, 1000))
result = jnp.dot(x, x).block_until_ready()
print("\n✅ GPU is working!")

# Check GPU memory
!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv


## 3. Clone Repository


In [None]:
!git clone https://github.com/Maharishiva/ArcX.git
%cd ArcX
!git status


## 4. Prepare Simple Datasets (Optional)

Fetch the latest ARC-AGI-2 data, copy it into this repo, and generate intentionally easy `training_simple` / `val_simple` splits.


In [None]:
%%bash
set -euo pipefail
cd /content/ArcX
mkdir -p data/training data/evaluation
if [ ! -d /content/arc-agi-2 ]; then
  git clone https://github.com/arcprize/ARC-AGI-2.git /content/arc-agi-2
else
  echo 'ARC-AGI-2 repository already exists; pulling latest changes.'
  git -C /content/arc-agi-2 pull --ff-only
fi
cp -f /content/arc-agi-2/data/training/*.json data/training/
cp -f /content/arc-agi-2/data/evaluation/*.json data/evaluation/
echo 'training count:'
ls -1 data/training | wc -l
echo 'evaluation count:'
ls -1 data/evaluation | wc -l
python scripts/make_simple_datasets.py
ls data


## 5. Quick Smoke Test (Optional, ~2 minutes)

Run this to verify everything is working before starting a long training run.


In [None]:
!PYTHONPATH=. python scripts/ppo_train.py \
    --preset debug \
    --device cuda \
    --total-updates 5 \
    --num-envs 4 \
    --rollout-length 8


## 6. Main Training Run

Choose one of the options below based on how long you want to train:

- **Quick test (~30 min):** 100 updates
- **Medium run (~2-3 hours):** 500 updates  
- **Full training (~6-8 hours):** 1000+ updates

Uncomment the option you want to run.


In [None]:
# Option A: Quick test run (~30 minutes)
!PYTHONPATH=. python scripts/ppo_train.py \
    --preset a100 \
    --device cuda \
    --total-updates 100 \
    --eval-interval 20 \
    --checkpoint-interval 25


In [None]:
# Option B: Medium run (~2-3 hours) - RECOMMENDED FOR FIRST RUN
# !PYTHONPATH=. python scripts/ppo_train.py \
#     --preset a100 \
#     --device cuda \
#     --total-updates 500 \
#     --eval-interval 25 \
#     --checkpoint-interval 50


In [None]:
# Option C: Full training run (~6-8 hours)
# !PYTHONPATH=. python scripts/ppo_train.py \
#     --preset a100 \
#     --device cuda \
#     --total-updates 1000 \
#     --eval-interval 50 \
#     --checkpoint-interval 50


## 7. List Available Checkpoints


In [None]:
import os
from pathlib import Path

checkpoint_dir = Path("checkpoints/ppo_a100")
if checkpoint_dir.exists():
    checkpoints = sorted([d for d in checkpoint_dir.iterdir() if d.is_dir()])
    print(f"Found {len(checkpoints)} checkpoints:")
    for ckpt in checkpoints:
        size_mb = sum(f.stat().st_size for f in ckpt.rglob('*') if f.is_file()) / 1024 / 1024
        print(f"  - {ckpt.name} ({size_mb:.1f} MB)")
else:
    print("No checkpoints found yet. Run training first!")


## 8. Evaluate Model & Generate GIFs

This will create GIF visualizations of the trained model solving ARC puzzles.


In [None]:
# Find the latest checkpoint
checkpoint_dir = Path("checkpoints/ppo_a100")
checkpoints = sorted([d for d in checkpoint_dir.iterdir() if d.is_dir()])
latest_checkpoint = checkpoints[-1] if checkpoints else None

if latest_checkpoint:
    print(f"Evaluating checkpoint: {latest_checkpoint}")
    !PYTHONPATH=. python scripts/ppo_eval_viz.py \
        --checkpoint {latest_checkpoint} \
        --device cuda \
        --num-episodes 10 \
        --rollout-horizon 100 \
        --output-dir artifacts/eval_viz
else:
    print("No checkpoints found. Run training first!")


## 9. Display Generated GIFs

View the generated GIFs directly in the notebook.


In [None]:
from IPython.display import Image as IPImage, display
from pathlib import Path

output_dir = Path("artifacts/eval_viz")
if output_dir.exists():
    gifs = sorted(output_dir.glob("*.gif"))
    print(f"Found {len(gifs)} GIFs:\n")
    
    for gif in gifs[:5]:  # Show first 5
        print(f"Episode: {gif.name}")
        display(IPImage(filename=str(gif)))
        print("\n" + "="*80 + "\n")
else:
    print("No GIFs found. Run evaluation first!")


## 10. Download Results

Download GIFs and checkpoints to your local machine.


In [None]:
# Zip and download GIFs
!zip -r eval_results.zip artifacts/eval_viz/ 2>/dev/null || echo "Creating zip..."

from google.colab import files
if Path("eval_results.zip").exists():
    print("Downloading GIFs...")
    files.download('eval_results.zip')
else:
    print("No results to download yet.")


## 11. Advanced: Mount Google Drive (Optional)

Save checkpoints directly to Google Drive to avoid losing them if Colab disconnects.


In [None]:
# Uncomment to enable Google Drive backup:

# from google.colab import drive
# drive.mount('/content/drive')

# # Create directory in Drive
# !mkdir -p /content/drive/MyDrive/arc_agi_checkpoints

# # Run training with Drive checkpoint directory
# !PYTHONPATH=. python scripts/ppo_train.py \
#     --preset a100 \
#     --device cuda \
#     --total-updates 1000 \
#     --eval-interval 50 \
#     --checkpoint-dir /content/drive/MyDrive/arc_agi_checkpoints/ppo_a100


---

## 📊 Understanding the Output

During training, you'll see logs like:
```
Update 50/1000 | Train: return=0.234, steps=45.2 | Test: return=0.189, steps=38.5
Checkpoint saved: checkpoints/ppo_a100/ppo_50
```

- **return**: Average reward (higher is better, max ~2.5)
- **steps**: Average episode length
- **Train vs Test**: Train uses training tasks, Test uses held-out evaluation tasks

## 🎯 What to Expect

- **Random policy**: return ~0.0-0.2
- **Learning progress**: return should increase to 0.5-1.0 over hundreds of updates
- **Good performance**: return > 1.0 indicates the model is learning to match targets

## 🔧 Troubleshooting

**If GPU isn't detected:**
- Check runtime type is set to GPU
- Restart runtime and re-run dependency installation

**If training is slow:**
- Verify you're using A100 GPU (not T4)
- Check GPU utilization with `!nvidia-smi`

**If Colab disconnects:**
- Use Google Drive mounting (Section 10)
- Download checkpoints periodically
- Keep browser tab active

## 📚 Documentation

- [Repository](https://github.com/Maharishiva/ArcX)
- [PPO Training Overview](https://github.com/Maharishiva/ArcX/blob/main/docs/ppo_training_overview.md)
- [ARC-AGI Challenge](https://arcprize.org/)

---

**Happy training! 🚀**
