In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import rigging as rg
import mlframework.arc.images as images
from mlframework.files.json import load_json
import arckit
import pathlib

In [None]:
PATH_DATA = pathlib.Path("/home/kristian/Projects/mlframework/data/")
PATH_COMPETITION = PATH_DATA / "arc/arc-prize-2024/"
PATH_TRAIN_CHALLENGES = PATH_COMPETITION / "arc-agi_training_challenges.json"
PATH_TRAIN_SOLUTIONS = PATH_COMPETITION / "arc-agi_training_solutions.json"
PATH_TEST = PATH_COMPETITION / "test.csv"
PATH_SUBMISSION_EXAMPLE = PATH_COMPETITION / "sample_submission.csv"

MODEL = "transformers!meta-llama/Meta-Llama-3-8B-Instruct,device_map=cuda:1,max_tokens=1024,load_in_4bit=True"

In [None]:
train_solutions = load_json(PATH_TRAIN_SOLUTIONS)

train_set, test_set = arckit.load_data("kaggle2024")

task_example = train_set.tasks[0]
drawing = images.show_task(task_example, train_solutions=train_solutions)

In [None]:
prompt = task_example.gpt_prompt(0, include_completion=False)
prompt

In [None]:
class AskerQuestion(rg.Model):
    question: str

def get_model():
    return rg.get_generator(MODEL)

async def ask_next_question(system_prompt, model, verbose=False):
    system_prompt = """
        'We are playing a game which involves transforming an input grid of digits into an output grid of digits. 
        In general, digits form objects in 2D and the task is to perform some spatial transformation 
        of these objects to go from the input grid to the output grid. 
        All the information about the transformation is contained within the input pairs themselves, 
        and your answer will only be correct if the output grid is exactly correct, 
        so this is what I expect from you. I will begin by giving you several examples of input-output pairs. 
        You will then be given a new input grid, and you must provide the corresponding output grid.\n
        
        Input 1: \n0 7 7\n7 7 7\n0 7 7\nOutput 1: \n0 0 0 0 7 7 0 7 7\n0 0 0 7 7 7 7 7 7\n0 0 0 0 7 7 0 7 7\n0 7 7 0 7 7 0 7 7\n7 7 7 7 7 7 7 7 7\n0 7 7 0 7 7 0 7 7\n0 0 0 0 7 7 0 7 7\n0 0 0 7 7 7 7 7 7\n0 0 0 0 7 7 0 7 7\n\nInput 2: \n4 0 4\n0 0 0\n0 4 0\nOutput 2: \n4 0 4 0 0 0 4 0 4\n0 0 0 0 0 0 0 0 0\n0 4 0 0 0 0 0 4 0\n0 0 0 0 0 0 0 0 0\n0 0 0 0 0 0 0 0 0\n0 0 0 0 0 0 0 0 0\n0 0 0 4 0 4 0 0 0\n0 0 0 0 0 0 0 0 0\n0 0 0 0 4 0 0 0 0\n\nInput 3: \n0 0 0\n0 0 2\n2 0 2\nOutput 3: \n0 0 0 0 0 0 0 0 0\n0 0 0 0 0 0 0 0 0\n0 0 0 0 0 0 0 0 0\n0 0 0 0 0 0 0 0 0\n0 0 0 0 0 0 0 0 2\n0 0 0 0 0 0 2 0 2\n0 0 0 0 0 0 0 0 0\n0 0 2 0 0 0 0 0 2\n2 0 2 0 0 0 2 0 2\n\nInput 4: \n6 6 0\n6 0 0\n0 6 6\nOutput 4: \n6 6 0 6 6 0 0 0 0\n6 0 0 6 0 0 0 0 0\n0 6 6 0 6 6 0 0 0\n6 6 0 0 0 0 0 0 0\n6 0 0 0 0 0 0 0 0\n0 6 6 0 0 0 0 0 0\n0 0 0 6 6 0 6 6 0\n0 0 0 6 0 0 6 0 0\n0 0 0 0 6 6 0 6 6\n\nInput 5: \n2 2 2\n0 0 0\n0 2 2\nOutput 5: \n2 2 2 2 2 2 2 2 2\n0 0 0 0 0 0 0 0 0\n0 2 2 0 2 2 0 2 2\n0 0 0 0 0 0 0 0 0\n0 0 0 0 0 0 0 0 0\n0 0 0 0 0 0 0 0 0\n0 0 0 2 2 2 2 2 2\n0 0 0 0 0 0 0 0 0\n0 0 0 0 2 2 0 2 2\n\nInput 6:\n7 0 7\n7 0 7\n7 7 0Please provide a step-by-step explanation. Specifically, answer the following in your explanation. 1. Justify the output shape of your answer. \n2. Did you consider shapes in the outputs and why?\nProvide the output for Input 6 again at the end of your answer.'
    """

    user_prompt = f"""
        Previous questions and answers are:
            {prev_content}
        
        Ask your question within the tag {AskerQuestion.xml_tags()}
    """
    asker = (
        await model
        .chat(
            [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt},
            ]
        )
        .run()
    )
    question = asker.last.parse(AskerQuestion).question
    if verbose:
        print(f"=== Question {len(questions) + 1} ====")
        print(question)

    return question

In [None]:
model = get_model()