In [None]:
import argparse, json, os, re
import numpy as np
import pandas as pd
import sys
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colors
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any, Union
from PIL import Image, ImageDraw, ImageFont
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from enum import Enum
import warnings

warnings.filterwarnings("ignore")



In [None]:
class ARCVisualizer:
    """Enhanced ARC task vis w/ mult display modes

    provides utilities for conv int grids to RGB images, w/ optionality
    to add grid liones/titles and assmbling comp. vis of trainig and test pairs """

    PALETTE = np.array([
        [0,   0,   0],      #0 black
        [0,   116, 217],    #1 blue
        [255, 65,  54],     #2 red
        [46,  204, 64],     #3 green
        [255, 220, 0],      #4 yellow
        [128, 128, 128],    #5 gray
        [240, 18,  190],    #6 magenta/pink
        [255, 133, 27],     #7 orange
        [0,   255, 255],    #8 cyan/sky
        [135, 12,  37],     #9 maroon/brown
    ], dtype=np.uint8)


    def __init__(self, scale: int = 30, draw_grid: bool = True):
        self.scale = scale
        self.draw_grid = draw_grid
        self.cmap = colors.ListedColormap([
            "#000000", "#0074D9", "#FF4136", "#2ECC40", "#FFDC00",
            "#AAAAAA", "#F012BE", "#FF851B", "#7FDBFF", "#870C25"])
        self.norm = colors.Normalize(vmin=0, vmax=9)
    def grid_to_image(self, grid: Union[List[List[int]], np.ndarray],
                      title: str = None) -> Image.Image:
        """conv grid to PIL image w/ optional title

        input grid must be 2d-list or np array w/ 0-9 ints inclusive

        returns nearest-neighbor resize to enlarge each cell by config scale"""
        arr = np.array(grid, dtype=np.int16)
        if arr.ndim != 2:
            raise ValueError("Grid must be 2D")
        if arr.min() < 0 or arr.max() > 9:
            raise ValueError("Grid values must be in 0-9 inclusive")
        rgb = self.PALETTE[arr]
        img = Image.fromarray(rgb.astype(np.uint8), mode="RGB")
        if self.scale != 1:
            img = img.resize((img.width * self.scale, img.height * self.scale),
                             resample=Image.NEAREST)
        if self.draw_grid and self.scale >= 10:
            self._add_gridlines(img)
        if title:
            img = self._add_title_banner(img, title)
        return img
    def _add_gridlines(self, img: Image.Image):
        """draw grid lines on image"""
        draw = ImageDraw.Draw(img)
        grid_color = (40, 40, 40)
        for x in range(0, img.width + 1, self.scale):
            draw.line([(x, 0), (x, img.height)], fill=grid_color)
        for y in range(0, img.height + 1, self.scale):
            draw.line([(0, y), (img.width, y)], fill=grid_color)

    def _add_title_banner(self, img: Image.Image, title: str) -> Image.Image:
        """prepend small banner w/ title above image."""
        banner_height = 30
        padding = 5
        new_img = Image.new("RGB", (img.width, banner_height + padding + img.height),
                            (255, 255, 255))
        draw = ImageDraw.Draw(new_img)
        draw.rectangle([0, 0, img.width, banner_height], fill=(245, 245, 245))
        try:
            font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 14)
        except Exception:
            font = ImageFont.load_default()
        draw.text((10, 8), title, fill=(30, 30, 30), font=font)
        new_img.paste(img, (0, banner_height + padding))
        return new_img

    def visualize_task_matplotlib(self, task: Dict, task_id: str = None, 
                                  show_predictions: List = None):
        """vis full task using matplotlib

        training and test pairs arr in a grid of subplots.  
        if predictions shown next to corresponding test inputs if provided
          Unused axes are hidden +  returned figure can be
        saved or further customized by caller
        """
        n_train = len(task["train"])
        n_test = len(task["test"])
        n_rows = max(2, n_test + (1 if show_predictions else 0))
        n_cols = max(n_train, 2)
        fig, axes = plt.subplots(n_rows, n_cols * 2, figsize=(n_cols * 4, n_rows * 2))
        if n_rows == 1:
            axes = axes.reshape(1, -1)
        #plot training examples
        for i, example in enumerate(task["train"]):
            ax_in = axes[0, i * 2]
            ax_out = axes[0, i * 2 + 1]
            ax_in.imshow(example["input"], cmap=self.cmap, norm=self.norm)
            ax_in.set_title(f"Train {i+1} Input")
            ax_in.axis("off")
            ax_in.grid(True, which="both", color="lightgrey", linewidth=0.5)
            ax_out.imshow(example["output"], cmap=self.cmap, norm=self.norm)
            ax_out.set_title(f"Train {i+1} Output")
            ax_out.axis("off")
            ax_out.grid(True, which="both", color="lightgrey", linewidth=0.5)
        #plot test inputs
        for i, test in enumerate(task["test"]):
            ax = axes[1, i * 2]
            ax.imshow(test["input"], cmap=self.cmap, norm=self.norm)
            ax.set_title(f"Test {i+1} Input")
            ax.axis("off")
            ax.grid(True, which="both", color="lightgrey", linewidth=0.5)
        #plot predictions if provided
        if show_predictions:
            for i, pred in enumerate(show_predictions):
                if i < n_test:
                    ax = axes[1, i * 2 + 1]
                    ax.imshow(pred, cmap=self.cmap, norm=self.norm)
                    ax.set_title(f"Test {i+1} Prediction")
                    ax.axis("off")
                    ax.grid(True, which="both", color="lightgrey", linewidth=0.5)
        #hide unused subplots
        for i in range(n_rows):
            for j in range(n_cols * 2):
                idx = i * n_cols * 2 + j
                if idx >= len(axes.flat):
                    continue
                if i == 0 and j >= n_train * 2:
                    axes[i, j].axis("off")
                elif i == 1 and j >= n_test * 2:
                    axes[i, j].axis("off")
        if task_id:
            fig.suptitle(f"Task: {task_id}", fontsize=16)
        plt.tight_layout()
        return fig

    def create_composite_image(self, task: Dict, task_id: str = None,
                               solutions: List = None) -> Image.Image:
        """generate composite image for all training and test pairs.

        Training examples are displayed as input/output pairs
        test inputs are optionally paired w/ provided solutions
        Images stacked vert to produce single overview of entire task.
        """
        images = []
        #Training pairs
        for i, example in enumerate(task["train"]):
            in_img = self.grid_to_image(example["input"], f"Train {i+1} Input")
            out_img = self.grid_to_image(example["output"], f"Train {i+1} Output")
            pair = self._hstack([in_img, out_img], gap=10)
            images.append(pair)
        #Test inputs and solutions
        for i, test in enumerate(task["test"]):
            test_img = self.grid_to_image(test["input"], f"Test {i+1} Input")
            if solutions and i < len(solutions):
                sol_img = self.grid_to_image(solutions[i], f"Test {i+1} Solution")
                pair = self._hstack([test_img, sol_img], gap=10)
                images.append(pair)
            else:
                images.append(test_img)
        composite = self._vstack(images, gap=15)
        if task_id:
            composite = self._add_title_banner(composite, f"Task: {task_id}")
        return composite

    def _hstack(self, images: List[Image.Image], gap: int = 10) -> Image.Image:
        """horz stack list of images w/ optional gaps"""
        if not images:
            raise ValueError("No images to stack")
        height = max(img.height for img in images)
        width = sum(img.width for img in images) + gap * (len(images) - 1)
        result = Image.new("RGB", (width, height), (255, 255, 255))
        x = 0
        for img in images:
            result.paste(img, (x, 0))
            x += img.width + gap
        return result

    def _vstack(self, images: List[Image.Image], gap: int = 10) -> Image.Image:
        """vert stack list of images w/ optional gaps"""
        if not images:
            raise ValueError("No images to stack")
        width = max(img.width for img in images)
        height = sum(img.height for img in images) + gap * (len(images) - 1)
        result = Image.new("RGB", (width, height), (255, 255, 255))
        y = 0
        for img in images:
            x = (width - img.width) // 2  #center horizontally
            result.paste(img, (x, y))
            y += img.height + gap
        return result

