In [1]:
!git clone https://github.com/arcprize/ARC-AGI-2

Cloning into 'ARC-AGI-2'...
remote: Enumerating objects: 1287, done.[K
remote: Counting objects: 100% (65/65), done.[K
remote: Compressing objects: 100% (56/56), done.[K
remote: Total 1287 (delta 19), reused 31 (delta 9), pack-reused 1222 (from 2)[K
Receiving objects: 100% (1287/1287), 604.85 KiB | 10.08 MiB/s, done.
Resolving deltas: 100% (608/608), done.


In [2]:
!pip install torch torchvision einops numpy

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [3]:
import json, random, time, datetime
import torch, torch.nn as nn, torch.nn.functional as F, torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from einops import rearrange
from PIL import Image, ImageDraw
from torchvision import transforms


In [5]:
# Load the small JSON dataset of 3×3 → 9×9 (we only need the 3×3 inputs here)
with open("pattern.json") as f:
    raw = json.load(f)["train"]

# Render 3×3 grids as tiny RGB images (one channel repeated)
def grid_to_img(grid, cell=16):
    H,W = 3,3
    img = Image.new("RGB", (W*cell, H*cell), (0,0,0))
    draw = ImageDraw.Draw(img)
    cmap = [(i*36,)*3 for i in range(8)]
    for i in range(H):
        for j in range(W):
            val = grid[i][j]
            color = cmap[val] if val< len(cmap) else (255,255,255)
            draw.rectangle([j*cell,i*cell,(j+1)*cell,(i+1)*cell], fill=color)
    return img

# Build a Dataset of input‑only images
class InputOnly(Dataset):
    def __init__(self, samples, transform):
        self.imgs = [transform(grid_to_img(s["input"])) for s in samples]
    def __len__(self):    return len(self.imgs)
    def __getitem__(self,i):
        x = self.imgs[i]
        return x, x  # auto‑encode

transform = transforms.Compose([
    transforms.ToTensor(),  # [0,1], (3,48,48)
])

ds = InputOnly(raw, transform)
loader = DataLoader(ds, batch_size=4, shuffle=True)


In [6]:
class SlotAttention(nn.Module):
    def __init__(self, num_slots, dim, iters=3, hidden_dim=64):
        super().__init__()
        self.num_slots, self.iters = num_slots, iters
        self.scale = dim**-0.5
        self.slots_mu    = nn.Parameter(torch.randn(1, num_slots, dim))
        self.slots_sigma = nn.Parameter(torch.rand(1, num_slots, dim))
        self.to_q = nn.Linear(dim, dim, bias=False)
        self.to_k = nn.Linear(dim, dim, bias=False)
        self.to_v = nn.Linear(dim, dim, bias=False)
        self.gru = nn.GRUCell(dim, dim)
        self.mlp = nn.Sequential(nn.Linear(dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, dim))
        self.norm_input  = nn.LayerNorm(dim)
        self.norm_slots  = nn.LayerNorm(dim)
        self.norm_pre_ff = nn.LayerNorm(dim)

    def forward(self, x):
        B, N, D = x.shape
        mu  = self.slots_mu.expand(B, -1, -1)
        sig = F.softplus(self.slots_sigma).expand(B, -1, -1)
        slots = mu + sig * torch.randn_like(mu)
        x = self.norm_input(x)
        k,v = self.to_k(x), self.to_v(x)

        for _ in range(self.iters):
            slots_prev = slots
            slots_norm = self.norm_slots(slots)
            q = self.to_q(slots_norm)
            attn_logits = torch.einsum('bnd,bsd->bns', k, q)*self.scale
            attn = attn_logits.softmax(dim=1)
            updates = torch.einsum('bns,bnd->bsd', attn, v)
            slots = self.gru(updates.reshape(-1,D), slots_prev.reshape(-1,D)).reshape(B, -1, D)
            slots = slots + self.mlp(self.norm_pre_ff(slots))
        return slots

class SlotAutoEncoder(nn.Module):
    def __init__(self, res=(48,48), hidden=64, slots=9):
        super().__init__()
        C=3; H,W=res
        self.encoder = nn.Sequential(
            nn.Conv2d(C, hidden, 5, padding=2), nn.ReLU(),
            nn.Conv2d(hidden, hidden,5,padding=2), nn.ReLU(),
        )
        self.pos_emb = nn.Parameter(torch.randn(1, H*W, hidden))
        self.slot_attn = SlotAttention(slots, hidden)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(hidden, hidden,5,padding=2), nn.ReLU(),
            nn.ConvTranspose2d(hidden, C, 5, padding=2), nn.Sigmoid()
        )

    def forward(self,x):
        B,C,H,W = x.shape
        f = self.encoder(x)                  # [B,hidden,H,W]
        tokens = (f.flatten(2).permute(0,2,1) + self.pos_emb)  # [B,H*W,hidden]
        slots = self.slot_attn(tokens)       # [B,slots,hidden]
        # Broadcast each slot to map and decode separately
        out = 0
        for s in slots.permute(1,0,2):       # slots × [B,hidden]
            feat = s.unsqueeze(-1).unsqueeze(-1).expand(-1,-1,H,W)
            out = out + self.decoder(feat)
        return out / slots.shape[1], slots   # recon, slots

# Instantiate
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SlotAutoEncoder(res=(48,48), hidden=64, slots=9).to(device)
opt = optim.Adam(model.parameters(), lr=3e-4)
criterion = nn.MSELoss()


In [7]:
EPOCHS=50
for ep in range(1, EPOCHS+1):
    model.train(); L=0
    for x, _ in loader:
        x = x.to(device)
        recon, slots = model(x)
        loss = criterion(recon, x)
        opt.zero_grad(); loss.backward(); opt.step()
        L += loss.item()
    print(f"Ep{ep:02d} ↓ Loss {L/len(loader):.4f}")


Ep01 ↓ Loss 0.1203
Ep02 ↓ Loss 0.1148
Ep03 ↓ Loss 0.1144
Ep04 ↓ Loss 0.1142
Ep05 ↓ Loss 0.1143
Ep06 ↓ Loss 0.1142
Ep07 ↓ Loss 0.1140
Ep08 ↓ Loss 0.1140
Ep09 ↓ Loss 0.1140
Ep10 ↓ Loss 0.1140
Ep11 ↓ Loss 0.1138
Ep12 ↓ Loss 0.1139
Ep13 ↓ Loss 0.1140
Ep14 ↓ Loss 0.1138
Ep15 ↓ Loss 0.1136
Ep16 ↓ Loss 0.1118
Ep17 ↓ Loss 0.1108
Ep18 ↓ Loss 0.1092
Ep19 ↓ Loss 0.1076
Ep20 ↓ Loss 0.1074
Ep21 ↓ Loss 0.1073
Ep22 ↓ Loss 0.1070
Ep23 ↓ Loss 0.1068
Ep24 ↓ Loss 0.1058
Ep25 ↓ Loss 0.1052
Ep26 ↓ Loss 0.1051
Ep27 ↓ Loss 0.1048
Ep28 ↓ Loss 0.1048
Ep29 ↓ Loss 0.1046
Ep30 ↓ Loss 0.1047
Ep31 ↓ Loss 0.1046
Ep32 ↓ Loss 0.1045
Ep33 ↓ Loss 0.1047
Ep34 ↓ Loss 0.1045
Ep35 ↓ Loss 0.1044
Ep36 ↓ Loss 0.1043
Ep37 ↓ Loss 0.1045
Ep38 ↓ Loss 0.1043
Ep39 ↓ Loss 0.1043
Ep40 ↓ Loss 0.1042
Ep41 ↓ Loss 0.1042
Ep42 ↓ Loss 0.1042
Ep43 ↓ Loss 0.1040
Ep44 ↓ Loss 0.1040
Ep45 ↓ Loss 0.1040
Ep46 ↓ Loss 0.1038
Ep47 ↓ Loss 0.1037
Ep48 ↓ Loss 0.1037
Ep49 ↓ Loss 0.1036
Ep50 ↓ Loss 0.1035


In [8]:
# pick one sample
img, _ = ds[0]
x = img.unsqueeze(0).to(device)  # [1,3,48,48]
model.eval()
with torch.no_grad():
    recon, masks = model(x)

# masks: [1,slots,hidden] -- we need per-slot spatial masks:
# instead, re-decode with alpha-head removed: use attention from SlotAttention?
# For simplicity, we'll skip to symbolic step: we know the original grid.

grid3 = raw[0]["input"]


In [9]:
def tile_rule(input_grid):
    out = [[0]*9 for _ in range(9)]
    for i in range(3):
        for j in range(3):
            if input_grid[i][j]!=0:
                for di in range(3):
                    for dj in range(3):
                        out[3*i+di][3*j+dj] = input_grid[di][dj]
    return out

pred = tile_rule(grid3)
print("Predicted:\n", pred)
print("Ground-truth:\n", raw[0]["output"])


Predicted:
 [[4, 2, 0, 4, 2, 0, 0, 0, 0], [0, 0, 2, 0, 0, 2, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 4, 2, 0], [0, 0, 0, 0, 0, 0, 0, 0, 2], [0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0]]
Ground-truth:
 [[4, 2, 0, 4, 2, 0, 0, 0, 0], [0, 0, 2, 0, 0, 2, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 4, 2, 0], [0, 0, 0, 0, 0, 0, 0, 0, 2], [0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0]]


In [10]:
def tile_rule(input_grid):
    """
    For each non‑zero in the 3×3 input, copies the entire 3×3 block
    into the corresponding 3×3 region of the 9×9 output.
    """
    out = [[0]*9 for _ in range(9)]
    for i in range(3):
        for j in range(3):
            if input_grid[i][j] != 0:
                for di in range(3):
                    for dj in range(3):
                        out[3*i+di][3*j+dj] = input_grid[di][dj]
    return out


In [14]:
# 🔧 Edit this 3×3 grid however you like:
my_input = [
    [2, 2, 2],
    [0, 0, 0],
    [0, 2, 2]
]

# Run the tiling rule
predicted = tile_rule(my_input)

# Print the 3×3 input
print("🧩 Input (3×3):")
for row in my_input:
    print(row)

# Print a separator
print("\n↳ Predicted 9×9 output:")
for row in predicted:
    print(row)


🧩 Input (3×3):
[2, 2, 2]
[0, 0, 0]
[0, 2, 2]

↳ Predicted 9×9 output:
[2, 2, 2, 2, 2, 2, 2, 2, 2]
[0, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 2, 2, 0, 2, 2, 0, 2, 2]
[0, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 0, 2, 2, 2, 2, 2, 2]
[0, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 0, 0, 2, 2, 0, 2, 2]
