# ARC-AGI PPO Training on Colab Pro+ (A100)

This notebook trains a PerceiverActorCritic model (~4.8M params) on the ARC-AGI-2 dataset using PPO.



## 1. Install Dependencies


In [None]:
%%capture
!sudo apt-get update
%pip install -U pip wandb
%pip install "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html wandb
%pip install flax optax orbax-checkpoint tensorstore imageio einops matplotlib pillow wandb


## 2. Verify GPU is Available


In [None]:
import jax
import jax.numpy as jnp

print(f"JAX version: {jax.__version__}")
print(f"JAX devices: {jax.devices()}")
print(f"Device type: {jax.devices()[0].device_kind}")
print(f"Device platform: {jax.devices()[0].platform}")

# Quick GPU test
x = jax.random.normal(jax.random.PRNGKey(0), (1000, 1000))
_ = jnp.dot(x, x).block_until_ready()
print("âœ… GPU is working!")

# Check GPU memory
!nvidia-smi

## 3. Clone Repository


In [None]:
!git clone https://github.com/Maharishiva/ArcX.git
%cd ArcX
!git status


In [None]:
import os, pathlib

repo_root = pathlib.Path.cwd()
current_path = os.environ.get('PYTHONPATH')
paths = [str(repo_root)] + ([current_path] if current_path else [])
os.environ['PYTHONPATH'] = ':'.join(paths)
os.environ.setdefault('FLAX_USE_ORBAX_CHECKPOINTING', '0')
print('PYTHONPATH set to:', os.environ['PYTHONPATH'])
print('FLAX_USE_ORBAX_CHECKPOINTING =', os.environ['FLAX_USE_ORBAX_CHECKPOINTING'])

## 4. Prepare Simple Datasets (Optional)

Fetch the latest ARC-AGI-2 data, copy it into this repo, and generate intentionally easy `training_simple` / `val_simple` splits.


In [None]:
%%bash
set -euo pipefail
cd /content/ArcX
mkdir -p data/training data/evaluation
if [ ! -d /content/arc-agi-2 ]; then
  git clone https://github.com/arcprize/ARC-AGI-2.git /content/arc-agi-2
else
  echo 'ARC-AGI-2 repository already exists; pulling latest changes.'
  git -C /content/arc-agi-2 pull --ff-only