In [None]:
class DSLOperations:
    """coll of pure funcs to manip int grids
    
    ops form building blocks of symbolic search + incl basic geo transforms(rot, flip trans), color repl, cropping, padding, flood filling and symm checks"""
    
    @staticmethod
    def rotate(grid: np.ndarray, k: int = 1) -> np.ndarray:
        """rotate grid by 90 deg * k steps using np.rot90."""
        return np.rot90(grid, k)

    @staticmethod
    def flip(grid: np.ndarray, axis: int = 0) -> np.ndarray:
        """flip grid along vert (axis=0) or horz (axis=1) axis."""
        return np.flip(grid, axis=axis)

    @staticmethod
    def transpose(grid: np.ndarray) -> np.ndarray:
        """swap the x and y axes of the grid."""
        return np.swapaxes(grid, 0, 1)

    @staticmethod
    def color_replace(grid: np.ndarray, from_color: int, to_color: int) -> np.ndarray:
        """repl all occ of from_color with to_color."""
        result = grid.copy()
        result[result == from_color] = to_color
        return result

    @staticmethod
    def crop(grid: np.ndarray, x1: int, y1: int, x2: int, y2: int) -> np.ndarray:
        """extr a rectangular subgrid defined by (x1,y1) to (x2,y2) inclusive."""
        return grid[y1:y2+1, x1:x2+1]

    @staticmethod
    def pad(grid: np.ndarray, top: int = 0, bottom: int = 0, 
            left: int = 0, right: int = 0, fill: int = 0) -> np.ndarray:
        """pad grid on all sides with specified count of rows/columns."""
        return np.pad(grid, ((top, bottom), (left, right)), constant_values=fill)

    @staticmethod
    def flood_fill(grid: np.ndarray, x: int, y: int, new_color: int) -> np.ndarray:
        """Perform flood fill starting at (x,y) using new color."""
        grid = grid.copy()
        target_color = grid[y, x]
        if target_color == new_color:
            return grid
        h, w = grid.shape
        stack = [(x, y)]
        while stack:
            cx, cy = stack.pop()
            if cx < 0 or cy < 0 or cx >= w or cy >= h:
                continue
            if grid[cy, cx] != target_color:
                continue
            grid[cy, cx] = new_color
            stack.extend([(cx+1, cy), (cx-1, cy), (cx, cy+1), (cx, cy-1)])
        return grid

    @staticmethod
    def bounding_box(grid: np.ndarray, color: int) -> Optional[Tuple[int, int, int, int]]:
        """ret bounding box of all cells matching given color."""
        yx = np.argwhere(grid == color)
        if len(yx) == 0:
            return None
        ys, xs = yx[:, 0], yx[:, 1]
        return np.min(xs), np.min(ys), np.max(xs), np.max(ys)

    @staticmethod
    def extract_object(grid: np.ndarray, color: int, 
                       background: int = 0) -> Optional[np.ndarray]:
        """extract smallest subgrid containing specified color."""
        bbox = DSLOperations.bounding_box(grid, color)
        if bbox is None:
            return None
        x1, y1, x2, y2 = bbox
        return grid[y1:y2+1, x1:x2+1]

    @staticmethod
    def count_colors(grid: np.ndarray) -> Dict[int, int]:
        """count occurrences of each color in grid."""
        unique, counts = np.unique(grid, return_counts=True)
        return dict(zip(unique, counts))

    @staticmethod
    def get_symmetry(grid: np.ndarray) -> Dict[str, bool]:
        """return dict indicating presence of various symmetries."""
        return {
            "horizontal": np.array_equal(grid, np.flipud(grid)),
            "vertical": np.array_equal(grid, np.fliplr(grid)),
            "diagonal": np.array_equal(grid, grid.T),
            "rotational_90": np.array_equal(grid, np.rot90(grid, 2))
        }


