In [9]:
# ! pip install torch==2.4.1 torchvision==0.19.0
# ! pip install accelerate==0.34.2
# ! pip install transformers==4.45.1
# ! pip install unsloth==2024.9.post3
# ! pip install bitsandbytes==0.44.0
# ! pip install qwen-vl-utils
! pip install optimum
! pip install auto-gptq

Collecting optimum
  Downloading optimum-1.23.1-py3-none-any.whl.metadata (20 kB)
Collecting coloredlogs (from optimum)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting humanfriendly>=9.1 (from coloredlogs->optimum)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)
Downloading optimum-1.23.1-py3-none-any.whl (422 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m422.6/422.6 kB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hDownloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.8/86.8 kB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: humanfriendly, coloredlogs, optimum
Successfully installed coloredlogs-15.0.1 human

In [1]:
%env CUDA_VISIBLE_DEVICES=0,1
%env TOKENIZERS_PARALLELISM=false

env: CUDA_VISIBLE_DEVICES=0,1
env: TOKENIZERS_PARALLELISM=false


In [2]:
# BASE_PATH = "/kaggle/input"
BASE_PATH = "/home/stepan/kaggle-arc-agi"
# MODEL_ID = f"unsloth/Meta-Llama-3.1-8B-bnb-4bit"
MODEL_ID = "unsloth/Llama-3.2-3B-Instruct-bnb-4bit"
# VLLM_MODEL_ID = "unsloth/Llama-3.2-11B-Vision-Instruct"
VLLM_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4"
# VLLM_MODEL_ID = "Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4"
MAX_NEW_TOKENS = 2048
MAX_SEQ_LENGTH = 32768 - MAX_NEW_TOKENS

In [3]:
# LLM_HIDDEN_SIZE = 4096 # 8B
# VLLM_HIDDEN_SIZE = 3584 #7B
LLM_HIDDEN_SIZE = 3072  # 3B
VLLM_HIDDEN_SIZE = 1536  # 2B

In [4]:
import sys

sys.path.append(BASE_PATH)
sys.path.append(f"{BASE_PATH}/scripts")
# sys.path.append('/kaggle/input/arc-agi-python-utilities')

In [5]:
import io
import os
import json
import base64
import random
from PIL import Image

import numpy as np
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal

import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig  # type: ignore
from transformers import MllamaForConditionalGeneration, Qwen2VLForConditionalGeneration, AutoProcessor
from transformers import get_linear_schedule_with_warmup

from datasets import Dataset, DatasetDict  # type: ignore
from datasets import concatenate_datasets  # type: ignore

from qwen_vl_utils import process_vision_info  # type: ignore

import data_utils  # type: ignore
from logger import get_logger  # type: ignore
import train_utils  # type: ignore

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
def allocate_memory():
    memory_gpu_0 = train_utils.gpu_stats(device_id=0)
    memory_gpu_1 = train_utils.gpu_stats(device_id=1)

    total_gpu_0 = int(memory_gpu_0["max_memory"])
    total_gpu_1 = int(memory_gpu_1["max_memory"])

    max_mem_gpu_0 = int(total_gpu_0) * 0.9
    max_mem_gpu_1 = int(total_gpu_1) * 0.9
    block_mem_gpu_0 = max_mem_gpu_0
    block_mem_gpu_1 = max_mem_gpu_1

    x = torch.rand((256, 1024, block_mem_gpu_0)).to("cuda:0")
    y = torch.rand((256, 1024, block_mem_gpu_1)).to("cuda:1")

    del x
    del y

In [7]:
# allocate_memory()

In [8]:
dtype = torch.bfloat16

In [9]:
def get_models():
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="left")
    llm_model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=dtype,
        device_map="auto",
        max_memory={0: "23.5GiB", "cpu": "16GiB"},
        attn_implementation="flash_attention_2",
        output_hidden_states=True,
        return_dict_in_generate=True,
        quantization_config=BitsAndBytesConfig(load_in_4bit=True),
    )

    processor = AutoProcessor.from_pretrained(VLLM_MODEL_ID)
    vllm_model = Qwen2VLForConditionalGeneration.from_pretrained(
        VLLM_MODEL_ID,
        torch_dtype=dtype,
        device_map="auto",
        max_memory={1: "23.5GiB", "cpu": "16GiB"},
        attn_implementation="flash_attention_2",
        output_hidden_states=True,
        return_dict_in_generate=True,
        #         quantization_config=BitsAndBytesConfig(load_in_4bit=True),
    )

    return {"llm": llm_model, "tokenizer": tokenizer, "vllm": vllm_model, "processor": processor}

In [10]:
models = get_models()

Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.
Unrecognized keys in `rope_scaling` for 'rope_type'='default': {'mrope_section'}
  def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
  def backward(ctx, grad_output):
  @custom_fwd(cast_inputs=torch.float16)
CUDA extension not installed.
CUDA extension not installed.
You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour
`Qwen2VLRotaryEmbedding` can now be fully parameterized by passing the model config through the `config` argument. All other arguments will be removed in v4.46


In [11]:
log = get_logger(log_path=f"{BASE_PATH}/logs", log_file="llama-vllama")
log

<logger.SelectiveLogger at 0x72ff826533d0>

In [12]:
LLAMA_3_CHAT_TEMPLATE = """{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"""

TRAIN_IMAGE_SYSTEM_PROMPT = (
    "Puzzle Analysis Task: You are a puzzle-solving expert analyzing image sets from the abstraction and reasoning corpus by Francois Chollet. "
    "Steps to Follow: "
    "1. Overview: Count images, describe grid size and colored squares, list colors present. "
    "2. Image Analysis: Detail arrangement and frequency of colors, identify patterns or structures. "
    "3. Comparison: Note common color combinations and patterns, identify consistent color relationships. "
    "4. Rule Identification: Consider rules like color replacement, positional, and relational rules. "
    "5. Hypothesis Testing: Verify rules across images, noting exceptions. "
    "6. Pattern Refinement: Generalize the pattern to fit all images. "
    "7. Detailed Description: Explain the pattern with precise language and conditions. "
    "8. Verification: Ensure the pattern applies to all images, including hypothetical ones. "
    "9. Summary: Provide a concise summary of the pattern's logic."
)

