# ARC-RL: Explore Environment & Agent

Interactive notebook to:
1. Load and visualize ARC tasks
2. Step through the environment manually
3. Run a random/trained policy and visualize rollouts

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, HTML

from arc_rl.config import ModelConfig, TrainConfig
from arc_rl.dataset import ARCDataset
from arc_rl.env import BatchedARCEnv
from arc_rl.model import ARCPolicy, sample_resize, sample_paint
from arc_rl.renderer import grid_to_image, render_task, ARC_COLORS
from arc_rl.config import GRID_SIZE

%matplotlib inline
plt.rcParams['figure.figsize'] = (14, 6)
plt.rcParams['figure.dpi'] = 100

## 1. Load Dataset & Visualize a Task

In [None]:
dataset = ARCDataset('../references/ARC-AGI/data', split='training')
print(f'Loaded {len(dataset)} training tasks')

task = dataset[0]
print(f'\nTask: {task.task_id}')
print(f'  Train pairs: {len(task.train_pairs)}')
print(f'  Test pairs:  {len(task.test_pairs)}')

for i, (inp, out) in enumerate(task.train_pairs):
    print(f'  Pair {i}: input {len(inp)}x{len(inp[0])} -> output {len(out)}x{len(out[0])}')

In [None]:
# Render the task
examples = task.train_pairs[:3]
test_input, test_output = task.test_pairs[0]

img = render_task(examples, test_input, target=test_output)
plt.figure(figsize=(16, 8))
plt.imshow(img)
plt.axis('off')
plt.title(f'Task: {task.task_id}')
plt.tight_layout()
plt.show()

## 2. Create Environment & Step Through

In [None]:
device = torch.device('cpu')  # Use CPU for notebook exploration

examples_inst, test_inp, target_out = task.get_eval_instance(test_idx=0)
print(f'Examples: {len(examples_inst)} pairs')
print(f'Test input:  {len(test_inp)}x{len(test_inp[0])}')
print(f'Target output: {len(target_out)}x{len(target_out[0])}')

# Create environment with K=4 parallel rollouts
K = 4
env = BatchedARCEnv(
    [(examples_inst, test_inp, target_out)],
    K=K, max_steps=50, device=device,
)

obs = env.reset()
print(f'\nObservation shape: {obs.shape}')
print(f'  Expected: [{K}, {ModelConfig().in_channels}, 30, 30]')

In [None]:
# Step 0: RESIZE to match target output size
target_h, target_w = len(target_out), len(target_out[0])
print(f'Target size: {target_h}x{target_w}')

h_tensor = torch.full((K,), target_h, dtype=torch.long)
w_tensor = torch.full((K,), target_w, dtype=torch.long)
env.resize(h_tensor, w_tensor)
print(f'Grid resized to {target_h}x{target_w}')

# Step 1+: Paint some cells manually
# Paint the first non-zero cell from the target
painted = 0
for r in range(target_h):
    for c in range(target_w):
        if target_out[r][c] != 0:
            color = torch.full((K,), target_out[r][c], dtype=torch.long)
            x = torch.full((K,), c, dtype=torch.long)
            y = torch.full((K,), r, dtype=torch.long)
            env.paint(color, x, y)
            painted += 1
            if painted >= 5:
                break
    if painted >= 5:
        break

print(f'Painted {painted} cells')

# Visualize current state
predicted_grids = env.get_predicted_grids()
rewards = env.compute_rewards()
print(f'Rewards: {rewards.tolist()}')

fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].imshow(grid_to_image(test_inp))
axes[0].set_title('Test Input')
axes[0].axis('off')
axes[1].imshow(grid_to_image(predicted_grids[0]))
axes[1].set_title(f'Agent Output (reward={rewards[0]:.3f})')
axes[1].axis('off')
axes[2].imshow(grid_to_image(target_out))
axes[2].set_title('Target')
axes[2].axis('off')
plt.tight_layout()
plt.show()

## 3. Run a Random Policy