In [None]:
class SymbolicProgramSearch:
    """search space of small prog composed of DSL ops

    A DFS used to explore comp of primitive ops - when prog transforms all training
    inputs to corr. outputs, deemed VALID solution - max prog depth can be config at instantiation"""
    def __init__(self, max_depth: int = 3):
        self.max_depth = max_depth
        self.dsl = DSLOperations()
        #define candidate ops as (name, function) pairs
        self.operations = [
            ('rotate_90', lambda g: self.dsl.rotate(g, 1)),
            ('rotate_180', lambda g: self.dsl.rotate(g, 2)),
            ('rotate_270', lambda g: self.dsl.rotate(g, 3)),
            ('flip_h', lambda g: self.dsl.flip(g, 1)),
            ('flip_v', lambda g: self.dsl.flip(g, 0)),
            ('transpose', lambda g: self.dsl.transpose(g)),
        ]
        #include simple color replacements for first few colors
        for i in range(1, 4):
            for j in range(1, 4):
                if i != j:
                    self.operations.append(
                        (f'color_{i}_to_{j}', lambda g, fi=i, fj=j: self.dsl.color_replace(g, fi, fj))
                    )


    def search(self, task: Dict) -> Optional[List[Tuple[str, Any]]]:
        """attempt to discover program that solves training examples."""
        train_examples = task['train']
        return self._dfs_search(train_examples, [])

    def _dfs_search(self, examples: List[Dict],
                    current_program: List[Tuple[str, Any]]) -> Optional[List[Tuple[str, Any]]]:
        """DFS for program that matches all training pairs."""
        if len(current_program) > self.max_depth:
            return None
        #check current program
        if current_program and self._program_matches(examples, current_program):
            return current_program
        #explore further operations
        for op_name, op_func in self.operations:
            new_program = current_program + [(op_name, op_func)]
            result = self._dfs_search(examples, new_program)
            if result is not None:
                return result
        return None

    def _program_matches(self, examples: List[Dict],
                         program: List[Tuple[str, Any]]) -> bool:
        """check whether a candidate prog produces all expected outputs."""
        for example in examples:
            input_grid = np.array(example['input'])
            expected_output = np.array(example['output'])
            result = self._apply_program(input_grid, program)
            if result is None or not np.array_equal(result, expected_output):
                return False
        return True

    def _apply_program(self, grid: np.ndarray,
                       program: List[Tuple[str, Any]]) -> Optional[np.ndarray]:
        """apply seq of ops to a grid"""
        try:
            result = grid.copy()
            for _, op_func in program:
                result = op_func(result)
            return result
        except Exception:
            return None

    def apply(self, grid: np.ndarray,
              program: List[Tuple[str, Any]]) -> Optional[np.ndarray]:
        """apply a discovered program to a new grid"""
        return self._apply_program(grid, program)

