In [1]:
from utils.dsl import *
from utils.constants import *
import json
import re
import numpy as np
import os
from pathlib import Path


## Manual Verification: Test a Single Solver Function

Use the cells below to manually test one solver function on a specific example of a specific task.

**Configure these variables:**
- `task_id`: The ARC task identifier (e.g., '0934a4d8')
- `example_index`: Which example to test (0-based index)
- `test_or_train`: Dataset to use ('train' or 'test')
- Paste your solver function into the `solve()` function below


In [2]:
def solve(I):
    # Paste solver function here
    pass


In [3]:
task_id = 'insert_task_id_here'
example_index = 0
test_or_train = 'test' # 'train' or 'test'

def open_and_solve_example(task_id, example_index, solver, set='train'):
    path = Path(f'../data_v2/evaluation/{task_id}.json')
    if not path.exists():
        print(f"Error processing {task_id} example {example_index}")
        return None, None, None
    try:
        with path.open('r', encoding='utf-8') as f:
            task = json.load(f)
        I = tuple(tuple(row) for row in task[set][example_index]['input'])
        expected = task[set][example_index]['output']
        output = solver(I)
        return I, expected, output
    except Exception:
        print(f"Error processing {task_id} example {example_index}")
        return None, None, None

def verify_output(expected, output):
    if expected == None or output == None:
        print("No output to verify")
        return
    if np.array_equal(np.array(expected), np.array(output)):
        print("Output matches expected!")
    else:
        print("Output does NOT match expected.")
        print("Expected:")
        print(np.array(expected))
        print("Got:")
        print(np.array(output))

In [5]:
# Run programs and verify if output matches
input, expected, output = open_and_solve_example(task_id, example_index, solve, set=test_or_train)
verify_output(expected, output)

Error processing insert_task_id_here example 0
No output to verify


## Batch Verification: Test All Solver Functions from Log Files

Automatically extracts and tests solver functions from log files on both train and test datasets.

**Configure these variables:**
- `logs_dir`: Path to directory containing log files with solver functions
- `test_these_examples`: Range of examples indices to test (e.g., range(0, 6) tests indices 0-5)

The script will:
1. Extract solver functions from summary files in the log directory
2. Test each function on all specified examples indices for both train and test sets of all tasks
3. Print detailed results and summary statistics for each dataset


In [None]:
# Specify directory containing log files
logs_dir = Path('../logs/exp_name:testing_noreasoning_nodsl_k4_similar_fewshot_repair')
# Settings
test_these_examples = range(0, 6)  # Test indices 0 to 5

# Find all summary files
summary_files = [f for f in logs_dir.glob('*.txt') if 'selection' in f.name]
print(f"Found {len(summary_files)} summary files")

def test_solver_on_dataset(summary_files, dataset_type, test_these_examples):
    """Test all solver functions on a specific dataset (train or test)."""
    # Track results per file across all indices
    file_results = {}
    
    for file_path in summary_files:
        filename = file_path.name
        task_id_local = filename.split('_')[0]
        
        file_results[filename] = {
            'executed_all': True,
            'correct_all': True,
            'results_per_index': {}
        }
        
        try:
            # Read the file content
            with open(file_path, 'r', encoding='utf-8') as f:
                content = f.read()
            
            # Load task to check available tasks
            try:
                with open(f'../data_v2/evaluation/{task_id_local}.json') as f:
                    task = json.load(f)
            except FileNotFoundError:
                print(f"Skipping {filename}: task file not found")
                file_results[filename]['executed_all'] = False
                file_results[filename]['correct_all'] = False
                continue
            
            # Extract solver function code
            gen_code_idx = content.find('CODE:')
            if gen_code_idx == -1: 
                print(f"Skipping {filename}: markers not found")
                file_results[filename]['executed_all'] = False
                file_results[filename]['correct_all'] = False
                continue
            
            code_section = content[gen_code_idx:]
            def_match = re.search(r'\bdef\s+\w+', code_section)
            
            if not def_match:
                print(f"Skipping {filename}: no function definition found")
                file_results[filename]['executed_all'] = False
                file_results[filename]['correct_all'] = False
                continue
            
            # Extract function body
            def_start = def_match.start()
            return_matches = list(re.finditer(r'\n\s*return\s+.*', code_section[def_start:]))
            
            if not return_matches:
                print(f"Skipping {filename}: no return statement found")
                file_results[filename]['executed_all'] = False
                file_results[filename]['correct_all'] = False
                continue
            
            last_return = return_matches[-1]
            function_end = def_start + last_return.end()
            solver_code = code_section[def_start:function_end]
            
            # Create a fresh namespace with all globals for this solver
            exec_namespace = dict(globals())
            
            # Execute the solver function
            exec(solver_code, exec_namespace)
            func_name_match = re.search(r'def\s+(\w+)', solver_code)
            
            if not func_name_match:
                print(f"Skipping {filename}: could not extract function name")
                file_results[filename]['executed_all'] = False
                file_results[filename]['correct_all'] = False
                continue
            
            func_name = func_name_match.group(1)
            solver_func = exec_namespace[func_name]
            
            # Test across all task indices
            print(f"\nTesting {filename}:")
            for test_task_index in test_these_examples:
                # Check if task_index exists
                if len(task[dataset_type]) <= test_task_index:
                    print(f"  Index {test_task_index}: SKIP (not available)")
                    file_results[filename]['results_per_index'][test_task_index] = {
                        'executed': False,
                        'correct': False,
                        'skipped': True
                    }
                    continue
                
                try:
                    # Solve task
                    I, expected, output = open_and_solve_example(task_id_local, test_task_index, solver_func, set=dataset_type)
                    
                    # Check if output matches expected
                    diffs = np.where(np.array(expected) != np.array(output))
                    is_correct = not diffs[0].size
                    
                    file_results[filename]['results_per_index'][test_task_index] = {
                        'executed': True,
                        'correct': is_correct,
                        'skipped': False
                    }
                    
                    if is_correct:
                        print(f"  Index {test_task_index}: ✓ Correct")
                    else:
                        print(f"  Index {test_task_index}: ✗ Incorrect")
                        file_results[filename]['correct_all'] = False
                        
                except Exception as e:
                    print(f"  Index {test_task_index}: ✗ Error - {type(e).__name__}")
                    file_results[filename]['results_per_index'][test_task_index] = {
                        'executed': False,
                        'correct': False,
                        'skipped': False
                    }
                    file_results[filename]['executed_all'] = False
                    file_results[filename]['correct_all'] = False
                    
        except Exception as e:
            print(f"✗ {filename}: Fatal error - {type(e).__name__}: {str(e)[:100]}")
            file_results[filename]['executed_all'] = False
            file_results[filename]['correct_all'] = False
    
    return file_results