In [None]:
model_cfg = ModelConfig(hidden_channels=32, num_blocks=4)  # Small for notebook
policy = ARCPolicy(model_cfg)
num_params = sum(p.numel() for p in policy.parameters())
print(f'Policy: {num_params/1e6:.2f}M parameters')
print(f'Input channels: {model_cfg.in_channels}')

# Run a rollout with the random policy
K = 8
max_steps = 30

env = BatchedARCEnv(
    [(examples_inst, test_inp, target_out)],
    K=K, max_steps=max_steps, device=device,
)
obs = env.reset()

# Step 0: RESIZE
with torch.no_grad():
    outputs = policy(obs)
    from arc_rl.model import sample_resize, sample_paint
    resize_actions = sample_resize(outputs)
    env.resize(resize_actions.resize_h + 1, resize_actions.resize_w + 1)

predicted_sizes = list(zip(
    (resize_actions.resize_h + 1).tolist(),
    (resize_actions.resize_w + 1).tolist()
))
print(f'Predicted sizes: {predicted_sizes}')

# Steps 1+: PAINT
for step in range(1, max_steps):
    obs = env.get_obs()
    masks = env.get_grid_masks()
    with torch.no_grad():
        outputs = policy(obs)
    paint_actions = sample_paint(outputs, masks)
    y = paint_actions.position // GRID_SIZE
    x = paint_actions.position % GRID_SIZE
    env.paint(paint_actions.color, x, y)

rewards = env.compute_rewards()
predicted_grids = env.get_predicted_grids()

print(f'\nRewards: {[f"{r:.3f}" for r in rewards.tolist()]}')
print(f'Best reward: {rewards.max():.3f}')

In [None]:
# Visualize best rollout
best_idx = rewards.argmax().item()

fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].imshow(grid_to_image(test_inp))
axes[0].set_title('Test Input')
axes[0].axis('off')
axes[1].imshow(grid_to_image(predicted_grids[best_idx]))
axes[1].set_title(f'Best Random Prediction\n(reward={rewards[best_idx]:.3f})')
axes[1].axis('off')
axes[2].imshow(grid_to_image(target_out))
axes[2].set_title('Target')
axes[2].axis('off')
plt.suptitle(f'Task: {task.task_id} â€” Random Policy', fontsize=14)
plt.tight_layout()
plt.show()

## 4. Load Trained Model (after training)

After running `scripts/train.py`, load the best checkpoint and visualize predictions.

In [None]:
import os
ckpt_path = '../checkpoints/best.pt'

if os.path.exists(ckpt_path):
    ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
    trained_cfg = ModelConfig(**ckpt['model_cfg'])
    trained_policy = ARCPolicy(trained_cfg)
    trained_policy.load_state_dict(ckpt['model'])
    trained_policy.eval()
    print(f'Loaded checkpoint from iteration {ckpt["iteration"]}')
    print(f'  Accuracy: {ckpt.get("accuracy", "N/A")}')

    # Run on a few tasks
    for task_idx in [0, 10, 50, 100]:
        t = dataset[task_idx]
        examples, ti, to = t.get_eval_instance()

        env = BatchedARCEnv(
            [(examples, ti, to)],
            K=16, max_steps=150, device=device,
        )
        obs = env.reset()

        with torch.no_grad():
            out = trained_policy(obs)
            ra = sample_resize(out)
            env.resize(ra.resize_h + 1, ra.resize_w + 1)

            for s in range(1, 150):
                obs = env.get_obs()
                masks = env.get_grid_masks()
                out = trained_policy(obs)
                pa = sample_paint(out, masks)
                y = pa.position // GRID_SIZE
                x = pa.position % GRID_SIZE
                env.paint(pa.color, x, y)

        rw = env.compute_rewards()
        pg = env.get_predicted_grids()
        bi = rw.argmax().item()

        img = render_task(examples, ti, prediction=pg[bi], target=to)
        plt.figure(figsize=(16, 6))
        plt.imshow(img)
        plt.axis('off')
        status = 'SOLVED' if rw[bi] >= 2.0 else f'pixacc={min(rw[bi].item(), 1.0):.2f}'
        plt.title(f'Task {t.task_id}: {status}')
        plt.show()
else:
    print('No checkpoint found. Run scripts/train.py first!')