In [None]:
# ! pip install torch==2.4.1 torchvision==0.19.0
# ! pip install accelerate==0.34.2
# ! pip install transformers==4.45.1
# ! pip install unsloth==2024.9.post3
! pip install bitsandbytes==0.44.0
! pip install qwen-vl-utils

In [None]:
%env CUDA_VISIBLE_DEVICES=0,1
%env TOKENIZERS_PARALLELISM=false

In [None]:
BASE_PATH = "/kaggle/input"
# MODEL_ID = f"unsloth/Meta-Llama-3.1-8B-bnb-4bit"
MODEL_ID = "unsloth/Llama-3.2-3B-Instruct-bnb-4bit"
# VLLM_MODEL_ID = "unsloth/Llama-3.2-11B-Vision-Instruct"
VLLM_MODEL_ID = "4bit/Qwen2-VL-7B-Instruct"
MAX_NEW_TOKENS = 2048
MAX_SEQ_LENGTH = 32768 - MAX_NEW_TOKENS

In [None]:
import sys

# sys.path.append(BASE_PATH)
# sys.path.append(f"{BASE_PATH}/scripts")
sys.path.append('/kaggle/input/arc-agi-python-utilities')

In [None]:
import io
import json
import base64
from PIL import Image

import numpy as np
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal

import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig  # type: ignore
from transformers import MllamaForConditionalGeneration, Qwen2VLForConditionalGeneration, AutoProcessor
from transformers import Trainer, TrainingArguments

from datasets import Dataset, DatasetDict  # type: ignore
from datasets import concatenate_datasets  # type: ignore

from qwen_vl_utils import process_vision_info # type: ignore

import data_utils  # type: ignore

In [None]:
dtype = torch.bfloat16

In [None]:
def get_models():
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="left")
    llm_model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=dtype,
        device_map="auto",
        max_memory={0: "15GiB", "cpu": "16GiB"},
        attn_implementation="sdpa",
        output_hidden_states=True,
        return_dict_in_generate=True,
        quantization_config=BitsAndBytesConfig(load_in_4bit=True),
    )

    processor = AutoProcessor.from_pretrained(VLLM_MODEL_ID)
    vllm_model = Qwen2VLForConditionalGeneration.from_pretrained(
        VLLM_MODEL_ID,
        torch_dtype=dtype,
        device_map="auto",
        max_memory={1: "15GiB", "cpu": "16GiB"},
        attn_implementation="sdpa",
        output_hidden_states=True,
        return_dict_in_generate=True,
        quantization_config=BitsAndBytesConfig(load_in_4bit=True),
    )

    return {"llm": llm_model, "tokenizer": tokenizer, "vllm": vllm_model, "processor": processor}

In [None]:
models = get_models()

In [None]:
TRAIN_SYSTEM_PROMPT = (
    """You are a puzzle solving wizard. You are given a puzzle from the abstraction and reasoning corpus developed by Francois Chollet."""
)
TEST_SYSTEM_PROMPT = (
    """You are a puzzle solving wizard. You are given a puzzle from the abstraction and reasoning corpus developed by Francois Chollet."""
)

TRAIN_PROMPT = """Here are the example input and output pairs from which you should learn the underlying rule to later predict the output for the given test input:
-----------------
{training_data}"""

TEST_PROMPT = """Now, solve the following puzzle based on its input grid by applying the rules you have learned from the training data.:
-----------------
{input_test_data}
-----------------
What is the output grid? Only provide the output grid in the form as in the example input and output pairs. Do not provide any additional information:"""

TRAIN_IMAGE_PROMPT = "Describe the images"
TEST_IMAGE_PROMPT = "Describe the image"