TEST_IMAGE_SYSTEM_PROMPT = (
    "Puzzle Analysis Task: You are a puzzle-solving expert analyzing a single image from the abstraction and reasoning corpus by Francois Chollet. "
    "Steps to Follow: "
    "1. Overview: Describe grid size, count colored squares, list colors used. "
    "2. Color Analysis: Count squares per color, note dominant or rare colors. "
    "3. Spatial Analysis: Describe colored square positions and patterns. "
    "4. Edge and Corner Analysis: Detail colors on edges and corners, noting patterns. "
    "5. Symmetry and Balance: Check for symmetry and color balance. "
    "6. Pattern Identification: Identify repeating color patterns and structures. "
    "7. Color Relationships: Analyze color adjacency and arrangements. "
    "8. Unique Features: Highlight unusual characteristics. "
    "9. Quantitative Analysis: Calculate color ratios, note numerical patterns. "
    "10. Comparative Analysis: Compare elements like left vs right or top vs bottom. "
    "11. Abstraction: Describe the image abstractly, considering larger patterns. "
    "12. Summary: Summarize the image's key features and significant aspects."
)

TRAIN_TEXT_SYSTEM_PROMPT = (
    "Puzzle Analysis Task: You are a puzzle-solving expert analyzing matrices from the abstraction and reasoning corpus by Francois Chollet. "
    "Steps to Follow: "
    "1. Overview: Count matrices, describe structure, list numbers present. "
    "2. Matrix Analysis: Detail number arrangement and frequency, identify patterns or structures. "
    "3. Comparison: Note common number combinations and patterns, identify consistent number relationships. "
    "4. Rule Identification: Consider rules like number replacement, positional, quantity, shape, relational, and mathematical operations. "
    "5. Hypothesis Testing: Verify rules across matrices, noting exceptions. "
    "6. Pattern Refinement: Generalize the pattern to fit all matrices. "
    "7. Detailed Description: Explain the pattern with precise language and conditions. "
    "8. Verification: Ensure the pattern applies to all matrices, including hypothetical ones. "
    "9. Summary: Provide a concise summary of the pattern's logic."
)

TEST_TEXT_SYSTEM_PROMPT = (
    "Puzzle Analysis Task: You are a puzzle-solving expert analyzing a single matrix from the abstraction and reasoning corpus by Francois Chollet. "
    "Steps to Follow: "
    "1. Overview: Describe matrix dimensions, number range, note immediate patterns. "
    "2. Number Analysis: Count number frequency, note common or rare numbers. "
    "3. Spatial Analysis: Describe key number positions and patterns. "
    "4. Edge and Corner Analysis: Detail numbers on edges and corners, noting patterns. "
    "5. Symmetry and Balance: Check for symmetry and number balance. "
    "6. Pattern Identification: Identify repeating number patterns and structures. "
    "7. Number Relationships: Analyze number adjacency and arrangements, note mathematical relationships. "
    "8. Unique Features: Highlight unusual characteristics. "
    "9. Quantitative Analysis: Calculate relevant statistics, note numerical patterns. "
    "10. Comparative Analysis: Compare elements like rows vs columns or quadrants. "
    "11. Abstraction: Describe the matrix abstractly, considering larger patterns. "
    "12. Summary: Summarize the matrix's key features and significant aspects."
)

TRAIN_TEXT_PROMPT = (
    "Learn the underlying rule from these example input-output pairs to predict the output for the test input: "
    "----------------- "
    "{training_data}"
)

TEST_TEXT_PROMPT = "Analyze the following test input data: " "----------------- " "{input_test_data}"

TRAIN_IMAGE_PROMPT = "Analyze the images."
TEST_IMAGE_PROMPT = "Analyze the image."

In [13]:
# models["tokenizer"].chat_template = LLAMA_3_CHAT_TEMPLATE

In [14]:
def list_to_image(integer_list_2d, target_size=30):
    # Convert the 2D list to a NumPy array
    array = np.array(integer_list_2d)

    # Get the unique values in the array
    unique_values = np.unique(array)

    # Create a colormap
    cmap = plt.get_cmap("tab10")

    # Create a color lookup dictionary
    color_lookup = {value: cmap(i % 10)[:3] for i, value in enumerate(unique_values)}

    # Create an RGB array
    rgb_array = np.array([[color_lookup[val] for val in row] for row in array])

    # Convert to 8-bit color values
    rgb_array = (rgb_array * 255).astype(np.uint8)

    # Create an image from the colored array
    image = Image.fromarray(rgb_array, mode="RGB")

    # Create a new blank image with the target size
    new_image = Image.new("RGB", (target_size, target_size), color=(0, 0, 0))

    # Paste the original image onto the new image
    new_image.paste(image, (0, 0))

    new_image = new_image.resize((target_size * 15, target_size * 15), Image.NEAREST)

    return new_image


def pil_image_to_base64(image):
    buffered = io.BytesIO()
    image.save(buffered, format="PNG")
    return "data:image;base64," + base64.b64encode(buffered.getvalue()).decode("utf-8")

In [15]:
def prepare_inputs(dct, prepare_solution=False):
    if prepare_solution:
        return "<output>\n" + "\n".join(" ".join(map(str, row)) for row in dct) + "\n</output>"
    else:
        input_str = "\n".join(" ".join(map(str, row)) for row in dct["input"])
        output_str = "\n".join(" ".join(map(str, row)) for row in dct["output"]) if "output" in dct else ""
        text = f"<input>\n{input_str}\n</input>"
        if output_str:
            text += f"\n\n<output>\n{output_str}\n</output>"
        return text

In [16]:
def pad_matrix(matrix, target_rows, target_cols, pad_value=0):
    # Convert input to numpy array if it's not already
    matrix = np.array(matrix)

    # Get current dimensions
    current_rows, current_cols = matrix.shape

    # Pad rows
    if current_rows < target_rows:
        pad_rows = np.full((target_rows - current_rows, current_cols), pad_value)
        matrix = np.vstack((matrix, pad_rows))

    # Pad columns
    if current_cols < target_cols:
        pad_cols = np.full((target_rows, target_cols - current_cols), pad_value)
        matrix = np.hstack((matrix, pad_cols))

    return matrix

In [17]:
def ensure_size(matrix):
    matrix = np.array(matrix)
    h, w = matrix.shape
    if h < 1 or w < 1:
        return np.array([[0]])  # Minimum size 1x1
    return matrix[:30, :30]  # Maximum size 30x30


def rotate_grid(matrix, k):
    return ensure_size(np.rot90(matrix, k=k))


def flip_grid(matrix, axis):
    return ensure_size(np.flip(matrix, axis=axis))


def expand_grid(matrix, factor=3):
    expanded = np.repeat(np.repeat(matrix, factor, axis=0), factor, axis=1)
    return ensure_size(expanded)


