# Gadgets visualization notebook
An interactive notebook for rendering Stitch compressor inventions.

Author: Gabe Grand

Date: 2022-03-03

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys

BASE_DIR = ".."
sys.path.insert(0, BASE_DIR)

In [None]:
import json
import math
import os
import pickle
import re
from collections import defaultdict

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
from PIL import Image
from primitives.gadgets_primitives import cc_shape, polygon_shape, r_shape
from primitives.object_primitives import (
    SYNTHESIS_TASK_CANVAS_WIDTH_HEIGHT,
    XYLIM,
    export_rendered_program,
    render_parsed_program,
    render_stroke_arrays_to_canvas,
)

from dreamcoder_programs.program import Curried, Program, tokeniseProgram

In [None]:
%config InlineBackend.figure_format = 'retina'

# Render each stroke

In [None]:
def display_arrays_as_grid(
    rendered_arrays, suptitle=None, titles=None, ncols=4, transparent_background=True
):
    N = len(rendered_arrays)
    ncols = min(N, ncols)
    nrows = math.ceil(N / ncols)

    # fig = plt.figure(figsize=(20, 20))
    fig = plt.figure(figsize=(4 * ncols, 4 * nrows))

    grid = ImageGrid(
        fig,
        111,
        nrows_ncols=(nrows, ncols),
        axes_pad=0.4,  # pad between axes
        # share_all=True
    )

    for i, A in enumerate(rendered_arrays):
        ax = grid[i]
        ax.imshow(A, cmap="Greys")

        ax.set_xticks([])
        ax.set_yticks([])

    for i in range(N):
        ax = grid[i]
        if titles is not None:
            ax.set_title(titles[i])
        adjust_title(ax)

    if not transparent_background:
        fig.patch.set_facecolor("white")

    if suptitle is not None:
        plt.suptitle(suptitle, fontsize=24, fontweight="bold", va="top")
    fig.tight_layout()

    return fig


def adjust_title(ax):
    title = ax.title
    ax.figure.canvas.draw()

    def _get_t():
        ax_width = ax.get_window_extent().width
        ti_width = title.get_window_extent().width
        return ax_width / ti_width

    while _get_t() <= 1 and title.get_fontsize() > 1:
        title.set_fontsize(title.get_fontsize() - 0.5)


def display_programs_as_grid(programs, max_programs=16, **kwargs):
    rendered_arrays = []
    for p in programs:
        try:
            A = render_parsed_program(p, allow_partial_rendering=True)
            rendered_arrays.append(A)
        except:
            pass
            # A = np.ones((SYNTHESIS_TASK_CANVAS_WIDTH_HEIGHT, SYNTHESIS_TASK_CANVAS_WIDTH_HEIGHT))
        # rendered_arrays.append(A)
        if len(rendered_arrays) == max_programs:
            break
    if len(rendered_arrays) > 0:
        return display_arrays_as_grid(rendered_arrays, **kwargs)
    else:
        print("No valid arrays to display.")
        return None

# Exploring Stitch inventions

In [None]:
INVENTIONS_DIR = "stitch_inventions"
INVENTIONS_FILE = "inventions_nuts_bolts.json"
# INVENTIONS_FILE = "inventions_wheels.json"

with open(os.path.join(INVENTIONS_DIR, INVENTIONS_FILE), "r") as f:
    data = json.load(f)

In [None]:
inv_to_body = {}
inv_to_dreamcoder = {}
inv_to_program_fragment = defaultdict(list)
inv_to_program_full = {}
inv_to_use_context = defaultdict(list)

for inv in data["invs"]:
    inv_to_body[inv["name"]] = inv["body"]
    inv_to_dreamcoder[inv["name"]] = inv["dreamcoder"]
    inv_to_program_full[inv["name"]] = inv["rewritten"]
    # Sort primarily by length, then alphanumerically
    for use in sorted(
        inv["uses"], key=lambda d: (len(list(d.keys())[0]), list(d.keys())[0])
    ):
        program_fragment = list(use.values())[0]
        # inline prior inventions
        for prior_inv_name, prior_inv_dreamcoder in inv_to_dreamcoder.items():
            program_fragment = program_fragment.replace(
                prior_inv_name, prior_inv_dreamcoder
            )
        inv_to_program_fragment[inv["name"]].append(program_fragment)
        inv_to_use_context[inv["name"]].append(list(use.keys())[0])

In [None]:
RENDER_DIR = f"renders/{os.path.splitext(INVENTIONS_FILE)[0]}"
os.makedirs(RENDER_DIR, exist_ok=True)

np.random.seed(123)
for inv_name in inv_to_body.keys():
    print(inv_name)
    # fragments_all = np.array(inv_to_program_fragment[inv_name])
    # idxs_sampled = np.random.choice(list(range(len(fragments_all))), size=min(len(fragments_all), 16), replace=False)
    # idxs_sampled = sorted(list(idxs_sampled))
    # fragments_sampled = list(fragments_all[idxs_sampled])
    fragments_sampled = inv_to_program_fragment[inv_name]
    fig = display_programs_as_grid(
        fragments_sampled,
        suptitle=inv_name,
        titles=inv_to_use_context[inv_name],
        transparent_background=False,
    )
    if fig is not None:
        pass
        plt.savefig(f"{RENDER_DIR}/{inv_name}.png", dpi=144, bbox_inches="tight")

    # break