## Naive Enumerative Search in Python

This is a way for us to quickly test if a given task *can* have a program synthesized for it through naive, recursive enumeration up to a given depth. 

Author: Sean Flannery
Email: sflanner@purdue.edu

### Currently supported DSL
```
program ::= single
single ::=  filterColor(single, color) 
						| recolor(single, color)
						| orthogonal(single, axis)            
						| image
color ::= 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
axis  ::= Y_AXIS | X_AXIS | ROT_90 

Removed XY_AXIS, YX_AXIS, ROT_270, ROT_180 since they can be expressed in our DSL already
```

In [1]:
import os
import json
import numpy as np
from itertools import product
from copy import copy
from functools import lru_cache
from tqdm.notebook import tqdm
from multiprocessing import Pool

### Utility for Fetching our Task Data

In [2]:
# given path to a file containing our data, return 2 lists
# 1 containing all training examples as tuple pairs of numpy arrays
# 1 containing all test examples as tuple pairs of numpy arrays 
def get_task_data(json_path):
    with open(json_path) as json_file_reader:
        json_dict = json.loads(json_file_reader.read())
        # get our input/output examples
        train_examples = []
        for d in json_dict['train']:
            train_examples.append((np.array(d['input']),np.array(d['output'])))
        test_examples = []
        for d in json_dict['test']:
            test_examples.append((np.array(d['input']),np.array(d['output'])))   
        return train_examples, test_examples

#### Sample: Reflection on X, then Y axis

In [3]:
train, test = get_task_data('data/training/9dfd6313.json') # 9dfd6313.json is a reflection on XY axis task

In [4]:
def identity(image, extra_arg=None):
    # anywhere the color is, keep it, else set to 0
    return image

In [5]:
def filterColor(image, color):
    # anywhere the color is, keep it, else set to 0
    return np.where(image == color, image, 0)

In [6]:
def recolor(image, color):
    # set all nonzero pixels to color, else keep it the same (0)
    return np.where(image != 0, color, 0)

In [7]:
def orthogonal(image, axis):
    if   axis == 'Y_AXIS':
        return np.fliplr(image)
    elif axis == 'X_AXIS':
        return np.flipud(image)
    elif axis == 'ROT_90':
        return np.rot90(image,1) # rotate once
    #elif axis == 'ROT_180': 
    #    return np.rot90(image,2) # rotate twice
    #elif axis == 'ROT_270':
    #    return np.rot90(image,3) # rotate thrice
    else:
        raise NotImplemented()

In [8]:
a = np.array([[1,2],[3,4],[0,2]])
print(a)

[[1 2]
 [3 4]
 [0 2]]


In [9]:
print(orthogonal(a, 'Y_AXIS'))

[[2 1]
 [4 3]
 [2 0]]


In [10]:
print(orthogonal(a, 'X_AXIS'))

[[0 2]
 [3 4]
 [1 2]]


In [11]:
print(orthogonal(a, 'ROT_90'))

[[2 4 2]
 [1 3 0]]


In [12]:
filterColor(a,2)

array([[0, 2],
       [0, 0],
       [0, 2]])

In [13]:
recolor(a,4)

array([[4, 4],
       [4, 4],
       [0, 4]])

### Compose all options that our program could take at any level

incomplete: We still have to make this recursive.

idea: Structure each enumeration as a tuple like

if we want to use the raw input:
```
(function, None, args)
```
if we want to use input from a greater depth:
```
(function, program_output_from_greater_depth, args)
```

In [14]:
function_options = [filterColor, recolor, orthogonal]
color_options = list(range(10)) #todo: remove last option?
orth_options = ['Y_AXIS', 'X_AXIS', 'ROT_90']

In [15]:
color_enumerations = list(product([filterColor, recolor], list(range(10))))
orth_enumerations = list(product([orthogonal], orth_options))
#identity_options = [(identity, None)]

In [16]:
total_search_options = color_enumerations + orth_enumerations #+ identity_options

## Important globals to always set with each run of the below functions
```
SEARCH_OPTIONS - the possible functions that could be applied at any level 
NUM_SEARCH_OPTIONS - len of the prior argument
TASK_PAIRS - a copy of the task pairs for the desired task
MAX_LRU_SIZE - max number of LRU entries to keep in memory for run_program_on_task
```

In [17]:
SEARCH_OPTIONS = total_search_options
NUM_SEARCH_OPTIONS = len(total_search_options)
TASK_PAIRS = copy(train) # make a copy of our training data and specify what task we want to work on 
MAX_LRU_SIZE=1024

#### Function to generate a program from a number

This creates a python generator of arbitrary depth WITHOUT loading all digits into memory. This is especially important when our enumeration size is massive...

In [18]:
def programGenerator(num, depth):
    # given a number, and the depth level, we can determine which of the options they wanted
    inner_arg = None # None indicates we should pass the input itself
    for _ in range(depth):
        option = SEARCH_OPTIONS[num % NUM_SEARCH_OPTIONS]
        # get the function, and any additional arguments
        func = option[0]
        args = option[1] # TODO: Expand this beyond just the 1
        inner_arg = (func, inner_arg, args)
        num = num // NUM_SEARCH_OPTIONS
    return inner_arg