def shrink_grid(matrix, factor=3):
    matrix = np.array(matrix)
    h, w = matrix.shape
    shrunk = matrix[: h // factor * factor : factor, : w // factor * factor : factor]
    return ensure_size(shrunk)


def roll_grid(matrix, shift, axis):
    return ensure_size(np.roll(matrix, shift=shift, axis=axis))


def add_border(matrix, value=0):
    bordered = np.pad(matrix, pad_width=1, mode="constant", constant_values=value)
    return ensure_size(bordered)


def remove_border(matrix):
    matrix = np.array(matrix)
    if matrix.shape[0] <= 2 or matrix.shape[1] <= 2:
        return ensure_size(matrix)  # Can't remove border from very small matrices
    return ensure_size(matrix[1:-1, 1:-1])


def replace_color(matrix, old_color, new_color):
    return ensure_size(np.where(matrix == old_color, new_color, matrix))


def add_noise(matrix, noise, noise_values):
    return ensure_size(np.where(noise, noise_values, matrix))


def mirror_grid(matrix):
    mirrored = np.hstack([matrix, np.fliplr(matrix)])
    return ensure_size(mirrored)


def tile_grid(matrix, tiles=(2, 2)):
    tiled = np.tile(matrix, tiles)
    return ensure_size(tiled)


def diagonal_shift(matrix, shift):
    shifted = np.pad(matrix, ((shift, 0), (shift, 0)), mode="constant")[:-shift, :-shift]
    return ensure_size(shifted)


def apply_mask(matrix, mask):
    return ensure_size(matrix * mask)


def swap_quadrants(matrix, order):
    matrix = ensure_size(matrix)
    h, w = matrix.shape
    if h < 2 or w < 2:
        return matrix  # Can't swap quadrants for very small matrices
    quadrants = [matrix[: h // 2, : w // 2], matrix[: h // 2, w // 2 :], matrix[h // 2 :, : w // 2], matrix[h // 2 :, w // 2 :]]
    reordered = [quadrants[i] for i in order]
    swapped = np.vstack([np.hstack([reordered[0], reordered[1]]), np.hstack([reordered[2], reordered[3]])])
    return ensure_size(swapped)


def generate_augmentation_pipeline(num_augmentations=3, seed=None):
    if seed is not None:
        random.seed(seed)

    operations = [
        ("rotate", lambda: {"k": random.choice([1, 2, 3])}),
        ("flip", lambda: {"axis": random.choice([0, 1])}),
        ("expand", lambda: {"factor": random.randint(2, 4)}),
        # ("shrink", lambda: {"factor": random.randint(2, 4)}),
        # ("roll", lambda: {"shift": random.randint(1, 5), "axis": random.choice([0, 1])}),
        ("add_border", lambda: {"value": random.randint(0, 9)}),
        # ("remove_border", lambda: {}),
        ("replace_color", lambda: {"old_color": random.randint(0, 9), "new_color": random.randint(0, 9)}),
        ("add_noise", lambda: {"noise_prob": random.uniform(0.05, 0.2)}),
        ("mirror", lambda: {}),
        ("tile", lambda: {"tiles": (random.randint(1, 3), random.randint(1, 3))}),
        # ("diagonal_shift", lambda: {"shift": random.randint(1, 5)}),
        ("apply_mask", lambda: {"mask_prob": random.uniform(0.1, 0.3)}),
        ("swap_quadrants", lambda: {"order": random.sample(range(4), 4)}),
    ]

    pipeline = []
    for _ in range(num_augmentations):
        op, param_func = random.choice(operations)
        pipeline.append((op, param_func()))

    return pipeline


def apply_augmentation_pipeline(matrix, pipeline):
    matrix = np.array(matrix)
    for op, params in pipeline:
        if op == "rotate":
            matrix = rotate_grid(matrix, **params)
        elif op == "flip":
            matrix = flip_grid(matrix, **params)
        elif op == "expand":
            matrix = expand_grid(matrix, **params)
        elif op == "shrink":
            matrix = shrink_grid(matrix, **params)
        elif op == "roll":
            matrix = roll_grid(matrix, **params)
        elif op == "add_border":
            matrix = add_border(matrix, **params)
        elif op == "remove_border":
            matrix = remove_border(matrix)
        elif op == "replace_color":
            matrix = replace_color(matrix, **params)
        elif op == "add_noise":
            noise = np.random.choice([0, 1], size=matrix.shape, p=[1 - params["noise_prob"], params["noise_prob"]])
            noise_values = np.random.randint(0, 10, size=matrix.shape)
            matrix = add_noise(matrix, noise, noise_values)
        elif op == "mirror":
            matrix = mirror_grid(matrix)
        elif op == "tile":
            matrix = tile_grid(matrix, **params)
        elif op == "diagonal_shift":
            matrix = diagonal_shift(matrix, **params)
        elif op == "apply_mask":
            mask = np.random.choice([0, 1], size=matrix.shape, p=[params["mask_prob"], 1 - params["mask_prob"]])
            matrix = apply_mask(matrix, mask)
        elif op == "swap_quadrants":
            matrix = swap_quadrants(matrix, **params)
    return matrix

In [18]:
def augment_challenge(challenge, solution, pipeline):
    new_challenge = {
        **challenge,
        "train": [
            {
                "input": apply_augmentation_pipeline(grid["input"], pipeline).tolist(),
                "output": apply_augmentation_pipeline(grid["output"], pipeline).tolist(),
            }
            for grid in challenge["train"]
        ],
        "test": {"input": apply_augmentation_pipeline(challenge["test"]["input"], pipeline).tolist()},
    }
    return {"challenge": new_challenge, "solution": apply_augmentation_pipeline(solution, pipeline).tolist()}


def augment_dataset(dataset, num_augmentations=3, total_number=1000, seed=11):
    augmented_dataset = {
        "id": [],
        "challenge": [],
        "solution": [],
    }

    while len(augmented_dataset["id"]) < total_number:
        idx = random.randint(0, len(dataset["id"]) - 1)
        challenge = dataset["challenge"][idx]

        solution = dataset["solution"][idx]

        pipeline = generate_augmentation_pipeline(num_augmentations, seed)
        augmented_challenge = augment_challenge(challenge, solution, pipeline)

        augmented_dataset["id"].append(f'aug-{dataset["id"][idx]}-{num_augmentations}')
        augmented_dataset["challenge"].append(augmented_challenge["challenge"])
        augmented_dataset["solution"].append(augmented_challenge["solution"])

    return augmented_dataset


def to_dataset(data, solutions=None, augment=False):
    restructured_data = {
        "id": [],
        "challenge": [],
    }
    if solutions is not None:
        restructured_data["solution"] = []

    for challenge_id, challenge_data in data.items():  # for all challenges
        for test_id, task in enumerate(
            challenge_data["test"]
        ):  # for all test tasks in this challenge we want to expand dataset so that each test task is separate dataset record
            restructured_data["id"].append(challenge_id)
            restructured_data["challenge"].append({"train": challenge_data["train"], "test": task, "order": test_id})
            if solutions is not None:
                restructured_data["solution"].append(solutions[challenge_id][test_id])

    if augment:
        augmented_data = augment_dataset(restructured_data, num_augmentations=3, total_number=10_000, seed=11)
        restructured_data["id"].extend(augmented_data["id"])
        restructured_data["challenge"].extend(augmented_data["challenge"])
        restructured_data["solution"].extend(augmented_data["solution"])

        augmented_data = augment_dataset(restructured_data, num_augmentations=5, total_number=10_000, seed=22)
        restructured_data["id"].extend(augmented_data["id"])
        restructured_data["challenge"].extend(augmented_data["challenge"])
        restructured_data["solution"].extend(augmented_data["solution"])

        augmented_data = augment_dataset(restructured_data, num_augmentations=7, total_number=10_000, seed=33)
        restructured_data["id"].extend(augmented_data["id"])
        restructured_data["challenge"].extend(augmented_data["challenge"])
        restructured_data["solution"].extend(augmented_data["solution"])

    return Dataset.from_dict(restructured_data)


def prepare_inputs(dct, prepare_solution=False):
    if prepare_solution:
        return "<output>\n" + "\n".join(" ".join(map(str, row)) for row in dct) + "\n</output>"
    else:
        input_str = "\n".join(" ".join(map(str, row)) for row in dct["input"])
        output_str = "\n".join(" ".join(map(str, row)) for row in dct["output"]) if "output" in dct else ""
        text = f"<input>\n{input_str}\n</input>"
        if output_str:
            text += f"\n\n<output>\n{output_str}\n</output>"
        return text


def prepare_dataset(tokenizer, base_path=None, final_training=False):
    # Load all datasets
    training_challenges = data_utils.load_data(f"{base_path}/arc-prize-2024/arc-agi_training_challenges.json")
    training_solutions = data_utils.load_data(f"{base_path}/arc-prize-2024/arc-agi_training_solutions.json")
    evaluation_challenges = data_utils.load_data(f"{base_path}/arc-prize-2024/arc-agi_evaluation_challenges.json")
    evaluation_solutions = data_utils.load_data(f"{base_path}/arc-prize-2024/arc-agi_evaluation_solutions.json")
    test_challenges = data_utils.load_data(f"{base_path}/arc-prize-2024/arc-agi_test_challenges.json")

    train_dataset = to_dataset(training_challenges, training_solutions, augment=True)
    eval_dataset = to_dataset(evaluation_challenges, evaluation_solutions)
    pred_dataset = to_dataset(test_challenges)

    def create_train_image_content(challenge):
        content = [{"type": "text", "text": TRAIN_IMAGE_SYSTEM_PROMPT}]

        for i, example in enumerate(challenge["train"]):
            content.extend(
                [
                    {"type": "text", "text": f"Input Task {i+1}"},
                    {"type": "image", "image": pil_image_to_base64(list_to_image(example["input"]))},
                    {"type": "text", "text": f"Output Task {i+1}"},
                    {"type": "image", "image": pil_image_to_base64(list_to_image(example["output"]))},
                ]
            )

        content.append({"type": "text", "text": TRAIN_IMAGE_PROMPT})
        return content

    def create_chat(challenge, solution=None):
        train_input = TRAIN_TEXT_SYSTEM_PROMPT.format(
            training_data="\n\n".join([prepare_inputs(ex) for ex in challenge["train"]]),
        )
        test_input = TEST_TEXT_SYSTEM_PROMPT.format(
            input_test_data=prepare_inputs(challenge["test"]),
        )

        train_text_messages = [
            {"role": "system", "content": TRAIN_TEXT_SYSTEM_PROMPT},
            {"role": "user", "content": train_input},
        ]

        test_text_messages = [
            {"role": "system", "content": TEST_TEXT_SYSTEM_PROMPT},
            {"role": "user", "content": test_input},
        ]

        train_image_messages = [
            {
                "role": "user",
                "content": create_train_image_content(challenge),
            },
        ]

        test_image_messages = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": TEST_IMAGE_SYSTEM_PROMPT},
                    {"type": "text", "text": f"Input Test Task"},
                    {"type": "image", "image": pil_image_to_base64(list_to_image(challenge["test"]["input"]))},
                    {"type": "text", "text": TEST_IMAGE_PROMPT},
                ],
            },
        ]

        if solution:
            test_text_messages.append(
                {
                    "role": "assistant",
                    "content": prepare_inputs(solution, prepare_solution=True),
                }
            )

        return {
            "train_text_messages": train_text_messages,
            "test_text_messages": test_text_messages,
            "train_image_messages": train_image_messages,
            "test_image_messages": test_image_messages,
        }

    def process_dataset(examples, solutions=None):
        # Create messages for each challenge-solution pair
        chats = []
        for challenge, solution in zip(examples["challenge"], solutions or [None] * len(examples["challenge"])):
            chat = create_chat(challenge, solution)
            chats.append(chat)

        return {"messages": chats}

    pred_dataset = pred_dataset.map(lambda x: process_dataset(x), batched=True)
    train_dataset = train_dataset.map(lambda x: process_dataset(x, train_dataset["solution"]), batched=True)
    eval_dataset = eval_dataset.map(lambda x: process_dataset(x, eval_dataset["solution"]), batched=True)

    if final_training:  # if final training, we need to add the validation dataset to the training dataset
        train_dataset = concatenate_datasets([train_dataset, eval_dataset]).shuffle(seed=42)

        return DatasetDict(
            {
                "train": train_dataset,
                "predict": pred_dataset,
            }
        )

    test_dataset = eval_dataset.train_test_split(test_size=0.3)

    dataset = DatasetDict(
        {
            "train": train_dataset,
            "test": test_dataset["train"],
            "val": test_dataset["test"],
            "predict": pred_dataset,
        }
    )

    return dataset