In [None]:
def list_to_image(integer_list_2d, target_size=30):
    # Convert the 2D list to a NumPy array
    array = np.array(integer_list_2d)

    # Get the unique values in the array
    unique_values = np.unique(array)

    # Create a colormap
    cmap = plt.get_cmap("tab10")

    # Create a color lookup dictionary
    color_lookup = {value: cmap(i % 10)[:3] for i, value in enumerate(unique_values)}

    # Create an RGB array
    rgb_array = np.array([[color_lookup[val] for val in row] for row in array])

    # Convert to 8-bit color values
    rgb_array = (rgb_array * 255).astype(np.uint8)

    # Create an image from the colored array
    image = Image.fromarray(rgb_array, mode="RGB")

    # Create a new blank image with the target size
    new_image = Image.new("RGB", (target_size, target_size), color=(0, 0, 0))

    # Paste the original image onto the new image
    new_image.paste(image, (0, 0))

    new_image = new_image.resize((target_size * 15, target_size * 15), Image.NEAREST)

    return new_image

def pil_image_to_base64(image):
    buffered = io.BytesIO()
    image.save(buffered, format="PNG")
    return 'data:image;base64,' + base64.b64encode(buffered.getvalue()).decode('utf-8')


In [None]:
def prepare_inputs(dct, prepare_solution=False):
    if prepare_solution:
        return "<output>\n" + "\n".join(" ".join(map(str, row)) for row in dct) + "\n</output>"
    else:
        input_str = "\n".join(" ".join(map(str, row)) for row in dct["input"])
        output_str = "\n".join(" ".join(map(str, row)) for row in dct["output"]) if "output" in dct else ""
        text = f"<input>\n{input_str}\n</input>"
        if output_str:
            text += f"\n\n<output>\n{output_str}\n</output>"
        return text

In [None]:
def to_dataset(data, solutions=None):
    restructured_data = {
        "id": [],
        "challenge": [],
    }
    if solutions is not None:
        restructured_data["solution"] = []

    for challenge_id, challenge_data in data.items():  # for all challenges
        for test_id, task in enumerate(
            challenge_data["test"]
        ):  # for all test tasks in this challenge we want to expand dataset so that each test task is separate dataset record
            restructured_data["id"].append(challenge_id)
            restructured_data["challenge"].append({"train": challenge_data["train"], "test": task, "order": test_id})
            if solutions is not None:
                restructured_data["solution"].append(solutions[challenge_id][test_id])

    return Dataset.from_dict(restructured_data)


def prepare_inputs(dct, prepare_solution=False):
    if prepare_solution:
        return "<output>\n" + "\n".join(" ".join(map(str, row)) for row in dct) + "\n</output>"
    else:
        input_str = "\n".join(" ".join(map(str, row)) for row in dct["input"])
        output_str = "\n".join(" ".join(map(str, row)) for row in dct["output"]) if "output" in dct else ""
        text = f"<input>\n{input_str}\n</input>"
        if output_str:
            text += f"\n\n<output>\n{output_str}\n</output>"
        return text