In [None]:
class SimpleARCNet(nn.Module):
    """light cnn for recog grid patterns

    net takes one-hot input and output grids stacked along channel dim and predicts one of
    transform labels - NOT TRAINED YET - simple building base"""
    def __init__(self, max_grid_size: int = 30, hidden_dim: int = 256):
        super().__init__()
        self.max_grid_size = max_grid_size
        #encoder: two conv layers with pooling and adaptive pooling to a fixed size
        self.encoder = nn.Sequential(
            nn.Conv2d(20, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(4)
        )
        #fully connected classifier
        self.fc = nn.Sequential(
            nn.Linear(128 * 4 * 4, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 10)
        )

    def forward(self, input_grid: torch.Tensor, output_grid: torch.Tensor) -> torch.Tensor:
        """Forward pass: classify relation between input and output grid."""
        x = torch.cat([input_grid, output_grid], dim=1)
        x = self.encoder(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

    @staticmethod
    def grid_to_tensor(grid: np.ndarray, num_colors: int = 10,
                       max_size: int = 30) -> torch.Tensor:
        """conv an integer grid to one‑hot tensor with padding."""
        h, w = grid.shape
        padded = np.zeros((max_size, max_size), dtype=int)
        padded[:h, :w] = grid
        one_hot = np.zeros((num_colors, max_size, max_size), dtype=np.float32)
        for i in range(num_colors):
            one_hot[i] = (padded == i).astype(np.float32)
        return torch.tensor(one_hot)