In [19]:
dataset = prepare_dataset(models["tokenizer"], base_path=BASE_PATH, final_training=False)
dataset

Map: 100%|██████████| 30416/30416 [28:00<00:00, 18.10 examples/s]
Map: 100%|██████████| 419/419 [00:10<00:00, 38.60 examples/s]


DatasetDict({
    train: Dataset({
        features: ['id', 'challenge', 'solution', 'messages'],
        num_rows: 30416
    })
    test: Dataset({
        features: ['id', 'challenge', 'solution', 'messages'],
        num_rows: 293
    })
    val: Dataset({
        features: ['id', 'challenge', 'solution', 'messages'],
        num_rows: 126
    })
    predict: Dataset({
        features: ['id', 'challenge', 'messages'],
        num_rows: 105
    })
})

In [20]:
def plot_example(sample, cmap="Set1", title=None):
    # print train examples as a multiplot with input -> output
    fig, axs = plt.subplots(len(sample["train"]) + 1, 2, figsize=(20, 20))
    for i, ex in enumerate(sample["train"]):
        axs[i][0].imshow(ex["input"], cmap=cmap)
        axs[i][1].imshow(ex["output"], cmap=cmap)

        axs[i][0].axis("off")
        axs[i][1].axis("off")

    idx = len(sample["train"])

    axs[idx][0].imshow(sample["test"]["input"], cmap=cmap)
    # plot empty image if no output is available black and white cmap
    # axs[i][1].imshow(np.zeros_like(ex['input']), cmap='gray')

    axs[idx][0].axis("off")
    axs[idx][1].axis("off")

    fig.suptitle(title)


