# ARC-AGI 2025: Training Notebook

This notebook trains a baseline sequence-to-sequence Transformer on the generated ARC-AGI dataset (`artifacts/datasets/*.jsonl`).

Sections:
1. Setup
2. Dependencies
3. Device & Determinism
4. Load Dataset
5. Visualize Samples
6. Tokenization & Augmentations
7. Datasets & DataLoaders
8. Transformer Model
9. Loss/Optimizer/Scheduler
10. Training Loop
11. Validation Metrics
12. Inference Solver
13. Save Artifacts
14. Unit Tests
15. Hyperparameter Sweep (optional)
16. Export to TorchScript/ONNX (optional)

In [2]:
# Set Up Environment and Paths
from __future__ import annotations
import os, json, time, random
from pathlib import Path

# Robust project root detection for notebooks (no __file__)
CWD = Path.cwd()
CANDIDATES = [CWD, *CWD.parents]
PROJECT_ROOT = None
for p in CANDIDATES:
    if (p / 'artifacts').exists() and (p / 'models').exists():
        PROJECT_ROOT = p
        break
if PROJECT_ROOT is None:
    # Fallback to two levels up from CWD
    PROJECT_ROOT = CWD if (CWD / 'artifacts').exists() else CWD.parent

DATASETS_DIR = PROJECT_ROOT / 'artifacts' / 'datasets'
MODELS_DIR = PROJECT_ROOT / 'models'
MODELS_DIR.mkdir(parents=True, exist_ok=True)

run_id = time.strftime('%Y%m%d-%H%M%S')
RUN_DIR = MODELS_DIR / f'run_{run_id}'
RUN_DIR.mkdir(parents=True, exist_ok=True)

print('Project root:', PROJECT_ROOT)
print('Datasets dir:', DATASETS_DIR)
print('Run dir:', RUN_DIR)

Project root: /home/aibe/Documents/Code/arc-agi
Datasets dir: /home/aibe/Documents/Code/arc-agi/artifacts/datasets
Run dir: /home/aibe/Documents/Code/arc-agi/models/run_20250809-185104


In [3]:
# Install and Import Dependencies
import sys

# Optional: install heavy deps if missing
try:
    import torch
except Exception:
    %pip install torch --quiet
    import torch

try:
    import einops
except Exception:
    %pip install einops --quiet
    import einops

try:
    import tqdm
except Exception:
    %pip install tqdm --quiet
    import tqdm

try:
    import matplotlib
except Exception:
    %pip install matplotlib --quiet
    import matplotlib

from typing import List, Tuple, Dict, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from einops import rearrange
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

print('Python', sys.version)
print('Torch', torch.__version__)
print('NumPy', np.__version__)

Python 3.13.5 (main, Jun 21 2025, 09:35:00) [GCC 15.1.1 20250425]
Torch 2.8.0+cu128
NumPy 2.3.2


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# Detect Device and Configure Determinism
SEED = 1337
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

if torch.cuda.is_available():
    device = torch.device('cuda')
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    try:
        torch.set_float32_matmul_precision('high')
    except Exception:
        pass
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

print('Using device:', device)

Using device: cpu


In [5]:
# Load ARC-AGI Dataset (from artifacts/datasets/*.jsonl)
from dataclasses import dataclass

@dataclass
class Sample:
    split: str
    task_id: str
    subset: str
    index: int
    input: List[List[int]]
    output: List[List[int]]
    transform: dict


def read_jsonl(path: Path):
    with path.open('r') as f:
        for line in f:
            if line.strip():
                yield json.loads(line)


def load_split(name: str):
    path = DATASETS_DIR / f'{name}.jsonl'
    if not path.exists():
        print(f'Warning: dataset not found: {path}')
        return []
    data = []
    for rec in read_jsonl(path):
        if 'input' in rec and 'output' in rec:
            data.append(Sample(
                split=rec['split'], task_id=rec['task_id'], subset=rec['subset'], index=rec['index'],
                input=rec['input'], output=rec['output'], transform=rec.get('transform', {})
            ))
    return data

train_samples = load_split('training')
val_samples = load_split('evaluation')  # use evaluation as validation if present
if not val_samples and len(train_samples) > 10:
    n = int(0.9 * len(train_samples))
    val_samples = train_samples[n:]
    train_samples = train_samples[:n]