fi
cp -f /content/arc-agi-2/data/training/*.json data/training/
cp -f /content/arc-agi-2/data/evaluation/*.json data/evaluation/
echo 'training count:'
ls -1 data/training | wc -l
echo 'evaluation count:'
ls -1 data/evaluation | wc -l
python scripts/make_simple_datasets.py
ls data


## 5. (Optional) Enable Weights & Biases Logging

Authenticate with wandb so training runs can report metrics.
Add `--wandb-mode online --wandb-project <your_project>` to any training command when you want logging.


In [None]:
import wandb
wandb.login()


## 6. Quick Smoke Test (Optional, ~2 minutes)

Run this to verify everything is working before starting a long training run.


In [None]:
!python scripts/ppo_train.py \
    --preset debug \
    --device cuda \
    --total-updates 5 \
    --num-envs 4 \
    --rollout-length 8

## 7. Main Training Run

Choose one of the options below based on how long you want to train:

- **Quick test (~30 min):** 100 updates
- **Medium run (~2-3 hours):** 500 updates  
- **Full training (~6-8 hours):** 1000+ updates

Uncomment the option you want to run.


### Long-run PPO training (simple data split)

Uses `data/training_simple` to accelerate convergence.


In [None]:

# Long-run PPO training (simple data split)
!python scripts/ppo_train.py   --device cuda   --num-envs 16   --rollout-length 64   --num-minibatches 4   --num-epochs 2   --eval-envs 8   --eval-horizon 512   --total-updates 10000   --eval-interval 25   --checkpoint-interval 100   --checkpoint-dir ${drive_dir:-checkpoints/ppo_a100}   --data-dir data/training_simple   --wandb-mode online   --wandb-project arc-agi-ppo   --wandb-run-name colab_long_run_simple

# Option A: Quick test run (~30 minutes)
# Add --wandb-mode online --wandb-project <your_project> to enable wandb logging
!python scripts/ppo_train.py     --preset a100     --device cuda     --total-updates 100     --eval-interval 20     --checkpoint-interval 25     --checkpoint-dir ${drive_dir:-checkpoints/ppo_a100}     --data-dir data/training_simple

# Option C: Full training run (~6-8 hours)
# !python scripts/ppo_train.py #     --preset a100 #     --device cuda #     --total-updates 1000 #     --eval-interval 50 #     --checkpoint-interval 50 #     --checkpoint-dir ${drive_dir:-checkpoints/ppo_a100}


In [None]:
# Option A: Quick test run (~30 minutes)
# Add --wandb-mode online --wandb-project <your_project> to enable wandb logging
!PYTHONPATH=. python scripts/ppo_train.py \
    --preset a100 \
    --device cuda \
    --total-updates 100 \
    --eval-interval 20 \
    --checkpoint-interval 25 \
    --data-dir data/training_simple


In [None]:
# Option B: Medium run (~2-3 hours) - RECOMMENDED FOR FIRST RUN
# !PYTHONPATH=. python scripts/ppo_train.py \
#     --preset a100 \
#     --device cuda \
#     --total-updates 500 \
#     --eval-interval 25 \
#     --checkpoint-interval 50


In [None]:
# Option C: Full training run (~6-8 hours)
# !PYTHONPATH=. python scripts/ppo_train.py \
#     --preset a100 \
#     --device cuda \
#     --total-updates 1000 \
#     --eval-interval 50 \
#     --checkpoint-interval 50


## 8. List Available Checkpoints


In [None]:
import os
from pathlib import Path

checkpoint_dir = Path("checkpoints/ppo_a100")
if checkpoint_dir.exists():
    checkpoints = sorted([d for d in checkpoint_dir.iterdir() if d.is_dir()])
    print(f"Found {len(checkpoints)} checkpoints:")
    for ckpt in checkpoints:
        size_mb = sum(f.stat().st_size for f in ckpt.rglob('*') if f.is_file()) / 1024 / 1024
        print(f"  - {ckpt.name} ({size_mb:.1f} MB)")
else:
    print("No checkpoints found yet. Run training first!")


## 9. Evaluate Model & Generate GIFs

This will create GIF visualizations of the trained model solving ARC puzzles.


In [None]:
# Find the latest checkpoint
checkpoint_dir = Path("checkpoints/ppo_a100")
checkpoints = sorted([d for d in checkpoint_dir.iterdir() if d.is_dir()])
latest_checkpoint = checkpoints[-1] if checkpoints else None

if latest_checkpoint:
    print(f"Evaluating checkpoint: {latest_checkpoint}")
    !python scripts/ppo_eval_viz.py \
        --checkpoint {latest_checkpoint} \
        --device cuda \
        --num-episodes 10 \
        --rollout-horizon 100 \
        --output-dir artifacts/eval_viz
else:
    print("No checkpoints found. Run training first!")


## 10. Display Generated GIFs

View the generated GIFs directly in the notebook.


In [None]:
from IPython.display import Image as IPImage, display
from pathlib import Path

output_dir = Path("artifacts/eval_viz")
if output_dir.exists():
    gifs = sorted(output_dir.glob("*.gif"))
    print(f"Found {len(gifs)} GIFs:\n")
    
    for gif in gifs[:5]:  # Show first 5
        print(f"Episode: {gif.name}")
        display(IPImage(filename=str(gif)))
        print("\n" + "="*80 + "\n")
else:
    print("No GIFs found. Run evaluation first!")


## 11. Download Results

Download GIFs and checkpoints to your local machine.


In [None]:
# Zip and download GIFs
!zip -r eval_results.zip artifacts/eval_viz/ 2>/dev/null || echo "Creating zip..."

from google.colab import files
if Path("eval_results.zip").exists():
    print("Downloading GIFs...")
    files.download('eval_results.zip')
else:
    print("No results to download yet.")


## 12. Advanced: Mount Google Drive (Optional)

Run the next cell to mount Google Drive, create a timestamped directory, and point training checkpoints there automatically.


In [None]:
from datetime import datetime
from google.colab import drive
import os

drive.mount('/content/drive', force_remount=True)
run_name = f"arcx_ppo_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}"
drive_dir = f"/content/drive/MyDrive/arcx_ppo/{run_name}"
os.makedirs(drive_dir, exist_ok=True)
print('Saving checkpoints to:', drive_dir)

# Export path for subsequent shell commands
os.environ['drive_dir'] = drive_dir