In [38]:
import os, glob, json, time, random
from arc.datatypes import *
from arc import utils


dataset_path = "../dataset"
json_file_paths = glob.glob(os.path.join(dataset_path, "*.json"))

all_tasks: list[TaskDict] = []

def load_dataset() -> list[TaskDict]:
    _all_tasks = []
    for json_file_path in json_file_paths:
        task_id = os.path.basename(json_file_path).split(".")[0]
        try:
            with open(json_file_path, "r") as f:
                task_json = json.load(f)
                if isinstance(task_json, list) and len(task_json) > 0:
                    _all_tasks.append({
                        "file_path": json_file_path,
                        "task_id": task_id,
                        "examples": task_json,
                    })
        except Exception as e:
            print(f"Error loading {json_file_path}: {e}")

    return _all_tasks

if not all_tasks:
    all_tasks = load_dataset()
all_tasks[0]

{'file_path': '../dataset/22168020.json',
 'task_id': '22168020',
 'examples': [{'input': [[6, 6, 0, 6, 6, 6, 6, 6, 6, 4],
    [6, 6, 6, 0, 0, 6, 6, 4, 4, 6],
    [6, 6, 6, 0, 0, 7, 7, 4, 4, 6],
    [6, 6, 0, 6, 6, 7, 7, 6, 6, 4],
    [3, 6, 6, 3, 7, 6, 6, 7, 6, 6],
    [6, 3, 3, 7, 6, 6, 6, 6, 7, 6],
    [6, 3, 3, 6, 6, 6, 6, 6, 6, 6]],
   'output': [[6, 6, 0, 6, 6, 6, 6, 6, 6, 4],
    [6, 6, 0, 0, 0, 6, 6, 4, 4, 4],
    [6, 6, 0, 0, 0, 7, 7, 4, 4, 4],
    [6, 6, 0, 6, 6, 7, 7, 6, 6, 4],
    [3, 3, 3, 3, 7, 7, 7, 7, 6, 6],
    [6, 3, 3, 7, 7, 7, 7, 7, 7, 6],
    [6, 3, 3, 6, 6, 6, 6, 6, 6, 6]]},
  {'input': [[0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 8, 0, 0, 0],
    [0, 0, 0, 0, 8, 0, 0],
    [0, 0, 0, 0, 0, 8, 8],
    [0, 0, 0, 0, 0, 8, 8],
    [0, 0, 0, 0, 8, 0, 0],
    [0, 0, 0, 8, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0]],
   'output': [[0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 8, 0, 0, 0],
    [0, 0, 0, 8, 8, 0, 0],
    [0, 0, 0, 8, 8, 8, 8]

In [39]:
def sample_task(idx: int) -> dict:
    task = all_tasks[idx]
    examples = task["examples"]

    sampled_indices = random.sample(range(len(examples)), 4)
    train_example_indices = sampled_indices[:3]
    test_example_index = sampled_indices[3]
    
    train_examples = [examples[i] for i in train_example_indices]
    test_exapmle = examples[test_example_index]
    test_example_input = test_exapmle["input"]
    test_example_output = test_exapmle["output"]

    return {
        "datapoint": {
            "train": train_examples,
            "test": [{"input": test_example_input}],
        },
        "test_output": test_example_output,
        "train_indices": train_example_indices,
        "test_index": test_example_index,
    }

In [40]:
def grid_to_str(grid: list[list[int]]) -> str:
    return "\n".join("".join(str(cell) for cell in row) for row in grid) + "\n"

In [41]:
def format_prompt(datapoint: DataPointDict) -> list[dict]:
    train_examples = datapoint["train"]
    test_input_grid = datapoint["test"][0]["input"]
    
    messages = [
        {"role": "system", "content": utils.system_prompt},
    ]
    
    msg = f"{utils.user_message_template1}\n"
    for ex in train_examples:
        msg += (
            f"input:\n{grid_to_str(ex['input'])}\n"
            f"output:\n{grid_to_str(ex['output'])}\n"
        )
        
    test_msg = (
        f"\n{utils.user_message_template2}\n"
        f"{utils.user_message_template3}\n"
        f"input:\n{grid_to_str(test_input_grid)}\n"
    )
    messages.append({"role": "user", "content": msg + test_msg})
    
    return messages


In [None]:
datapoint, test_output_grid = sample_task(0)
input_messages = format_prompt(datapoint)
test_output_text = grid_to_str(test_output_grid)
print("Input Messages:")
print(input_messages)
print("Test Output Text:")
print(test_output_text)

Input Messages:
[{'role': 'system', 'content': 'You are a puzzle solving wizard. You are given a puzzle from the abstraction and reasoning corpus developed by Francois Chollet.'}, {'role': 'user', 'content': 'Here are the example input and output pairs from which you should learn the underlying rule to later predict the output for the given test input:\n----------------------------------------\ninput:\n666666666\n666666666\n666666666\n666666666\n669666696\n666966966\n666699666\n666699666\n666666666\n\noutput:\n666666666\n666666666\n666666666\n666666666\n669999996\n666999966\n666699666\n666699666\n666666666\n\ninput:\n333333333\n333333333\n333333333\n333333333\n333333333\n383333833\n338338333\n333883333\n333883333\n\noutput:\n333333333\n333333333\n333333333\n333333333\n333333333\n388888833\n338888333\n333883333\n333883333\n\ninput:\n777117\n777117\n771771\n777977\n777799\n777799\n777977\n777777\n777777\n\noutput:\n777117\n777117\n771111\n777977\n777999\n777999\n777977\n777777\n777777\n\

In [42]:
from dotenv import load_dotenv

load_dotenv()

True

In [43]:
from openai import OpenAI

client = OpenAI()

def generate_response(messages: list[dict], model: str = "o4-mini"):
    response = client.responses.create(
        model=model,
        reasoning={"effort": "medium", "summary": "detailed"},
        input=messages,
    )
    return response

In [7]:
datapoint, test_output_grid = sample_task(0)
input_messages = format_prompt(datapoint)
test_output_text = grid_to_str(test_output_grid)

response = generate_response(input_messages)

ValueError: too many values to unpack (expected 2)

In [44]:
def extract_reasoning_summary(response) -> str:
    texts = []
    for output in response.output:
        if output.type == "reasoning":
            for summary in output.summary:
                if summary.type == "summary_text":
                    texts.append(summary.text)
    return "".join(texts)


# output_text = response.output_text
# reasoning_summary = extract_reasoning_summary(response)

# print("Output Text:")
# print(output_text)
# print("Reasoning Summary:")
# print(reasoning_summary)

In [45]:
import os
result_save_path = "reasoning_summary_results"
os.makedirs(result_save_path, exist_ok=True)

def generate_and_save_reasoning(task_idx):
    task = all_tasks[task_idx]
    task_id = task["task_id"]
    
    sampled_result = sample_task(task_idx)
    datapoint = sampled_result["datapoint"]
    test_output_grid = sampled_result["test_output"]

    train_indices = sampled_result["train_indices"]
    test_index = sampled_result["test_index"]

    input_messages = format_prompt(datapoint)
    test_output_text = grid_to_str(test_output_grid)

    response = generate_response(input_messages)
    output_text = response.output_text
    reasoning_summary = extract_reasoning_summary(response)

    teacher_correct = test_output_text.strip() == output_text.strip()

    result = {
        "task_id": task_id,
        "train_indices": train_indices,
        "test_index": test_index,

        "datapoint": datapoint,
        "input_messages": input_messages,
        "output_text": output_text,
        "reasoning": reasoning_summary,

        "test_output_text": test_output_text,
        "correct": teacher_correct,  
    }

    train_indices_postfix = "train_" + "_".join(map(str, train_indices))
    test_index_postfix = "test_" + str(test_index)
    with open(os.path.join(result_save_path, f"{task_id}_{train_indices_postfix}_{test_index_postfix}.json"), "w") as f:
        json.dump(result, f, indent=4)
    print(f"Saved result for task {task_id} to {result_save_path}/{task_id}.json")

In [None]:
import time
import json
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Semaphore

# Config
num_test = len(all_tasks)
num_iter_per_task = 3
MAX_WORKERS = 4
REQUESTS_PER_MINUTE = 100
INTERVAL = 60.0 / REQUESTS_PER_MINUTE
rate_limit_semaphore = Semaphore(MAX_WORKERS)

def safe_generate_and_save_reasoning(task_idx):
    with rate_limit_semaphore:
        start = time.time()
        try:
            generate_and_save_reasoning(task_idx)
        except Exception as e:
            print(f"Error on task {task_idx}: {e}")
        elapsed = time.time() - start
        if elapsed < INTERVAL:
            time.sleep(INTERVAL - elapsed)

def run_parallel(num_tasks: int, num_iter_per_task: int):
    with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
        futures = []
        for i in range(num_tasks - 1, -1, -1):
            for _ in range(num_iter_per_task):
                futures.append(executor.submit(safe_generate_and_save_reasoning, i))
        for future in as_completed(futures):
            future.result()  # will raise exception if any occurred

print(f"Running {num_iter_per_task} iterations for each of the {num_test} tasks in parallel...")
run_parallel(num_test, num_iter_per_task)

Running 3 iterations for each of the 300 tasks in parallel...
Saved result for task 4c4377d9 to reasoning_summary_results/4c4377d9.json
Saved result for task 4c4377d9 to reasoning_summary_results/4c4377d9.json
Saved result for task 4c4377d9 to reasoning_summary_results/4c4377d9.json
Saved result for task 91413438 to reasoning_summary_results/91413438.json
Saved result for task 91413438 to reasoning_summary_results/91413438.json
Saved result for task 6773b310 to reasoning_summary_results/6773b310.json
Saved result for task 91413438 to reasoning_summary_results/91413438.json
Saved result for task dc1df850 to reasoning_summary_results/dc1df850.json
Saved result for task 6773b310 to reasoning_summary_results/6773b310.json
Saved result for task 6773b310 to reasoning_summary_results/6773b310.json
Saved result for task a2fd1cf0 to reasoning_summary_results/a2fd1cf0.json
Saved result for task a2fd1cf0 to reasoning_summary_results/a2fd1cf0.json
Saved result for task a2fd1cf0 to reasoning_summar