print(f'Train samples: {len(train_samples)} | Val samples: {len(val_samples)}')

Train samples: 3232 | Val samples: 358


In [None]:
# Visualize Sample Tasks
from itertools import islice

def show_grid(ax, grid, title=""):
    arr = np.array(grid, dtype=int)
    cmap = plt.get_cmap('tab10', 10)
    ax.imshow(arr, cmap=cmap, vmin=0, vmax=9)
    ax.set_title(title)
    ax.set_xticks([]); ax.set_yticks([])

fig, axes = plt.subplots(2, 4, figsize=(10,5))
axes = axes.ravel()
for i, s in enumerate(islice(train_samples, 4)):
    show_grid(axes[2*i], s.input, f"Train Input {i}")
    show_grid(axes[2*i+1], s.output, f"Train Output {i}")
plt.tight_layout()
plt.show()

In [6]:
# Encode Grids to Tokens and Augmentations
PAD, BOS, EOS, SEP = 10, 11, 12, 13
VOCAB_SIZE = 14  # 0-9 colors + 4 specials

def normalize_grid(grid: List[List[int]] | List[int] | int) -> np.ndarray:
    arr = np.array(grid, dtype=int)
    if arr.ndim == 0:
        arr = arr.reshape(1, 1)
    elif arr.ndim == 1:
        arr = arr.reshape(1, -1)
    return arr


def encode_grid(grid: List[List[int]] | List[int] | int) -> List[int]:
    arr = normalize_grid(grid)
    return arr.reshape(-1).tolist()

def decode_grid(tokens: List[int], h: int, w: int) -> List[List[int]]:
    seq = tokens[: h*w]
    return [seq[i*w:(i+1)*w] for i in range(h)]

AUG_ROT = [0, 1, 2, 3]
AUG_FLIP = [False, True]

def apply_aug(grid):
    arr = normalize_grid(grid)
    # Only apply geometric augs when >0 dims
    if arr.shape[0] > 0 and arr.shape[1] > 0:
        k = random.choice(AUG_ROT)
        if k:
            arr = np.rot90(arr, k)
        if random.choice(AUG_FLIP):
            arr = np.fliplr(arr)
        # random color permutation over observed colors
        vals = sorted(set(arr.ravel().tolist()))
        if len(vals) > 1:
            perm = vals[:]
            random.shuffle(perm)
            mp = {a:b for a,b in zip(vals, perm)}
            vfunc = np.vectorize(lambda x: mp.get(int(x), int(x)))
            arr = vfunc(arr)
    return arr.astype(int).tolist()

print('Vocab size:', VOCAB_SIZE)

Vocab size: 14


In [11]:
# PyTorch Dataset and DataLoaders
MAX_H, MAX_W = 30, 30  # ARC grids are typically <= 30

def to_2d(grid):
    if grid is None:
        return []
    if isinstance(grid, (int, np.integer)):
        return [[int(grid)]]
    if isinstance(grid, list):
        if not grid:
            return []
        if isinstance(grid[0], list):
            return grid
        else:
            return [grid]
    arr = np.array(grid)
    if arr.ndim == 0:
        return [[int(arr)]]
    if arr.ndim == 1:
        return [arr.astype(int).tolist()]
    return arr.astype(int).tolist()

class ArcSeqDataset(Dataset):
    def __init__(self, samples, augment=False):
        self.samples = samples
        self.augment = augment

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        s = self.samples[idx]
        inp = to_2d(s.input)
        out = to_2d(s.output)
        if self.augment:
            inp = apply_aug(inp)
            out = apply_aug(out)
        h_in, w_in = len(inp), len(inp[0]) if inp and len(inp) > 0 else 0
        h_out, w_out = len(out), len(out[0]) if out and len(out) > 0 else 0
        enc = encode_grid(inp)
        dec_tgt = encode_grid(out)
        dec_in = [BOS] + dec_tgt[:-1] if len(dec_tgt) > 0 else [BOS]
        return {
            'enc': torch.tensor(enc, dtype=torch.long),
            'dec_in': torch.tensor(dec_in, dtype=torch.long),
            'tgt': torch.tensor(dec_tgt, dtype=torch.long),
            'h_in': h_in, 'w_in': w_in, 'h_out': h_out, 'w_out': w_out
        }