def augment_challenge(challenge, augmentation):
    new_challenge = {
        "train": [{"input": augmentation(grid["input"]), "output": augmentation(grid["output"])} for grid in challenge["train"]],
        "test": {"input": augmentation(challenge["test"]["input"])},
    }
    return new_challenge


def demonstrate_transformations(dataset):
    # Sample a few examples from the dataset
    # samples = random.sample(list(dataset['train']), 3)
    samples = list(dataset["train"])[:3]

    for i, sample in enumerate(samples):
        challenge = augment_challenge(sample["challenge"], lambda x: rotate_grid(x, k=1))
        plot_example(challenge, title=f"Example {i+1} - Rotate 90")

        challenge = augment_challenge(sample["challenge"], lambda x: flip_grid(x, axis=0))
        plot_example(challenge, title=f"Example {i+1} - Flip Vertical")

        challenge = augment_challenge(sample["challenge"], lambda x: expand_grid(x, factor=3))
        plot_example(challenge, title=f"Example {i+1} - Expand")

        challenge = augment_challenge(sample["challenge"], lambda x: shrink_grid(x, factor=2))
        plot_example(challenge, title=f"Example {i+1} - Shrink")

        challenge = augment_challenge(sample["challenge"], lambda x: roll_grid(x, shift=2, axis=0))
        plot_example(challenge, title=f"Example {i+1} - Roll")

        challenge = augment_challenge(sample["challenge"], lambda x: add_border(x, value=1))
        plot_example(challenge, title=f"Example {i+1} - Add Border")

        challenge = augment_challenge(sample["challenge"], lambda x: remove_border(x))
        plot_example(challenge, title=f"Example {i+1} - Remove Border")

        challenge = augment_challenge(sample["challenge"], lambda x: replace_color(x, old_color=1, new_color=2))
        plot_example(challenge, title=f"Example {i+1} - Replace Color")

        challenge = augment_challenge(sample["challenge"], lambda x: add_noise(x, noise_prob=0.1))
        plot_example(challenge, title=f"Example {i+1} - Add Noise")

        challenge = augment_challenge(sample["challenge"], lambda x: mirror_grid(x))
        plot_example(challenge, title=f"Example {i+1} - Mirror")

        challenge = augment_challenge(sample["challenge"], lambda x: tile_grid(x, tiles=(2, 2)))
        plot_example(challenge, title=f"Example {i+1} - Tile")

        challenge = augment_challenge(sample["challenge"], lambda x: diagonal_shift(x, shift=2))
        plot_example(challenge, title=f"Example {i+1} - Diagonal Shift")

        challenge = augment_challenge(sample["challenge"], lambda x: apply_mask(x, mask=np.random.choice([0, 1], size=(3, 3), p=[0.2, 0.8])))
        plot_example(challenge, title=f"Example {i+1} - Apply Mask")

        challenge = augment_challenge(sample["challenge"], lambda x: swap_quadrants(x, order=[1, 2, 3, 0]))
        plot_example(challenge, title=f"Example {i+1} - Swap Quadrants")

        ...


# Demonstrate the transformations
# demonstrate_transformations(dataset)

In [21]:
def filter_long_entries(dataset):
    def check_length(example):
        messages = example["messages"]
        inputs = models["processor"](messages=messages, return_tensors="pt", padding=True)
        pixel_values = inputs.pixel_values
        return pixel_values.numel() <= 16000

    filtered_dataset = dataset.filter(check_length)
    return filtered_dataset


# Apply the filter to all splits in the dataset
# filtered_dataset = DatasetDict({
#     split: filter_long_entries(dataset[split])
#     for split in dataset.keys()
# })

# print("Original dataset sizes:")
# for split, ds in dataset.items():
#     print(f"{split}: {len(ds)}")

# print("\nFiltered dataset sizes:")
# for split, ds in filtered_dataset.items():
#     print(f"{split}: {len(ds)}")

# # Update the dataset variable with the filtered version
# dataset = filtered_dataset

In [22]:
def eval(f):
    def wrapper(model, *args, **kwargs):
        if hasattr(model, "to_inference"):
            model.to_inference()
        else:
            model.eval()
        with torch.no_grad():
            return f(model, *args, **kwargs)

    return wrapper


def train(f):
    def wrapper(model, *args, **kwargs):
        if hasattr(model, "to_training"):
            model.to_training()
        else:
            model.train()
        return f(model, *args, **kwargs)

    return wrapper


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [23]:
@eval
def describe_puzzle(model, processor, image, prompt):
    # Create prompt
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": prompt},
            ],
        },
    ]

    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = processor(text=[text], images=[image], return_tensors="pt")
    inputs = inputs.to(model.device)

    # Run inference
    generated_ids = model.generate(**inputs, max_new_tokens=128)
    generated_ids = generated_ids[0, inputs.input_ids.shape[1] :]
    generated_text = processor.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
    return generated_text

In [24]:
# image = list_to_image(dataset["train"][10]["challenge"]["train"][0]["input"])
# image

In [25]:
# describe_puzzle(models['vllm'], models['processor'], image, "Describe the image")

In [26]:
class Encoder(nn.Module):
    def __init__(self, input_dim, condition_dim, latent_dim, hidden_dim):
        super(Encoder, self).__init__()
        self.condition_dim = condition_dim

        self.query = nn.Linear(input_dim, hidden_dim, dtype=dtype)
        self.key = nn.Linear(input_dim, hidden_dim, dtype=dtype)
        self.value = nn.Linear(input_dim, hidden_dim, dtype=dtype)

        self.attention = nn.MultiheadAttention(hidden_dim, num_heads=4, dtype=dtype)

        self.fc1 = nn.Linear(hidden_dim, hidden_dim, dtype=dtype)

        self.fc_mu = nn.Linear(hidden_dim, latent_dim, dtype=dtype)  # Mean of the latent space
        self.fc_var = nn.Linear(hidden_dim, latent_dim, dtype=dtype)  # Variance of the latent space

    def forward(self, x, condition):
        # Add the condition to the input
        x_cond = torch.cat([x, condition], dim=1)

        # Apply attention
        attn_output, _ = self.attention(self.query(x_cond), self.key(x_cond), self.value(x_cond))
        h = F.relu(self.fc1(attn_output.mean(dim=1)))  # Reduce to a single representation per sample

        # Compute the mean and variance for the latent space
        mu = self.fc_mu(h)
        log_var = self.fc_var(h)

        return mu, log_var