def prepare_dataset(tokenizer, base_path=None, final_training=False):
    # Load all datasets
    training_challenges = data_utils.load_data(f"{base_path}/arc-prize-2024/arc-agi_training_challenges.json")
    training_solutions = data_utils.load_data(f"{base_path}/arc-prize-2024/arc-agi_training_solutions.json")
    evaluation_challenges = data_utils.load_data(f"{base_path}/arc-prize-2024/arc-agi_evaluation_challenges.json")
    evaluation_solutions = data_utils.load_data(f"{base_path}/arc-prize-2024/arc-agi_evaluation_solutions.json")
    test_challenges = data_utils.load_data(f"{base_path}/arc-prize-2024/arc-agi_test_challenges.json")

    train_dataset = to_dataset(training_challenges, training_solutions)
    eval_dataset = to_dataset(evaluation_challenges, evaluation_solutions)
    pred_dataset = to_dataset(test_challenges)

    def create_chat(challenge, solution=None):
        train_input = TRAIN_PROMPT.format(
            training_data="\n\n".join([prepare_inputs(ex) for ex in challenge["train"]]),
        )
        test_input = TEST_PROMPT.format(
            input_test_data=prepare_inputs(challenge["test"]),
        )

        train_text_messages = [
            {"role": "system", "content": TRAIN_SYSTEM_PROMPT},
            {"role": "user", "content": train_input},
        ]

        test_text_messages = [
            {"role": "system", "content": TEST_SYSTEM_PROMPT},
            {"role": "user", "content": test_input},
        ]

        train_image_messages = [
            {
                "role": "user",
                "content": [
                    *[{"type": "image", "image": pil_image_to_base64(list_to_image(example["input"]))} for example in challenge["train"]],
                    {"type": "text", "text": TRAIN_IMAGE_PROMPT},
                ],
            },
        ]

        test_image_messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": pil_image_to_base64(list_to_image(challenge["test"]["input"]))},
                    {"type": "text", "text": TEST_IMAGE_PROMPT},
                ],
            },
        ]

        if solution:
            test_text_messages.append(
                {
                    "role": "assistant",
                    "content": prepare_inputs(solution, prepare_solution=True),
                }
            )

        return {
            "train_text_messages": train_text_messages,
            "test_text_messages": test_text_messages,
            "train_image_messages": train_image_messages,
            "test_image_messages": test_image_messages,
        }

    def process_dataset(examples, solutions=None):
        # Create messages for each challenge-solution pair
        chats = []
        for challenge, solution in zip(examples["challenge"], solutions or [None] * len(examples["challenge"])):
            chat = create_chat(challenge, solution)
            chats.append(chat)

        return {"messages": chats}

    pred_dataset = pred_dataset.map(lambda x: process_dataset(x), batched=True)
    train_dataset = train_dataset.map(lambda x: process_dataset(x, train_dataset["solution"]), batched=True)
    eval_dataset = eval_dataset.map(lambda x: process_dataset(x, eval_dataset["solution"]), batched=True)

    if final_training:  # if final training, we need to add the validation dataset to the training dataset
        train_dataset = concatenate_datasets([train_dataset, eval_dataset]).shuffle(seed=42)

        return DatasetDict(
            {
                "train": train_dataset,
                "predict": pred_dataset,
            }
        )

    test_dataset = eval_dataset.train_test_split(test_size=0.3)

    dataset = DatasetDict(
        {
            "train": train_dataset,
            "test": test_dataset["train"],
            "val": test_dataset["test"],
            "predict": pred_dataset,
        }
    )

    return dataset

In [None]:
dataset = prepare_dataset(models["tokenizer"], base_path=BASE_PATH, final_training=False)
dataset

In [None]:
def eval(f):
    def wrapper(model, *args, **kwargs):
        if hasattr(model, "to_inference"):
            model.to_inference()
        else:
            model.eval()
        with torch.no_grad():
            return f(model, *args, **kwargs)

    return wrapper


def train(f):
    def wrapper(model, *args, **kwargs):
        if hasattr(model, "to_training"):
            model.to_training()
        else:
            model.train()
        return f(model, *args, **kwargs)

    return wrapper


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
@eval
def describe_puzzle(model, processor, image, prompt):
    # Create prompt
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": prompt},
            ],
        },
    ]

    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = processor(text=[text], images=[image], return_tensors="pt")
    inputs = inputs.to(model.device)

    # Run inference
    generated_ids = model.generate(**inputs, max_new_tokens=128)
    generated_ids = generated_ids[0, inputs.input_ids.shape[1] :]
    generated_text = processor.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
    return generated_text

In [None]:
# image = list_to_image(dataset["train"][10]["challenge"]["train"][0]["input"])
# image

In [None]:
# describe_puzzle(models['vllm'], models['processor'], image, "Describe the image")

In [None]:
class Encoder(nn.Module):
    def __init__(self, input_dim, condition_dim, latent_dim, hidden_dim):
        super(Encoder, self).__init__()
        self.condition_dim = condition_dim

        self.query = nn.Linear(input_dim, hidden_dim, dtype=dtype)
        self.key = nn.Linear(input_dim, hidden_dim, dtype=dtype)
        self.value = nn.Linear(input_dim, hidden_dim, dtype=dtype)

        self.attention = nn.MultiheadAttention(hidden_dim, num_heads=4, dtype=dtype)

        self.fc1 = nn.Linear(hidden_dim, hidden_dim, dtype=dtype)

        self.fc_mu = nn.Linear(hidden_dim, latent_dim, dtype=dtype)  # Mean of the latent space
        self.fc_var = nn.Linear(hidden_dim, latent_dim, dtype=dtype)  # Variance of the latent space

    def forward(self, x, condition):
        # Add the condition to the input
        x_cond = torch.cat([x, condition], dim=1)

        # Apply attention
        attn_output, _ = self.attention(self.query(x_cond), self.key(x_cond), self.value(x_cond))
        h = F.relu(self.fc1(attn_output.mean(dim=1)))  # Reduce to a single representation per sample

        # Compute the mean and variance for the latent space
        mu = self.fc_mu(h)
        log_var = self.fc_var(h)

        return mu, log_var

