In [4]:
# riddle synth prerequisites
import torch
import json
import matplotlib.pyplot as plt
from matplotlib import colors

from arc.utils import dataset
from riddle_synth import NodeFactory, load_configuration_file, register_functions, InputSampler, SynthRiddleGen1, print_image


# load riddle-synth configuration
cfg = load_configuration_file('../../../riddle_synth/riddle_synth/config/all_depth_3_10k.json')

f = NodeFactory()
register_functions(f)

function_names = list(f.functions.keys())
function_names.sort()
print(f"Number of functions: {len(function_names)}")
#print(json.dumps(function_names, indent=2))

eval_riddle_ids = dataset.get_riddle_ids(cfg.input_sampler.subdirs)
if cfg.input_sampler.first_n is not None and cfg.input_sampler.first_n > 0:
    eval_riddle_ids = eval_riddle_ids[: cfg.input_sampler.first_n]
input_sampler = InputSampler(
    eval_riddle_ids,
    include_outputs=cfg.input_sampler.include_outputs,
    include_test=cfg.input_sampler.include_test,
    color_permutation=cfg.input_sampler.color_permutation,
    random_offsets=cfg.input_sampler.random_offsets,
    add_noise_p=cfg.input_sampler.add_noise_p,
    noise_p=cfg.input_sampler.noise_p,
    add_parts_p=cfg.input_sampler.add_parts_p,
    parts_min=cfg.input_sampler.parts_min,
    parts_max=cfg.input_sampler.parts_max,
    min_width=cfg.input_sampler.min_width,
    min_height=cfg.input_sampler.min_height,
    max_width=cfg.input_sampler.max_width,
    max_height=cfg.input_sampler.max_height,
)
#print(f"Total boards: {len(input_sampler.boards)}")

riddle_gen = SynthRiddleGen1(
    node_factory=f,
    input_sampler=input_sampler,
    sample_node_count=cfg.sample_node_count,
    min_depth=cfg.min_depth,
    max_depth=cfg.max_depth,
    max_input_sample_tries=cfg.max_input_sample_tries,
    min_examples=cfg.min_examples,
    max_examples=cfg.max_examples,
    function_names=function_names,
)

Number of functions: 180


In [5]:
# collect input output examples for functions (this may take some minutes...)

from riddle_synth import FunctionNode, Image, badImg

num_samples_per_func = 5
func_examples = { fn: [] for fn in function_names }
remaining = set(function_names)

i = 0
N = 5000
#N = 250

print('initial remaining', len(remaining))
while i < N:
    xs, g, node = riddle_gen.generate_riddle()
    if xs == None:
        continue
    i += 1

    if i % 10 == 0:
        print('remaining', i, len(remaining))

    # ensure that graph has at least one function we are still interest in
    if not any(filter(lambda x: isinstance(x, FunctionNode) and x.fn.name in remaining, g.nodes)):
        continue
    
    input_image = xs[0][0]
    graph_outputs = g.evaluate(input_image)

    for n in g.nodes:
        if isinstance(n, FunctionNode):
            fn_name = n.fn.name
            if fn_name not in remaining:
                continue

            # get function inputs and output
            input_ids = [src_node.id for src_node in n.input_nodes]
            input_args = [graph_outputs[id] for id in input_ids]
            output_value = graph_outputs[n.id]

            if isinstance(output_value, list):
                if len(output_value) == 0:
                    continue    # ignore empty lists
            if isinstance(output_value, Image):
                if output_value == badImg:
                    continue    # ignore 0 images
                if len(input_args) == 1 and isinstance(input_args[0], Image) and input_args[0] == output_value:
                    continue    # no change in image
                if len(input_args) == 2 and isinstance(input_args[0], Image) and isinstance(input_args[1], Image) and input_args[0] == input_args[1]:
                    continue    # ignore binary functions with same input args
                if len(input_args) == 2 and isinstance(input_args[0], Image) and isinstance(input_args[1], Image) and (output_value == input_args[0] or output_value == input_args[1]):
                    continue    # ignore binary functions which returned directly one of its inputs

            if len(func_examples[fn_name]) < num_samples_per_func:
                func_examples[fn_name].append({'input': input_args, 'output': output_value})
            else:
                remaining.discard(fn_name)

print('after remaining', len(remaining))



initial remaining 180
remaining 10 180
remaining 20 180
remaining 30 180
remaining 40 180
remaining 50 180
remaining 60 180
remaining 70 179
remaining 80 177
remaining 90 177
remaining 100 176


KeyboardInterrupt: 

In [3]:
import pickle
from pathlib import Path

print('remaining:', remaining)

filename = Path('./func_examples.pkl')
with filename.open('wb') as file:
    pickle.dump({'func_examples': func_examples, 'remaining': remaining}, file)

remaining: {'rigid_0', 'pick_not_maxes_7', 'to_origin', 'get_size0'}


In [106]:
from IPython.display import display, display_markdown

def plot_image(ax, image, title=''):
    cmap = colors.ListedColormap(
        ['#000000', '#0074D9','#FF4136','#2ECC40','#FFDC00',
         '#AAAAAA', '#F012BE', '#FF851B', '#7FDBFF', '#870C25'])
    norm = colors.Normalize(vmin=0, vmax=9)
    input_matrix = image.np
    ax.imshow(input_matrix, cmap=cmap, norm=norm)
    ax.grid(True, which='both', color='lightgrey', linewidth=0.5)    
    ax.set_yticks([x-0.5 for x in range(1+len(input_matrix))])
    ax.set_xticks([x-0.5 for x in range(1+len(input_matrix[0]))])     
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_title(title)

def plot_imagelist(p, title):
    if isinstance(p, Image):
        fig, ax = plt.subplots(figsize=(2,2), constrained_layout=True)
        fig.suptitle(title, fontsize=16)
        plot_image(ax, p, str(p))
    elif isinstance(p, list):
        if len(p) > 1:
            fig, axs = plt.subplots(ncols=len(p), nrows=1, figsize=(len(p), 1), constrained_layout=True)
            fig.suptitle(title, fontsize=16)
            for j,li in enumerate(p):
                plot_image(axs[j], li, str(li))
        elif len(p) == 1:
            plot_imagelist(p[0], title)


def plot_fn_example(fn_name, input_args, output_value):
    for i,p in enumerate(input_args):
        plot_imagelist(p, f'{fn_name} arg{i}')
    plot_imagelist(output_value, f'{fn_name} out') 
    plt.show()


# visualize func_examples
def show_function_outputs(fn_name, v):
    display_markdown(f'# {fn_name}\n------', raw=True)
    for j,example in enumerate(v):

        fn_dsc = f.functions[fn_name]
        s = f'{fn_name} (' + ','.join([t.name for t in fn_dsc.parameter_types]) + ') -> ' + fn_dsc.return_type.name
        print(f'Example: {j+1}/{len(v)}: {s}')
        
        input_args = example['input']
        output_value = example['output']
        plot_fn_example(fn_name, input_args, output_value)


# Icecuber DSL function types

```
1. Unary:  Image -> Image
2. Binary: (Image, Image) -> Image
3. Split:  Image -> List[Image]
4. Join:   List[Image] -> Image
5. Vector: List[Image] -> List[Image]
```

In [None]:
available_functions = sorted(func_examples.keys())
print('total:', len(available_functions))

# 'border' to 'fill'

# visualize func_examples
for fn_name in available_functions[:40]:
    show_function_outputs(fn_name, func_examples[fn_name])