In [1]:
import os, glob, json, time, random
from arc.datatypes import *
from arc import utils

dataset_path = "../dataset"
json_file_paths = glob.glob(os.path.join(dataset_path, "*.json"))

all_tasks: list[TaskDict] = []

def load_dataset() -> list[TaskDict]:
    _all_tasks = []
    for json_file_path in json_file_paths:
        task_id = os.path.basename(json_file_path).split(".")[0]
        try:
            with open(json_file_path, "r") as f:
                task_json = json.load(f)
                if isinstance(task_json, list) and len(task_json) > 0:
                    _all_tasks.append({
                        "file_path": json_file_path,
                        "task_id": task_id,
                        "examples": task_json,
                    })
        except Exception as e:
            print(f"Error loading {json_file_path}: {e}")

    return _all_tasks

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
if not all_tasks:
    all_tasks = load_dataset()
all_tasks[0]

{'file_path': '../dataset/22168020.json',
 'task_id': '22168020',
 'examples': [{'input': [[6, 6, 0, 6, 6, 6, 6, 6, 6, 4],
    [6, 6, 6, 0, 0, 6, 6, 4, 4, 6],
    [6, 6, 6, 0, 0, 7, 7, 4, 4, 6],
    [6, 6, 0, 6, 6, 7, 7, 6, 6, 4],
    [3, 6, 6, 3, 7, 6, 6, 7, 6, 6],
    [6, 3, 3, 7, 6, 6, 6, 6, 7, 6],
    [6, 3, 3, 6, 6, 6, 6, 6, 6, 6]],
   'output': [[6, 6, 0, 6, 6, 6, 6, 6, 6, 4],
    [6, 6, 0, 0, 0, 6, 6, 4, 4, 4],
    [6, 6, 0, 0, 0, 7, 7, 4, 4, 4],
    [6, 6, 0, 6, 6, 7, 7, 6, 6, 4],
    [3, 3, 3, 3, 7, 7, 7, 7, 6, 6],
    [6, 3, 3, 7, 7, 7, 7, 7, 7, 6],
    [6, 3, 3, 6, 6, 6, 6, 6, 6, 6]]},
  {'input': [[0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 8, 0, 0, 0],
    [0, 0, 0, 0, 8, 0, 0],
    [0, 0, 0, 0, 0, 8, 8],
    [0, 0, 0, 0, 0, 8, 8],
    [0, 0, 0, 0, 8, 0, 0],
    [0, 0, 0, 8, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0]],
   'output': [[0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 8, 0, 0, 0],
    [0, 0, 0, 8, 8, 0, 0],
    [0, 0, 0, 8, 8, 8, 8]

In [3]:
# Inference with thinking

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B")

In [35]:
from arc.utils import *

def sample_task(idx: int) -> tuple[DataPointDict, Grid]:
    task = all_tasks[idx]
    examples = task["examples"]
    
    train_examples = examples[:3]
    test_example_input = examples[3]["input"]
    test_example_output = examples[3]["output"]
    
    return {
        "train": train_examples,
        "test": [{"input": test_example_input}],
    }, test_example_output

def grid_to_str(grid: list[list[int]]) -> str:
    return "\n".join("".join(str(cell) for cell in row) for row in grid) + "\n"

def format_prompt(datapoint: DataPointDict) -> str:
    train_examples = datapoint["train"]
    test_input_grid = datapoint["test"][0]["input"]
    
    messages = [
        {"role": "system", "content": utils.system_prompt},
    ]
    msg = f"{user_message_template1}\n"
    for ex in train_examples:
        msg += (
            f"input:\n{grid_to_str(ex['input'])}\n"
            f"output:\n{grid_to_str(ex['output'])}\n"
        )

    test_msg = (
        f"\n{user_message_template2}\n"
        f"{user_message_template3}\n"
        f"input:\n{grid_to_str(test_input_grid)}\n"
    )
    messages.append({"role": "user", "content": msg + test_msg})
    
    print(messages)
    
    input_text = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=False,
        enable_thinking=True,
    )
    return input_text

In [36]:
datapoint, test_output_grid = sample_task(0)
input_text = format_prompt(datapoint)
print(input_text)
test_out_text = grid_to_str(test_output_grid)
print("---------------")
print(test_out_text)

[{'role': 'system', 'content': 'You are a puzzle solving wizard. You are given a puzzle from the abstraction and reasoning corpus developed by Francois Chollet.'}, {'role': 'user', 'content': 'Here are the example input and output pairs from which you should learn the underlying rule to later predict the output for the given test input:\n----------------------------------------\ninput:\n6606666664\n6660066446\n6660077446\n6606677664\n3663766766\n6337666676\n6336666666\n\noutput:\n6606666664\n6600066444\n6600077444\n6606677664\n3333777766\n6337777776\n6336666666\n\ninput:\n0000000\n0008000\n0000800\n0000088\n0000088\n0000800\n0008000\n0000000\n0000000\n0000000\n\noutput:\n0000000\n0008000\n0008800\n0008888\n0008888\n0008800\n0008000\n0000000\n0000000\n0000000\n\ninput:\n333333333\n343333433\n334334333\n333443333\n333443333\n333333333\n\noutput:\n333333333\n344444433\n334444333\n333443333\n333443333\n333333333\n\n\n----------------------------------------\nNow, solve the following puzzle

In [None]:
from transformers import AutoModelForCausalLM
from transformers import BitsAndBytesConfig


bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,  # Enable 4-bit quantization
    bnb_4bit_use_double_quant=True,  # Use double quantization for improved precision
    bnb_4bit_quant_type="nf4",  # Specify the quantization type
    bnb_4bit_compute_dtype=torch.float16,  # Set the computation data type
)

model_args = {
    "pretrained_model_name_or_path": "Qwen/Qwen3-4B",
    "trust_remote_code": True,  # Allow the model to use custom code from the repository
    "quantization_config": bnb_config,  # Apply the 4-bit quantization configuration
    "attn_implementation": "sdpa",  # Use scaled-dot product attention for better performance
    "torch_dtype": torch.float16,  # Set the data type for the model
    "use_cache": False,  # Disable caching to save memory
    "token": None,
    "device_map": "auto",  # Automatically map the model to available devices
}
        
model = AutoModelForCausalLM.from_pretrained(**model_args)
model.eval()

Loading checkpoint shards: 100%|██████████| 3/3 [00:07<00:00,  2.51s/it]


In [13]:
model_inputs = tokenizer([input_text], return_tensors="pt").to(model.device)
output_ids = model.generate(**model_inputs, max_new_tokens=32768)

In [25]:
output_ids_ = output_ids[0][len(model_inputs.input_ids[0]):].tolist()

try:
    # rindex finding 151668 (</think>)
    index = len(output_ids) - output_ids_[::-1].index(151668)
except ValueError:
    index = 0

# 
thinking_content = tokenizer.decode(output_ids_[:index], skip_special_tokens=True).strip("\n")
content = tokenizer.decode(output_ids_[index:], skip_special_tokens=True).strip("\n")

print("thinking content:\n", thinking_content)
print()
print("content:\n", content)

thinking content:
 <think>
Okay, let's try to figure out the pattern here. I need to find the rule that transforms the input grid into the output grid based on the examples provided. Let me start by analyzing the first example input and output.

Looking at the first example input:
Input lines are:
6606666664
6660066446
6660077446
6606677664
3663766766
6337666676
6336666666

Output is:
6606666664
6600066444
6600077444
6606677664
3333777766
6337777776
6336666666

Hmm. Let me compare the first input line with the output. The first input line is 6606666664, and the output is the same. So maybe that line is unchanged. But the second input line is 6660066446, and the output is 6600066444. Wait, the second line in input is 6660066446. The output is 6600066444. So the third character changes from 0 to 0? Wait, maybe looking at each position. Let me check each line.

Alternatively, maybe the rule is similar to a cellular automaton. For example, in the second example, the input is 0000000, then 

In [26]:
short_generation_ids = model.generate(**model_inputs, max_new_tokens=512)
short_output_ids_ = short_generation_ids[0][len(model_inputs.input_ids[0]):].tolist()

content = tokenizer.decode(short_output_ids_, skip_special_tokens=True).strip("\n")

In [28]:
print(content)

<think>
Okay, let's try to figure out the pattern here. I need to look at the examples provided to find the rule that transforms the input grid into the output grid. Let me start by analyzing the first example.

First example input:
6606666664
6660066446
6660077446
6606677664
3663766766
6337666676
6336666666

Output:
6606666664
6600066444
6600077444
6606677664
3333777766
6337777776
6336666666

Looking at the first line, input is 6606666664 and output is same. So maybe the first line is unchanged. Then the second input line is 6660066446 and output is 6600066444. Hmm, the third input line is 6660077446 and output is 6600077444. Wait, the second and third lines have 66600... and 66600... but the output changes the middle part. Let me check the third line: input is 6660077446 and output is 6600077444. So the first three digits are 666, then 007... but in the output, it's 660 followed by 000... Maybe there's a pattern where certain digits are being replaced or modified based on their posit