In [None]:
def reparameterize(mu, log_var):
    std = torch.exp(0.5 * log_var)
    eps = torch.randn_like(std)
    return mu + eps * std

In [None]:
class Decoder(nn.Module):
    def __init__(self, latent_dim, condition_dim, output_dim, hidden_dim):
        super(Decoder, self).__init__()
        self.condition_dim = condition_dim
        self.fc1 = nn.Linear(latent_dim + condition_dim, hidden_dim, dtype=dtype)

        self.query = nn.Linear(hidden_dim, hidden_dim, dtype=dtype)
        self.key = nn.Linear(hidden_dim, hidden_dim, dtype=dtype)
        self.value = nn.Linear(hidden_dim, hidden_dim, dtype=dtype)

        self.attention = nn.MultiheadAttention(hidden_dim, num_heads=4, dtype=dtype)
        self.fc_output = nn.Linear(
            hidden_dim, output_dim * output_dim * 10, dtype=dtype
        )  # output is the 30x30 image with each pixel being a vector of logits

    def forward(self, z, condition, output_len):
        # Combine latent variable z and condition
        z_cond = torch.cat([z.unsqueeze(1).repeat(1, condition.shape[1], 1), condition], dim=-1)

        h = F.relu(self.fc1(z_cond))

        # Apply attention to guide the generation process
        attn_output, _ = self.attention(self.query(h), self.key(h), self.value(h))

        # Generate output
        output = torch.softmax(self.fc_output(attn_output), dim=-1)

        return output

In [None]:
class CVAE(nn.Module):
    def __init__(self, input_dim, condition_dim, latent_dim, output_dim, hidden_dim):
        super(CVAE, self).__init__()
        self.encoder = Encoder(input_dim, condition_dim, latent_dim, hidden_dim)
        self.decoder = Decoder(latent_dim, condition_dim, output_dim, hidden_dim)

    def forward(self, x, condition, output_len):
        # Encode
        mu, log_var = self.encoder(x, condition)

        # Reparameterization trick
        z = reparameterize(mu, log_var)  # (B, latent_dim)

        # Decode
        output = self.decoder(z, condition, output_len)  # (B, output_len, output_dim * output_dim * 10)

        return output, mu, log_var