In [19]:
gen_maybe = (programGenerator(num,depth=6) for num in range(NUM_SEARCH_OPTIONS**6))

In [20]:
next(gen_maybe) # This is what would happen if we picked 0 at each level

(<function __main__.filterColor(image, color)>,
 (<function __main__.filterColor(image, color)>,
  (<function __main__.filterColor(image, color)>,
   (<function __main__.filterColor(image, color)>,
    (<function __main__.filterColor(image, color)>,
     (<function __main__.filterColor(image, color)>, None, 0),
     0),
    0),
   0),
  0),
 0)

### Function to Run a Given Program on our Task

We use an LRU_CACHE to automatically memoize and prevent us from repeating work on programs we've seen before...

For example, setting a MAXSIZE of 1024, means that we'll hold at most 1024 input hashes and their corresponding outputs in memory. 

### Version 1: No LRU Cache (np.arrays can't be cached)
This is the version that is more intuitive for someone to check things after a program is generated.

In [21]:
def run_program_on_task(task_input, true_output, program):
    # Recurse here for sub-programs (don't worry, we memoized)
    if program[1] is not None: # check if we have an inner arg: 
        inner_was_valid, task_input = run_program_on_task(task_input, true_output, program[1])
        if inner_was_valid: # a subprogram was correct!
            return True, program[1]
    # check that program_output and true_output are the same
    # - program[0] is a reference to a function
    # - program[2] are the additional arguments for the function
    program_output = program[0](task_input,program[2]) 
    return np.array_equal(program_output, true_output), program_output

### Version 2: Heck yeah, some LRU Cache (use the index to the given task's np.arrays, since those are hashable)

In [22]:
@lru_cache(maxsize=MAX_LRU_SIZE) # auto-memoization, but requires globals for inputs/outputs
def run_program_on_task_lru_safe(io_index, program):
    global TASK_PAIRS
    task_input, true_output = TASK_PAIRS[io_index]
    # Recurse here for sub-programs (don't worry, we memoized)
    if program[1] is not None: # check if we have an inner arg: 
        inner_was_valid, task_input = run_program_on_task_lru_safe(io_index, program[1])
        if inner_was_valid: # a subprogram was correct!
            return True, program[1]
    # check that program_output and true_output are the same
    # - program[0] is a reference to a funciton
    # - program[2] are the additional arguments for the function
    program_output = program[0](task_input,program[2]) 
    return np.array_equal(program_output, true_output), program_output

### Enumerative Search at given depth
Make sure the globals for TASK_PAIRS and SEARCH_OPTIONS are set

In [23]:
def enumerative_search(depth=1): 
    global TASK_PAIRS
    global SEARCH_OPTIONS
    global NUM_SEARCH_OPTIONS
    task_range = range(len(TASK_PAIRS))
    # Create a generator that we can use to evaluate all of our potential programs
    program_generator = (programGenerator(num,depth=depth) for num in range(NUM_SEARCH_OPTIONS**depth))
    for attempt_index, program in enumerate(program_generator):
        # TODO: turn this into a function that is parallelizable
        # let's see if this attempt works!
        all_valid = True
        for io_index in task_range:
            is_valid, _ = run_program_on_task_lru_safe(io_index, program)
            if not is_valid:
                all_valid = False
                break
        # if all of it worked out, return the number, and the program itself 
        if all_valid:
            return attempt_index, program
    # nothing was generated. return number of programs checked
    return NUM_SEARCH_OPTIONS**depth, None

The lru_cache memoizes recursive calls (up to MAXSIZE).

### A function to print out a prettier version of our generated functions

In [24]:
def prettify(program_tuple):
    if program_tuple is None:
        return "Failed to generate a program."
    func_name = program_tuple[0].__name__
    args = program_tuple[2]
    # check if it was recursive
    if program_tuple[1] is not None:
        return f"{func_name}({prettify(program_tuple[1])},{args})"
    else:
        return f"{func_name}(input,{args})"

In [25]:
attempt_index, program = enumerative_search(2)
prettify(program)

'orthogonal(orthogonal(input,ROT_90),X_AXIS)'

## An all-in-one function to quickly test an ARC task

In [26]:
def generate_task_solution(task_path, max_depth=4, search_options=total_search_options, max_lru_size=None): 
    global TASK_PAIRS
    global SEARCH_OPTIONS
    global NUM_SEARCH_OPTIONS
    global MAX_LRU_SIZE
    
    if max_lru_size is None:
        MAX_LRU_SIZE = min(NUM_SEARCH_OPTIONS**3,NUM_SEARCH_OPTIONS**(max_depth-1))
    
    SEARCH_OPTIONS = search_options
    NUM_SEARCH_OPTIONS = len(SEARCH_OPTIONS)
    MAX_LRU_SIZE = max_lru_size
    train, test = get_task_data(task_path)
    TASK_PAIRS = copy(train) # we only consider training examples
    
    # We want to clear the cache of any prior examples in case that was for a different task
    run_program_on_task_lru_safe.cache_clear()
    counter = 0
    for depth in range(1,max_depth + 1):
        num_checked, program = enumerative_search(depth)
        counter += num_checked
        if program is not None:
            break
            
    # check if it was None
    if program is None:
        return None, "TRAIN_FAILED", counter
    
    # Now, since we have a program that worked on training data, let's see what happens on test data
    for test_input, true_test_output in test:
        is_valid, _ = run_program_on_task(test_input, true_test_output, program)
        if not is_valid:
            return program, "TEST_FAILED", counter
        
    # Ah, yes, very nice
    return program, "SUCCESS", counter