def make_row_col_indices(h, w):
    rows = np.repeat(np.arange(h), w) if (h > 0 and w > 0) else np.array([], dtype=int)
    cols = np.tile(np.arange(w), h) if (h > 0 and w > 0) else np.array([], dtype=int)
    return rows, cols


def collate_batch(batch):
    B = len(batch)
    enc_lens = [len(b['enc']) for b in batch]
    dec_lens = [len(b['dec_in']) for b in batch]
    max_enc = max(enc_lens) if enc_lens else 0
    max_dec = max(dec_lens) if dec_lens else 0

    enc = torch.full((B, max_enc), PAD, dtype=torch.long)
    dec_in = torch.full((B, max_dec), PAD, dtype=torch.long)
    tgt = torch.full((B, max_dec), PAD, dtype=torch.long)

    enc_pad_mask = torch.ones((B, max_enc), dtype=torch.bool)  # True for pad
    dec_pad_mask = torch.ones((B, max_dec), dtype=torch.bool)

    row_idx = torch.zeros((B, max_enc), dtype=torch.long)
    col_idx = torch.zeros((B, max_enc), dtype=torch.long)

    meta = []
    for i, b in enumerate(batch):
        L_e = len(b['enc']); L_d = len(b['dec_in'])
        enc[i, :L_e] = b['enc']
        dec_in[i, :L_d] = b['dec_in']
        tgt[i, :len(b['tgt'])] = b['tgt']
        enc_pad_mask[i, :L_e] = False
        dec_pad_mask[i, :L_d] = False
        r, c = make_row_col_indices(b['h_in'], b['w_in'])
        if L_e > 0 and len(r) == L_e:
            row_idx[i, :L_e] = torch.tensor(r, dtype=torch.long)
            col_idx[i, :L_e] = torch.tensor(c, dtype=torch.long)
        meta.append((b['h_in'], b['w_in'], b['h_out'], b['w_out']))

    return {
        'enc': enc, 'dec_in': dec_in, 'tgt': tgt,
        'enc_pad_mask': enc_pad_mask, 'dec_pad_mask': dec_pad_mask,
        'row_idx': row_idx, 'col_idx': col_idx, 'meta': meta
    }

train_ds = ArcSeqDataset(train_samples, augment=True)
val_ds = ArcSeqDataset(val_samples, augment=False)

BATCH_SIZE = 16
NUM_WORKERS = 0
PIN_MEM = (device.type == 'cuda')

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=PIN_MEM, collate_fn=collate_batch)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=NUM_WORKERS, pin_memory=PIN_MEM, collate_fn=collate_batch)

len(train_loader), len(val_loader)

(202, 23)

In [8]:
# Define Transformer Model for ARC
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=4096):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0), persistent=False)
    def forward(self, x):
        L = x.size(1)
        return x + self.pe[:, :L]