In [None]:
class ARCModel(torch.nn.Module):
    def __init__(self, llm_model, vllm_model):
        super().__init__()
        self.llm_model = llm_model
        self.vllm_model = vllm_model

        self.text_proj = nn.Linear(3072, 2304, dtype=dtype)
        self.image_proj = nn.Linear(3584, 2304, dtype=dtype)

        self.output_dim = 30

        self.cvae = CVAE(input_dim=2304, condition_dim=2304, latent_dim=512, output_dim=self.output_dim, hidden_dim=1024)

    def to(self, device):
        self.device = device
        self.cvae.to(device)
        self.text_proj.to(device)
        self.image_proj.to(device)
        return self

    def to_inference(self):
        self.llm_model.eval()
        self.vllm_model.eval()

    def to_training(self):
        self.llm_model.train()
        self.vllm_model.train()

    def cvae_loss(self, recon_x, x, mu, log_var):
        recon_loss = F.binary_cross_entropy_with_logits(recon_x, x, reduction="sum")  # TODO: try BCELoss
        # KL Divergence loss
        kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        return recon_loss + kl_loss

    def encode(self, text_inputs, image_inputs):
        with torch.no_grad():
            text_features = self.llm_model(**text_inputs.to(self.llm_model.device)).hidden_states[-1]  # (batch_size, seq_len, 3072)
            image_features = self.vllm_model(**image_inputs.to(self.vllm_model.device)).hidden_states[-1]  # (batch_size, vid_len, 3584)

        # -- todo: cleanup
        text_inputs.to("cpu")
        image_inputs.to("cpu")

        torch.cuda.empty_cache()
        # -- todo: cleanup

        text_features = self.text_proj(text_features.to(self.device))
        image_features = self.image_proj(image_features.to(self.device))

        features = torch.cat([text_features, image_features], dim=1)  # (batch_size, seq_len + vid_len, 2304)
        return features

    def forward(self, train_inputs, test_inputs, targets=None):
        train_features = self.encode(text_inputs=train_inputs["text"], image_inputs=train_inputs["image"])  # (B, seq_len + vid_len, 2304)
        test_features = self.encode(text_inputs=test_inputs["text"], image_inputs=test_inputs["image"])  # (B, seq_len + vid_len, 2304)

        outputs, mu, log_var = self.cvae(train_features, test_features, output_len=30)  # (B, cond_seq_len, 30)
        
        B = outputs.shape[0]
        outputs = outputs[:, 0, :].reshape(B, self.output_dim * self.output_dim, 10).cpu().float()
        labels = F.one_hot(torch.tensor(targets).reshape(B, self.output_dim * self.output_dim), num_classes=10).float()

        if targets is not None:
            loss = self.cvae_loss(outputs, labels, mu, log_var)
            return {"loss": loss, "outputs": outputs, "mu": mu, "log_var": log_var}

        # we will only take (B, 30, 30) for the loss calculation
        return {"loss": None, "outputs": outputs, "mu": mu, "log_var": log_var}

    def from_pretrained(self, path):
        # self.space_model.load_state_dict(torch.load(f"{path}/space_model.pth"))
        # self.classifier.load_state_dict(torch.load(f"{path}/classifier.pth"))
        # return self
        ...

    def save_pretrained(self, path):
        # self.base_model.save_pretrained(f"{path}/base")
        # torch.save(self.space_model.state_dict(), f"{path}/space_model.pth")
        # torch.save(self.classifier.state_dict(), f"{path}/classifier.pth")
        ...

In [None]:
arc_model = ARCModel(models["llm"], models["vllm"])
arc_model.to("cuda:0")

In [None]:
def pad_matrix(matrix, target_rows, target_cols, pad_value=0):
    # Pad existing rows to target column length
    padded_matrix = [row + [pad_value] * (target_cols - len(row)) for row in matrix]

    # Add new rows if necessary
    while len(padded_matrix) < target_rows:
        padded_matrix.append([pad_value] * target_cols)

    return padded_matrix

In [None]:
def collate(mode, tokenizer, processor):
    def convert_to_pil_image(image_dict):
        if isinstance(image_dict, dict) and "bytes" in image_dict:
            return Image.open(io.BytesIO(image_dict["bytes"]))
        return image_dict

    def prepare_inputs(text_messages, image_messages):
        
        def clean_none_values(messages):
            return [{k: v for k, v in message.items() if v is not None} for message in messages]
        
        image_messages = [[{**msg, 'content': clean_none_values(msg['content'])} for msg in msgs] for msgs in image_messages]
        
        text_encodings = tokenizer.apply_chat_template(
            text_messages,
            tokenize=True,
            add_generation_prompt=(mode not in ["train", "val"]),
            return_tensors="pt",
            return_dict=True,
            padding=True,
        )

        image_text = processor.apply_chat_template(
            image_messages, tokenize=False, add_generation_prompt=True
        )
        image_inputs, _ = process_vision_info(image_messages)

        image_encodings = processor(
            text=image_text,
            images=image_inputs,
            padding=True,
            return_tensors="pt",
        )

        return text_encodings, image_encodings

    def collate_fn(batch):
        # Separate the different components of the batch
        # For 'test' mode, remove the last assistant message from each entry
        train_text_messages = [item["messages"]["train_text_messages"] for item in batch]
        train_image_messages = [item["messages"]["train_image_messages"] for item in batch]

        test_text_messages = [item["messages"]["test_text_messages"] for item in batch]
        test_image_messages = [item["messages"]["test_image_messages"] for item in batch]

        # Tokenize the texts
        train_text_encodings, train_image_encodings = prepare_inputs(train_text_messages, train_image_messages)
        test_text_encodings, test_image_encodings = prepare_inputs(test_text_messages, test_image_messages)

        # If 'solution' is present (for training/validation data)
        if "solution" in batch[0]:
            solutions = [pad_matrix(item["solution"], target_rows=30, target_cols=30) for item in batch]
            return {
                "train_inputs": {"text": train_text_encodings, "image": train_image_encodings},
                "test_inputs": {"text": test_text_encodings, "image": test_image_encodings},
                "targets": solutions,
            }
        else:
            return {
                "train_inputs": {"text": train_text_encodings, "image": train_image_encodings},
                "test_inputs": {"text": test_text_encodings, "image": test_image_encodings},
            }

    return collate_fn