In [27]:
def reparameterize(mu, log_var):
    std = torch.exp(0.5 * log_var)
    eps = torch.randn_like(std)
    return mu + eps * std

In [28]:
class Decoder(nn.Module):
    def __init__(self, latent_dim, condition_dim, output_dim, hidden_dim):
        super(Decoder, self).__init__()
        self.condition_dim = condition_dim
        self.fc1 = nn.Linear(latent_dim + condition_dim, hidden_dim, dtype=dtype)

        self.query = nn.Linear(hidden_dim, hidden_dim, dtype=dtype)
        self.key = nn.Linear(hidden_dim, hidden_dim, dtype=dtype)
        self.value = nn.Linear(hidden_dim, hidden_dim, dtype=dtype)

        self.attention = nn.MultiheadAttention(hidden_dim, num_heads=4, dtype=dtype)
        self.fc_output = nn.Linear(
            hidden_dim, output_dim * output_dim * 10, dtype=dtype
        )  # output is the 30x30 image with each pixel being a vector of logits

    def forward(self, z, condition, output_len):
        # Combine latent variable z and condition
        z_cond = torch.cat([z.unsqueeze(1).repeat(1, condition.shape[1], 1), condition], dim=-1)

        h = F.relu(self.fc1(z_cond))

        # Apply attention to guide the generation process
        attn_output, _ = self.attention(self.query(h), self.key(h), self.value(h))

        # Generate output
        output = torch.softmax(self.fc_output(attn_output), dim=-1)

        return output

In [29]:
class CVAE(nn.Module):
    def __init__(self, input_dim, condition_dim, latent_dim, output_dim, hidden_dim):
        super(CVAE, self).__init__()
        self.encoder = Encoder(input_dim, condition_dim, latent_dim, hidden_dim)
        self.decoder = Decoder(latent_dim, condition_dim, output_dim, hidden_dim)

    def forward(self, x, condition, output_len):
        # Encode
        mu, log_var = self.encoder(x, condition)

        # Reparameterization trick
        z = reparameterize(mu, log_var)  # (B, latent_dim)

        # Decode
        output = self.decoder(z, condition, output_len)  # (B, output_len, output_dim * output_dim * 10)

        return output, mu, log_var

In [30]:
class ARCModel(torch.nn.Module):
    def __init__(self, llm_model, vllm_model, beta=1.0, gamma=1.0):
        super().__init__()
        self.llm_model = llm_model
        self.vllm_model = vllm_model

        self.llm_model.requires_grad_(False)
        self.vllm_model.requires_grad_(False)

        self.text_proj = nn.Linear(LLM_HIDDEN_SIZE, 2304, dtype=dtype)
        self.image_proj = nn.Linear(VLLM_HIDDEN_SIZE, 2304, dtype=dtype)

        self.cvae = CVAE(input_dim=2304, condition_dim=2304, latent_dim=512, output_dim=30, hidden_dim=1024)

        self.output_dim = 30
        self.beta = beta
        self.gamma = gamma

    def to(self, device):
        self.device = device
        self.cvae.to(device)
        self.text_proj.to(device)
        self.image_proj.to(device)
        return self

    def to_inference(self):
        self.llm_model.eval()
        self.vllm_model.eval()

    def to_training(self):
        self.llm_model.train()
        self.vllm_model.train()

    # def cvae_loss(self, recon_x, x, mu, log_var, rows, cols):
    #     recon_loss = F.binary_cross_entropy_with_logits(recon_x, x, reduction="sum")
    #     # KL Divergence loss
    #     kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    #     return recon_loss + kl_loss

    def cvae_loss(self, recon_x, x, mu, log_var, rows, cols):
        """
        recon_x: (B, 30, 30, 10)
        x: (B, 30, 30)
        beta: weight for KL divergence loss
        gamma: weight for dimensions loss
        """
        B = recon_x.shape[0]

        # Create a mask for the actual content area
        mask = torch.zeros_like(x, dtype=torch.bool)
        for i in range(B):
            mask[i, : rows[i], : cols[i]] = True

        # Calculate reconstruction loss (normalized)
        recon_loss = F.cross_entropy(recon_x.view(-1, 10), x.view(-1), reduction="mean")

        # Calculate dimensions loss (normalized)
        dims_loss = (torch.sum(recon_x[mask].argmax(dim=-1) * x[mask]) - torch.sum(recon_x[~mask].argmax(dim=-1) * x[~mask])) / mask.sum()

        # KL Divergence loss (scaled)
        kl_loss = -0.5 * torch.mean(1 + log_var - mu.pow(2) - log_var.exp())

        # Combine losses with scaling factors
        total_loss = recon_loss + self.beta * kl_loss + self.gamma * dims_loss

        return total_loss

    def encode(self, text_inputs, image_inputs):
        with torch.no_grad():
            text_features = self.llm_model(**text_inputs.to(self.llm_model.device)).hidden_states[-1]  # (batch_size, seq_len, 3072)
            image_features = self.vllm_model(**image_inputs.to(self.vllm_model.device)).hidden_states[-1]  # (batch_size, vid_len, 3584)

        # -- todo: cleanup
        text_inputs.to("cpu")
        image_inputs.to("cpu")

        torch.cuda.empty_cache()
        # -- todo: cleanup

        text_features = self.text_proj(text_features.to(self.device))
        image_features = self.image_proj(image_features.to(self.device))

        features = torch.cat([text_features, image_features], dim=1)  # (batch_size, seq_len + vid_len, 2304)
        return features

    def forward(self, train_inputs, test_inputs, targets=None):
        train_features = self.encode(text_inputs=train_inputs["text"], image_inputs=train_inputs["image"])  # (B, seq_len + vid_len, 2304)
        test_features = self.encode(text_inputs=test_inputs["text"], image_inputs=test_inputs["image"])  # (B, seq_len + vid_len, 2304)

        outputs, mu, log_var = self.cvae(train_features, test_features, output_len=30)  # (B, cond_seq_len, 30)

        B = outputs.shape[0]
        outputs = outputs[:, 0, :].reshape(B, self.output_dim, self.output_dim, 10).cpu().float()

        if targets is not None:
            rows = targets["original_rows"]
            cols = targets["original_cols"]
            padded_matrices = targets["padded_matrix"]

            labels = torch.tensor(np.array(padded_matrices)).reshape(B, self.output_dim, self.output_dim)

            loss = self.cvae_loss(outputs, labels, mu, log_var, rows, cols)
            return {"loss": loss, "outputs": outputs, "mu": mu, "log_var": log_var}

        # we will only take (B, 30, 30) for the loss calculation
        return {"loss": None, "outputs": outputs, "mu": mu, "log_var": log_var}

    def from_pretrained(self, path):
        self.text_proj.load_state_dict(torch.load(f"{path}/text_proj.pth"))
        self.image_proj.load_state_dict(torch.load(f"{path}/image_proj.pth"))
        self.cvae.load_state_dict(torch.load(f"{path}/cvae.pth"))
        return self

    def save_pretrained(self, path):
        # Create the directory if it doesn't exist
        os.makedirs(path, exist_ok=True)
        torch.save(self.text_proj.state_dict(), f"{path}/text_proj.pth")
        torch.save(self.image_proj.state_dict(), f"{path}/image_proj.pth")
        torch.save(self.cvae.state_dict(), f"{path}/cvae.pth")
        # Save any other non-llm and non-vllm weights here

