In [None]:
%env CUDA_VISIBLE_DEVICES=1
%env TOKENIZERS_PARALLELISM=false

In [None]:
BASE_PATH = "/home/stepan/kaggle-arc-agi"
MODEL_ID = f"{BASE_PATH}/models/gemma-2-2b-it/checkpoint-500"
MAX_NEW_TOKENS = 2048
MAX_SEQ_LENGTH = 8192 - MAX_NEW_TOKENS

In [None]:
import sys

sys.path.append(BASE_PATH)
sys.path.append(f"{BASE_PATH}/scripts")

In [None]:
import torch  # type: ignore
import numpy as np  # type: ignore

from datasets import DatasetDict, Dataset  # type: ignore

from tqdm.auto import tqdm  # type: ignore

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig  # type: ignore

from logger import get_logger  # type: ignore
import train_utils  # type: ignore
import data_utils  # type: ignore

In [None]:
log = get_logger(f"{BASE_PATH}/logs/gemma-2-2b", "arc-agi")

In [None]:
def get_model_tokenizer(load_in_4bit=True):
    quantization_config = BitsAndBytesConfig(load_in_4bit=load_in_4bit)

    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="left")
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        quantization_config=quantization_config,
        attn_implementation="flash_attention_2",
        torch_dtype=torch.bfloat16,
        device_map={"": 0},
    )

    return model, tokenizer

In [None]:
model, tokenizer = get_model_tokenizer()
model

In [None]:
dataset = data_utils.prepare_dataset(tokenizer, fit_dataset=True)
dataset

In [None]:
def evaluate_batch(model, tokenizer, batch, num_seq=5):
    inputs = {"input_ids": batch["input_ids"], "attention_mask": batch["attention_mask"]}

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=True,
            use_cache=True,
            num_beams=5,
            num_return_sequences=num_seq,
            temperature=0.5,
            top_k=50
        )

    input_ids_length = inputs["input_ids"].shape[1]  # sequence length without new tokens
    new_tokens = outputs[:, input_ids_length:]

    generated_texts = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)

    return generated_texts

In [None]:
def process_sequences(generated_texts, num_seq):
    parsed_outputs = [train_utils.parse_output(text) for text in generated_texts]
    res = []
    for i in range(0, len(parsed_outputs), num_seq):
        options = [opt for opt in parsed_outputs[i : i + num_seq] if opt is not None]
        if not options:
            res.append((None, None))
            continue

        # Group options by their structure (rows x columns)
        structure_groups = {}
        for option in options:
            rows = len(option)
            cols = len(option[0]) if rows > 0 else 0
            structure = (rows, cols)
            if structure not in structure_groups:
                structure_groups[structure] = []
            structure_groups[structure].append(option)

        # Select the group with the most options
        most_common_structure = max(structure_groups, key=lambda x: len(structure_groups[x]))
        selected_options = structure_groups[most_common_structure]

        # Get dimensions of the most common structure
        rows, cols = most_common_structure

        # Perform element-wise voting
        voted_option = [[None for _ in range(cols)] for _ in range(rows)]
        for row in range(rows):
            for col in range(cols):
                elements = [option[row][col] for option in selected_options]
                voted_option[row][col] = max(set(elements), key=elements.count)

        # Select the top 2 options based on similarity to the voted option
        def similarity_score(option):
            return sum(option[r][c] == voted_option[r][c] for r in range(rows) for c in range(cols))

        top_2_options = sorted(selected_options, key=similarity_score, reverse=True)[:2]
        res.append(tuple(top_2_options))  # TODO this or top2 + voted
    return res

In [None]:
def evaluate(model, tokenizer, dataset, batch_size, num_seq=5):
    eval_dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=False, collate_fn=train_utils.collate(mode="test", tokenizer=tokenizer)
    )

    challenge_ids = []
    preds = []
    labels = []
    for i, batch in tqdm(enumerate(eval_dataloader), total=len(eval_dataloader)):
        generated_texts = evaluate_batch(model, tokenizer, batch, num_seq=num_seq)  # (batch_size * num_return_sequences, seq_len)

        # Ensure solutions is always a list
        ids = batch["id"]
        challenges = batch["challenge"]
        solutions = batch["solution"]

        processed_outputs = process_sequences(generated_texts, num_seq)

        # I don't like how complicated this is, but I don't see an easier way to do it right now
        for (parsed_output1, parsed_output2), label, challenge_id, challenge in zip(processed_outputs, solutions, ids, challenges):

            if parsed_output1 is None or parsed_output2 is None:
                preds.append(None)
            else:
                # Choose the best prediction based on partial match score
                score1 = train_utils.calculate_partial_match(parsed_output1, train_utils.tensor_to_int(label)) if parsed_output1 is not None else 0
                score2 = train_utils.calculate_partial_match(parsed_output2, train_utils.tensor_to_int(label)) if parsed_output2 is not None else 0
                best_pred = parsed_output1 if score1 >= score2 else parsed_output2
                preds.append(best_pred)

            labels.append(train_utils.tensor_to_int(label))
            challenge_ids.append((challenge_id, challenge["order"]))

        if i % 2 == 0 and i > 0:
            break

    return {
        "ids": challenge_ids,
        "preds": preds,
        "labels": labels,
    }

In [None]:
results = evaluate(model, tokenizer, dataset["test"], batch_size=1)
# Calculate metrics
accuracy, avg_partial_match = train_utils.calculate_metrics(results["preds"], results["labels"])

log.info(f"Exact match accuracy: {accuracy:.4f}")
log.info(f"Average partial match score: {avg_partial_match:.4f}")