In [None]:
dataloader = torch.utils.data.DataLoader(
    dataset["train"], batch_size=1, collate_fn=collate(mode="train", tokenizer=models["tokenizer"], processor=models["processor"])
)


def print_recursive(obj, indent=0):
    if isinstance(obj, torch.Tensor):
        print("  " * indent + str(obj.shape))
    elif (
        isinstance(obj, dict)
        or isinstance(obj, transformers.tokenization_utils_base.BatchEncoding)
        or isinstance(obj, transformers.feature_extraction_utils.BatchFeature)
    ):
        for key, value in obj.items():
            print("  " * indent + str(key) + ":")
            print_recursive(value, indent + 1)
    elif isinstance(obj, list):
        print("  " * indent + f"List of length: {len(obj)}, {len(obj[0])}, {len(obj[0][0])}")
    else:
        print("  " * indent + str(obj))


for batch in dataloader:
    print_recursive(batch)
#     outputs = arc_model(**batch)
#     print('-'* 30)
#     print_recursive(outputs)
    break

In [None]:
def compute_metrics(pred):
    raise ValueError(pred)
    return {
        "accuracy": accuracy,
        "f1": f1,
        "precision": precision,
        "recall": recall,
    }

In [None]:
@train
def training(model, tokenizer, processor, dataset, config):
    optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"])
    
    train_dataloader = torch.utils.data.DataLoader(
        dataset["train"], batch_size=config["batch_size"], collate_fn=collate(mode="train", tokenizer=tokenizer, processor=processor)
    )
    
    val_dataloader = torch.utils.data.DataLoader(
        dataset["val"], batch_size=config["batch_size"], collate_fn=collate(mode="val", tokenizer=tokenizer, processor=processor)
    )
    
    model.train()
    
    train_loss = 0
    
    history = {'train_loss': [], 'val_loss': []}
    for epoch in tqdm(range(config["epochs"]), desc="Epochs", total=config["epochs"]):
        for batch in tqdm(train_dataloader, desc="Train Batches", total=len(train_dataloader)):
            optimizer.zero_grad()
            outputs = model(**batch)
            
            loss = outputs["loss"]
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            
            break

        print(f"Epoch {epoch + 1}, Loss: {train_loss / len(train_dataloader)}")
        
        val_loss = 0
        for batch in tqdm(val_dataloader, desc="Val Batches", total=len(val_dataloader)):
            outputs = model(**batch)
            loss = outputs["loss"]
            val_loss += loss.item()
            
            break
            
        print(f"Epoch {epoch + 1}, Val Loss: {val_loss / len(val_dataloader)}")
        
        history['train_loss'].append(train_loss / len(train_dataloader))
        history['val_loss'].append(val_loss / len(val_dataloader))
        
    return history

In [None]:
config = {
    'epochs': 5,
    'batch_size': 2,
    'lr': 2e-5,
}

In [None]:
history = training(arc_model, models['tokenizer'], models['processor'], dataset, config)