In [31]:
arc_model = ARCModel(models["llm"], models["vllm"], gamma=0.5, beta=0.9)
arc_model.to("cuda:0")

ARCModel(
  (llm_model): LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): Embedding(128256, 3072)
      (layers): ModuleList(
        (0-27): 28 x LlamaDecoderLayer(
          (self_attn): LlamaFlashAttention2(
            (q_proj): Linear4bit(in_features=3072, out_features=3072, bias=False)
            (k_proj): Linear4bit(in_features=3072, out_features=1024, bias=False)
            (v_proj): Linear4bit(in_features=3072, out_features=1024, bias=False)
            (o_proj): Linear4bit(in_features=3072, out_features=3072, bias=False)
            (rotary_emb): LlamaRotaryEmbedding()
          )
          (mlp): LlamaMLP(
            (gate_proj): Linear4bit(in_features=3072, out_features=8192, bias=False)
            (up_proj): Linear4bit(in_features=3072, out_features=8192, bias=False)
            (down_proj): Linear4bit(in_features=8192, out_features=3072, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
 

In [32]:
def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


# Count and print the number of trainable parameters
trainable_params = count_trainable_parameters(arc_model)
print(f"Number of trainable parameters: {trainable_params:,}")

Number of trainable parameters: 43,456,808


In [33]:
def collate(mode, tokenizer, processor):
    def prepare_inputs(text_messages, image_messages):

        def clean_none_values(messages):
            return [{k: v for k, v in message.items() if v is not None} for message in messages]

        image_messages = [[{**msg, "content": clean_none_values(msg["content"])} for msg in msgs] for msgs in image_messages]

        text_encodings = tokenizer.apply_chat_template(
            text_messages,
            tokenize=True,
            add_generation_prompt=(mode not in ["train", "val"]),
            return_tensors="pt",
            return_dict=True,
            padding=True,
        )

        image_text = processor.apply_chat_template(image_messages, tokenize=False, add_generation_prompt=True)
        image_inputs, _ = process_vision_info(image_messages)

        image_encodings = processor(
            text=image_text,
            images=image_inputs,
            padding=True,
            return_tensors="pt",
        )

        return text_encodings, image_encodings

    def collate_fn(batch):
        # Separate the different components of the batch
        # For 'test' mode, remove the last assistant message from each entry
        train_text_messages = [item["messages"]["train_text_messages"] for item in batch]
        train_image_messages = [item["messages"]["train_image_messages"] for item in batch]

        test_text_messages = [item["messages"]["test_text_messages"] for item in batch]
        test_image_messages = [item["messages"]["test_image_messages"] for item in batch]

        # Tokenize the texts
        train_text_encodings, train_image_encodings = prepare_inputs(train_text_messages, train_image_messages)
        test_text_encodings, test_image_encodings = prepare_inputs(test_text_messages, test_image_messages)

        # If 'solution' is present (for training/validation data)
        if "solution" in batch[0]:
            return {
                "train_inputs": {"text": train_text_encodings, "image": train_image_encodings},
                "test_inputs": {"text": test_text_encodings, "image": test_image_encodings},
                "targets": {
                    "padded_matrix": [pad_matrix(item["solution"], target_rows=30, target_cols=30) for item in batch],
                    "original_rows": [len(item["solution"]) for item in batch],
                    "original_cols": [len(item["solution"][0]) for item in batch],
                },
            }
        else:
            return {
                "train_inputs": {"text": train_text_encodings, "image": train_image_encodings},
                "test_inputs": {"text": test_text_encodings, "image": test_image_encodings},
            }

    return collate_fn

In [34]:
dataloader = torch.utils.data.DataLoader(
    dataset["train"], batch_size=1, collate_fn=collate(mode="train", tokenizer=models["tokenizer"], processor=models["processor"])
)


def print_recursive(obj, indent=0):
    if isinstance(obj, torch.Tensor):
        print("  " * indent + str(obj.shape))
    elif (
        isinstance(obj, dict)
        or isinstance(obj, transformers.tokenization_utils_base.BatchEncoding)
        or isinstance(obj, transformers.feature_extraction_utils.BatchFeature)
    ):
        for key, value in obj.items():
            print("  " * indent + str(key) + ":")
            print_recursive(value, indent + 1)
    elif isinstance(obj, list):
        if len(obj) > 0 and isinstance(obj[0], list):
            if len(obj[0]) > 0 and isinstance(obj[0][0], list):
                print("  " * indent + f"List of length: {len(obj)}, {len(obj[0])}, {len(obj[0][0])}")
            else:
                print("  " * indent + f"List of length: {len(obj)}, {len(obj[0])}")
        else:
            print("  " * indent + f"List of length: {len(obj)}")
            print_recursive(obj[0], indent + 1)
    else:
        print("  " * indent + str(obj))


for batch in dataloader:
    print_recursive(batch)
    #     outputs = arc_model(**batch)
    #     print('-'* 30)
    #     print_recursive(outputs)
    break

train_inputs:
  text:
    input_ids:
      torch.Size([1, 393])
    attention_mask:
      torch.Size([1, 393])
  image:
    input_ids:
      torch.Size([1, 2824])
    attention_mask:
      torch.Size([1, 2824])
    pixel_values:
      torch.Size([10240, 1176])
    image_grid_thw:
      torch.Size([10, 3])
test_inputs:
  text:
    input_ids:
      torch.Size([1, 630])
    attention_mask:
      torch.Size([1, 630])
  image:
    input_ids:
      torch.Size([1, 499])
    attention_mask:
      torch.Size([1, 499])
    pixel_values:
      torch.Size([1024, 1176])
    image_grid_thw:
      torch.Size([1, 3])
targets:
  padded_matrix:
    List of length: 1
      [[7 0 7 0 0 0 7 0 7 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [7 0 7 0 0 0 7 0 7 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [7 7 0 0 0 0 7 7 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [7 0 7 0 0 0 7 0 7 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [7 0 7 0 0 0 7 0 7 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [7 7 0 0 0 0 7 7 0 0 

In [35]:
def calculate_partial_match(pred, label, masks):
    # Convert inputs to numpy arrays if they're not already
    pred = np.array(pred)
    label = np.array(label)
    masks = np.array(masks)

    # Calculate the match only where mask is 1
    matched = (pred == label) & (masks == 1)

    # Sum of matched elements divided by sum of mask elements
    return matched.sum() / masks.sum()

In [36]:
def calculate_accuracy(pred, label, masks):
    return (calculate_partial_match(pred, label, masks) == 1.0).mean()

In [37]:
def compute_metrics(outputs, labels, masks):
    return {
        "accuracy": calculate_accuracy(outputs, labels, masks),
        "partial_match": calculate_partial_match(outputs, labels, masks),
    }

In [38]:
@train
def training(model, tokenizer, processor, dataset, config):
    optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])

    train_dataloader = torch.utils.data.DataLoader(
        dataset["train"], batch_size=config["batch_size"], collate_fn=collate(mode="train", tokenizer=tokenizer, processor=processor), shuffle=True
    )

    val_dataloader = torch.utils.data.DataLoader(
        dataset["val"], batch_size=config["batch_size"], collate_fn=collate(mode="val", tokenizer=tokenizer, processor=processor), shuffle=False
    )

    model.train()

    train_loss = 0

    history = {"train_loss": [], "val_loss": [], "accuracy": [], "partial_match": []}
    # Calculate total number of training steps
    total_steps = len(train_dataloader) * config["epochs"]

    print(f"Total steps: {total_steps}")

    # Create the learning rate scheduler
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=config["warmup_steps"], num_training_steps=total_steps)

    for epoch in range(config["epochs"]):
        train_loss = 0
        steps = 0
        for batch in tqdm(train_dataloader, desc="Train Batches", total=len(train_dataloader)):
            outputs = model(**batch)

            loss = outputs["loss"] / config["gradient_accumulation_steps"]
            loss.backward()

            if (steps + 1) % config["gradient_accumulation_steps"] == 0:
                optimizer.step()
                scheduler.step()  # Update learning rate
                optimizer.zero_grad()

            train_loss += loss.item() * config["gradient_accumulation_steps"]
            steps += 1

        print(f"Epoch {epoch + 1}, Loss: {train_loss / len(train_dataloader)}")

        val_loss = 0
        with torch.no_grad():
            for batch in tqdm(val_dataloader, desc="Val Batches", total=len(val_dataloader)):
                outputs = model(**batch)
                loss = outputs["loss"]
                val_loss += loss.item()

                # outputs (B, 900, 10)
                B = outputs["outputs"].shape[0]
                pred = outputs["outputs"].reshape(B, 30, 30, 10).argmax(dim=-1).numpy()  # (B, 30, 30)
                labels = np.array(batch["targets"]["padded_matrix"])
                masks = np.array(
                    [pad_matrix(np.ones((r, c)), 30, 30) for r, c in zip(batch["targets"]["original_rows"], batch["targets"]["original_cols"])]
                )  # (B, 30, 30), where 1 is the original value and 0 is the padded value

                metrics = compute_metrics(pred, labels, masks)

                history["accuracy"].append(metrics["accuracy"])
                history["partial_match"].append(metrics["partial_match"])

        log.info(f"Epoch {epoch + 1}, Train Loss: {train_loss / len(train_dataloader)}", terminal=True)
        log.info(f"Epoch {epoch + 1}, Val Loss: {val_loss / len(val_dataloader)}", terminal=True)
        log.info(f"Epoch {epoch + 1}, Accuracy: {np.mean(history['accuracy'])}", terminal=True)
        log.info(f"Epoch {epoch + 1}, Partial Match: {np.mean(history['partial_match'])}", terminal=True)
        
        model.save_pretrained(f"models/checkpoints/arc-agi-llama-vllm-{epoch + 1}")

        history["train_loss"].append(train_loss / len(train_dataloader))
        history["val_loss"].append(val_loss / len(val_dataloader))

    return history

In [39]:
config = {"epochs": 50, "batch_size": 2, "lr": 5e-5, "gradient_accumulation_steps": 4, "warmup_steps": 200, "weight_decay": 0.01}

In [40]:
history = training(arc_model, models["tokenizer"], models["processor"], dataset, config)

Total steps: 760400


Train Batches:   0%|          | 0/15208 [00:00<?, ?it/s]

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
Train Batches:  37%|███▋      | 5562/15208 [3:46:22<5:59:01,  2.23s/it]

: 

In [47]:
@eval
def evaluate(model, tokenizer, processor, dataset, config):
    test_dataloader = torch.utils.data.DataLoader(
        dataset["test"], batch_size=config["batch_size"], collate_fn=collate(mode="test", tokenizer=tokenizer, processor=processor)
    )

    history = {"accuracy": [], "partial_match": []}

    with torch.no_grad():
        for batch in tqdm(test_dataloader, desc="Test Batches", total=len(test_dataloader)):
            outputs = model(**batch)

            B = outputs["outputs"].shape[0]
            pred = outputs["outputs"].reshape(B, 30, 30, 10).argmax(dim=-1).numpy()  # (B, 30, 30)
            labels = batch["targets"]["padded_matrix"]
            masks = np.array(
                [pad_matrix(np.ones((r, c)), 30, 30) for r, c in zip(batch["targets"]["original_rows"], batch["targets"]["original_cols"])]
            )  # (B, 30, 30), where 1 is the original value and 0 is the padded value

            metrics = compute_metrics(pred, labels, masks)

            history["accuracy"].append(metrics["accuracy"])
            history["partial_match"].append(metrics["partial_match"])

    return history

In [48]:
config = {"batch_size": 2}

In [49]:
eval_history = evaluate(arc_model, models["tokenizer"], models["processor"], dataset, config)
eval_history

Test Batches: 100%|██████████| 147/147 [06:40<00:00,  2.73s/it]


{'accuracy': [0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
 

In [50]:
np.mean(eval_history["accuracy"]), np.mean(eval_history["partial_match"])

(0.0, 0.4407089855812507)