In [27]:
program, msg, ctr = generate_task_solution('data/training/9dfd6313.json')
print(msg)
print(f"Number of programs checked: {ctr}")
print(f"Generated Program:\n\n{prettify(program)}")

SUCCESS
Number of programs checked: 528
Generated Program:

orthogonal(orthogonal(input,ROT_90),X_AXIS)


### Trying all of our training examples
feel free to change the max_depth!

In [28]:
SOURCE_DIR_PATH = 'data/training/'
files_to_try = [f"{SOURCE_DIR_PATH}{file}" for file in os.listdir('data/training/') if os.path.isfile(f"{SOURCE_DIR_PATH}{file}")]
results = {'SUCCESS':[], 'TEST_FAILED':[], 'TRAIN_FAILED':[]}

for task_index, task_path in tqdm(enumerate(files_to_try),total=len(files_to_try)):
    program, msg, ctr = generate_task_solution(task_path, max_depth=3)
    results[msg].append((task_path, program))

print(
f'''STATISTICS FOR GIVEN DSL WITH MAX DEPTH 3
SUCCESS: {len(results['SUCCESS'])}
TEST_FAILED: {len(results['TEST_FAILED'])}
TRAIN_FAILED: {len(results['TRAIN_FAILED'])}
'''
)

print('\n'.join([f"{task_path}: {prettify(program)}" for task_path, program in results['SUCCESS']]))

  0%|          | 0/179 [00:00<?, ?it/s]

STATISTICS FOR GIVEN DSL WITH MAX DEPTH 3
SUCCESS: 7
TEST_FAILED: 0
TRAIN_FAILED: 172

data/training/ed36ccf7.json: orthogonal(input,ROT_90)
data/training/74dd1130.json: orthogonal(orthogonal(input,ROT_90),X_AXIS)
data/training/6150a2bd.json: orthogonal(orthogonal(input,X_AXIS),Y_AXIS)
data/training/68b16354.json: orthogonal(input,X_AXIS)
data/training/3c9b0459.json: orthogonal(orthogonal(input,X_AXIS),Y_AXIS)
data/training/67a3c6ac.json: orthogonal(input,Y_AXIS)
data/training/9dfd6313.json: orthogonal(orthogonal(input,ROT_90),X_AXIS)


In [29]:
SOURCE_DIR_PATH = 'data/training/'
files_to_try = [f"{SOURCE_DIR_PATH}{file}" for file in os.listdir('data/training/') if os.path.isfile(f"{SOURCE_DIR_PATH}{file}")]
results = {'SUCCESS':[], 'TEST_FAILED':[], 'TRAIN_FAILED':[]}

for task_index, task_path in tqdm(enumerate(files_to_try),total=len(files_to_try)):
    program, msg, ctr = generate_task_solution(task_path, max_depth=4)
    results[msg].append((task_path, program))

print(
f'''STATISTICS FOR GIVEN DSL WITH MAX DEPTH 4
SUCCESS: {len(results['SUCCESS'])}
TEST_FAILED: {len(results['TEST_FAILED'])}
TRAIN_FAILED: {len(results['TRAIN_FAILED'])}
'''
)

print('\n'.join([f"{task_path}: {prettify(program)}" for task_path, program in results['SUCCESS']]))

  0%|          | 0/179 [00:00<?, ?it/s]

STATISTICS FOR GIVEN DSL WITH MAX DEPTH 4
SUCCESS: 7
TEST_FAILED: 1
TRAIN_FAILED: 171

data/training/ed36ccf7.json: orthogonal(input,ROT_90)
data/training/74dd1130.json: orthogonal(orthogonal(input,ROT_90),X_AXIS)
data/training/6150a2bd.json: orthogonal(orthogonal(input,X_AXIS),Y_AXIS)
data/training/68b16354.json: orthogonal(input,X_AXIS)
data/training/3c9b0459.json: orthogonal(orthogonal(input,X_AXIS),Y_AXIS)
data/training/67a3c6ac.json: orthogonal(input,Y_AXIS)
data/training/9dfd6313.json: orthogonal(orthogonal(input,ROT_90),X_AXIS)


In [30]:
print('\n'.join([f"{task_path}: {prettify(program)}" for task_path, program in results['TEST_FAILED']]))

data/training/f76d97a5.json: recolor(recolor(recolor(filterColor(input,5),9),6),4)