class ArcTransformer(nn.Module):
    def __init__(self, vocab_size=VOCAB_SIZE, d_model=128, nhead=4, num_layers=3, dim_ff=256, dropout=0.1):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.row_emb = nn.Embedding(64, d_model)
        self.col_emb = nn.Embedding(64, d_model)
        self.pos_enc = PositionalEncoding(d_model)

        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_ff, dropout, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_ff, dropout, batch_first=True)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.proj = nn.Linear(d_model, vocab_size)

    def encode(self, enc_tokens, row_idx, col_idx, src_key_padding_mask=None):
        x = self.tok_emb(enc_tokens) + self.row_emb(row_idx) + self.col_emb(col_idx)
        x = self.pos_enc(x)
        x = self.encoder(x, src_key_padding_mask=src_key_padding_mask)
        return x

    def decode(self, dec_tokens, memory, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        y = self.tok_emb(dec_tokens)
        y = self.pos_enc(y)
        L = y.size(1)
        causal_mask = torch.triu(torch.ones(L, L, device=y.device, dtype=torch.bool), diagonal=1)
        y = self.decoder(y, memory, tgt_mask=causal_mask,
                         tgt_key_padding_mask=tgt_key_padding_mask,
                         memory_key_padding_mask=memory_key_padding_mask)
        return self.proj(y)

    def forward(self, batch):
        memory = self.encode(batch['enc'], batch['row_idx'], batch['col_idx'], src_key_padding_mask=batch['enc_pad_mask'])
        logits = self.decode(batch['dec_in'], memory,
                             tgt_key_padding_mask=batch['dec_pad_mask'],
                             memory_key_padding_mask=batch['enc_pad_mask'])
        return logits

    @torch.no_grad()
    def generate(self, enc, row_idx, col_idx, enc_pad_mask, max_len=256):
        self.eval()
        memory = self.encode(enc, row_idx, col_idx, src_key_padding_mask=enc_pad_mask)
        B = enc.size(0)
        ys = torch.full((B, 1), BOS, dtype=torch.long, device=enc.device)
        for _ in range(max_len):
            logits = self.decode(ys, memory,
                                 tgt_key_padding_mask=torch.zeros_like(ys, dtype=torch.bool),
                                 memory_key_padding_mask=enc_pad_mask)
            next_tok = logits[:, -1].argmax(-1, keepdim=True)
            ys = torch.cat([ys, next_tok], dim=1)
            if (next_tok == EOS).all():
                break
        return ys[:, 1:]  # drop BOS

In [9]:
# Configure Loss, Optimizer, and Scheduler
model = ArcTransformer().to(device)

criterion = nn.CrossEntropyLoss(ignore_index=PAD)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
scaler = torch.cuda.amp.GradScaler(enabled=(device.type=='cuda'))

print(sum(p.numel() for p in model.parameters())/1e6, 'M params')

1.013774 M params


  scaler = torch.cuda.amp.GradScaler(enabled=(device.type=='cuda'))


In [12]:
# Train Loop with Mixed Precision and Checkpointing
EPOCHS = 1
ACCUM_STEPS = 1
BEST_VAL = float('inf')

best_path = RUN_DIR / 'best.pt'
last_path = RUN_DIR / 'last.pt'

for epoch in range(1, EPOCHS+1):
    model.train()
    pbar = tqdm(train_loader, desc=f'Epoch {epoch} [train]')
    total_loss = 0.0

    for step, batch in enumerate(pbar, 1):
        for k in ['enc','dec_in','tgt','enc_pad_mask','dec_pad_mask','row_idx','col_idx']:
            batch[k] = batch[k].to(device)
        with torch.amp.autocast('cuda', enabled=(device.type=='cuda')):
            logits = model(batch)
            B, L, V = logits.shape
            loss = criterion(logits.view(B*L, V), batch['tgt'].view(B*L)) / ACCUM_STEPS
        scaler.scale(loss).backward()

        if step % ACCUM_STEPS == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
        total_loss += loss.item() * ACCUM_STEPS
        pbar.set_postfix(loss=total_loss/step)

    scheduler.step()

    # Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f'Epoch {epoch} [val]'):
            for k in ['enc','dec_in','tgt','enc_pad_mask','dec_pad_mask','row_idx','col_idx']:
                batch[k] = batch[k].to(device)
            logits = model(batch)
            B, L, V = logits.shape
            loss = criterion(logits.view(B*L, V), batch['tgt'].view(B*L))
            val_loss += loss.item()
    val_loss /= max(1, len(val_loader))
    print(f'Epoch {epoch} val_loss={val_loss:.4f}')

    # Checkpoint
    torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict(),
                'epoch': epoch}, last_path)
    if val_loss < BEST_VAL:
        BEST_VAL = val_loss
        torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict(),
                    'epoch': epoch}, best_path)
        print('Saved new best to', best_path)

Epoch 1 [train]: 100%|██████████| 202/202 [17:00<00:00,  5.05s/it, loss=0.857]
Epoch 1 [train]: 100%|██████████| 202/202 [17:00<00:00,  5.05s/it, loss=0.857]
Epoch 1 [val]: 100%|██████████| 23/23 [00:18<00:00,  1.24it/s]

Epoch 1 val_loss=0.8127
Saved new best to /home/aibe/Documents/Code/arc-agi/models/run_20250809-185104/best.pt





In [None]:
# Validate and Compute ARC Metrics
@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    exact = 0
    total = 0
    cell_correct = 0
    cell_total = 0
    for batch in tqdm(loader, desc='Eval'): 
        for k in ['enc','dec_in','tgt','enc_pad_mask','dec_pad_mask','row_idx','col_idx']:
            batch[k] = batch[k].to(device)
        logits = model(batch)
        preds = logits.argmax(-1)
        mask = batch['tgt'] != PAD
        equal = (preds == batch['tgt']) & mask
        cell_correct += equal.sum().item()
        cell_total += mask.sum().item()
        # exact match per sequence
        seq_equal = (equal.sum(dim=1) == mask.sum(dim=1))
        exact += seq_equal.sum().item()
        total += preds.size(0)
    return {
        'exact_match': exact / max(1, total),
        'cell_accuracy': cell_correct / max(1, cell_total)
    }