def print_summary(file_results, dataset_type, test_these_examples, total_files):
    """Print summary statistics for a dataset."""
    files_executed_all = sum(1 for r in file_results.values() if r['executed_all'])
    files_correct_all = sum(1 for r in file_results.values() if r['correct_all'])
    
    # Per-index statistics
    per_index_stats = {}
    for idx in test_these_examples:
        executed = sum(1 for r in file_results.values() if idx in r['results_per_index'] and r['results_per_index'][idx]['executed'])
        correct = sum(1 for r in file_results.values() if idx in r['results_per_index'] and r['results_per_index'][idx]['correct'])
        skipped = sum(1 for r in file_results.values() if idx in r['results_per_index'] and r['results_per_index'][idx]['skipped'])
        per_index_stats[idx] = {'executed': executed, 'correct': correct, 'skipped': skipped}
    
    print(f"\n{'='*60}")
    print(f"SUMMARY FOR {dataset_type.upper()} SET:")
    print(f"{'='*60}")
    print(f"Total files: {total_files}")
    print(f"Programs that executed without error for all examples: {files_executed_all}")
    print(f"Programs with correct output for all examples: {files_correct_all}")
    print(f"\nPER-EXAMPLE-INDEX BREAKDOWN:")
    for idx in test_these_examples:
        stats = per_index_stats[idx]
        print(f"  Index {idx}: {stats['executed']} executed, {stats['correct']} correct, {stats['skipped']} skipped")
    print(f"{'='*60}")

# Test on train set
print("\n" + "="*60)
print("TESTING ON TRAIN SET")
print("="*60)
train_results = test_solver_on_dataset(summary_files, 'train', test_these_examples)

# Test on test set
print("\n" + "="*60)
print("TESTING ON TEST SET")
print("="*60)
test_results = test_solver_on_dataset(summary_files, 'test', test_these_examples)

# Print summary
print_summary(train_results, 'train', test_these_examples, len(summary_files))
print_summary(test_results, 'test', test_these_examples, len(summary_files))


Found 101 summary files

TESTING ON TRAIN SET

Testing 7b3084d4_repair_selection_summary.txt:
Error processing 7b3084d4 example 0
  Index 0: ✗ Error - ValueError
Error processing 7b3084d4 example 1
  Index 1: ✗ Error - ValueError
Error processing 7b3084d4 example 2
  Index 2: ✗ Error - ValueError
  Index 3: SKIP (not available)
  Index 4: SKIP (not available)
  Index 5: SKIP (not available)

Testing 8f3a5a89_repair_selection_summary.txt:
Error processing 8f3a5a89 example 0
  Index 0: ✗ Error - ValueError
Error processing 8f3a5a89 example 1
  Index 1: ✗ Error - ValueError
Error processing 8f3a5a89 example 2
  Index 2: ✗ Error - ValueError
  Index 3: SKIP (not available)
  Index 4: SKIP (not available)
  Index 5: SKIP (not available)

Testing 7491f3cf_repair_selection_summary.txt:
Error processing 7491f3cf example 0
  Index 0: ✗ Error - ValueError
Error processing 7491f3cf example 1
  Index 1: ✗ Error - ValueError
Error processing 7491f3cf example 2
  Index 2: ✗ Error - ValueError
Error 