metrics = evaluate(model, val_loader)
print(metrics)

In [None]:
# Inference: Solve Unseen Tasks
@torch.no_grad()
def solve_batch(model, batch, max_len=256):
    for k in ['enc','enc_pad_mask','row_idx','col_idx']:
        batch[k] = batch[k].to(device)
    gen = model.generate(batch['enc'], batch['row_idx'], batch['col_idx'], batch['enc_pad_mask'], max_len=max_len)
    preds = gen.cpu().numpy().tolist()
    outputs = []
    for i, (h_in, w_in, h_out, w_out) in enumerate(batch['meta']):
        outputs.append(decode_grid(preds[i], h_out, w_out))
    return outputs

In [None]:
# Save Artifacts to models/
config = {
    'model': 'ArcTransformer', 'vocab_size': VOCAB_SIZE,
    'd_model': 256, 'nhead': 8, 'num_layers': 4, 'dim_ff': 512,
    'batch_size': BATCH_SIZE, 'epochs': EPOCHS, 'seed': SEED,
}

metrics = evaluate(model, val_loader)
with (RUN_DIR / 'metrics.json').open('w') as f:
    json.dump(metrics, f, indent=2)

torch.save({'model': model.state_dict(), 'config': config}, RUN_DIR / 'model.pt')
with (RUN_DIR / 'config.json').open('w') as f:
    json.dump(config, f, indent=2)

print('Saved to', RUN_DIR)

In [None]:
# Lightweight Unit Tests
# 1) Encode/Decode roundtrip
_grid = [[1,2,3],[4,5,6]]
assert decode_grid(encode_grid(_grid), 2, 3) == _grid
print('Encode/Decode test passed')

# 2) Batch forward pass sanity
batch = next(iter(train_loader))
for k in ['enc','dec_in','tgt','enc_pad_mask','dec_pad_mask','row_idx','col_idx']:
    batch[k] = batch[k].to(device)
with torch.no_grad():
    logits = model(batch)
assert logits.shape[:2] == batch['tgt'].shape
print('Forward pass shape test passed')

In [None]:
# Optional: Hyperparameter Sweep Hook
from itertools import product

def sweep(grid):
    results = []
    for (lr, layers, heads, dropout) in product(grid['lr'], grid['layers'], grid['heads'], grid['dropout']):
        m = ArcTransformer(num_layers=layers, nhead=heads).to(device)
        opt = torch.optim.AdamW(m.parameters(), lr=lr)
        # One minibatch quick score
        batch = next(iter(train_loader))
        for k in ['enc','dec_in','tgt','enc_pad_mask','dec_pad_mask','row_idx','col_idx']:
            batch[k] = batch[k].to(device)
        with torch.no_grad():
            logits = m(batch)
            B, L, V = logits.shape
            loss = criterion(logits.view(B*L, V), batch['tgt'].view(B*L)).item()
        results.append({'lr': lr, 'layers': layers, 'heads': heads, 'dropout': dropout, 'loss': loss})
    return sorted(results, key=lambda x: x['loss'])

# Example sweep grid (commented)
# grid = {'lr':[1e-4,3e-4], 'layers':[3,4], 'heads':[4,8], 'dropout':[0.0,0.1]}
# sweep_results = sweep(grid)
# sweep_results[:5]

In [None]:
# Optional: Export to TorchScript/ONNX
try:
    example = next(iter(val_loader))
    for k in ['enc','dec_in','tgt','enc_pad_mask','dec_pad_mask','row_idx','col_idx']:
        example[k] = example[k].to(device)
    ts_path = RUN_DIR / 'model_ts.pt'
    scripted = torch.jit.trace(model, (example))  # may fail for dict input
    scripted.save(str(ts_path))
    print('Saved TorchScript to', ts_path)
except Exception as e:
    print('TorchScript export skipped:', e)

try:
    import onnx
    onnx_path = RUN_DIR / 'model.onnx'
    # ONNX export with dynamic axes is non-trivial for dict inputs; skipping here
    print('ONNX export not implemented in this baseline')
except Exception as e:
    pass