In [None]:
%env HF_DATASETS_OFFLINE=1
%env HF_HUB_OFFLINE=1
%env TRANSFORMERS_OFFLINE=1
%env TOKENIZERS_PARALLELISM=false

In [None]:
deps_path = "/kaggle/input/unsloth-library-install-v2"

In [None]:
%%capture
! pip install --no-index --find-links {deps_path} pip3-autoremove -y
! pip-autoremove torch -y
! pip install --no-index --find-links {deps_path} torch
! pip install --no-index --find-links {deps_path} triton
! pip install --no-index --find-links {deps_path} "unsloth[kaggle-new]"

In [None]:
%%capture
deps_path_2 = '/kaggle/input/llama-3-arc-deps'
! pip install --no-index --find-links {deps_path_2} --requirement {deps_path_2}/requirements.txt

In [None]:
BASE_PATH = "/kaggle/input"
MODEL_ID = "/kaggle/input/gemma-2-2b-it-baseline/pytorch/default/3/home/stepan/kaggle-arc-agi/models/gemma-2-2b-it/baseline"
MAX_NEW_TOKENS = 2048
MAX_SEQ_LENGTH = 8192 - MAX_NEW_TOKENS

In [None]:
import json

import torch  # type: ignore
import numpy as np  # type: ignore

from datasets import DatasetDict, Dataset  # type: ignore

from tqdm.auto import tqdm  # type: ignore

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig  # type: ignore
import train_utils  # type: ignore
import data_utils  # type: ignore

In [None]:
def get_model_tokenizer():
    quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="left")
    config = AutoConfig.from_pretrained(MODEL_ID, local_files_only=True)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        quantization_config=quantization_config,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        local_files_only=True,
        config=config,
        max_memory = {0: "15.5GiB", "cpu": "16GiB"}
    )

    model.eval()

    return model, tokenizer

In [None]:
model, tokenizer = get_model_tokenizer()

In [None]:
dataset = data_utils.prepare_dataset(tokenizer, fit_dataset=True, base_path=BASE_PATH)
dataset

In [None]:
def generate_with_temp(model, inputs, temperature):
    outputs = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=True, temperature=temperature, top_k=50, use_cache=True)
    return outputs


def evaluate_batch(model, tokenizer, batch):
    inputs = {
        "input_ids": batch["input_ids"],
        "attention_mask": batch["attention_mask"],
    }

    with torch.no_grad():
        outputs1 = generate_with_temp(model, inputs, 0.3)
        outputs2 = generate_with_temp(model, inputs, 0.7)

    input_ids_length = inputs["input_ids"].shape[1]  # sequence length without new tokens
    new_tokens1 = outputs1[:, input_ids_length:]
    new_tokens2 = outputs2[:, input_ids_length:]

    generated_texts1 = tokenizer.batch_decode(new_tokens1, skip_special_tokens=True)
    generated_texts2 = tokenizer.batch_decode(new_tokens2, skip_special_tokens=True)

    return generated_texts1, generated_texts2

In [None]:
def predict(model, tokenizer, dataset, batch_size):
    eval_dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=train_utils.collate(mode="predict", tokenizer=tokenizer),
    )

    challenge_ids = []
    preds = []
    for i, batch in tqdm(enumerate(eval_dataloader), total=len(eval_dataloader)):
        generated_texts1, generated_texts2 = evaluate_batch(model, tokenizer, batch)

        ids = batch["id"]
        challenges = batch["challenge"]

        for gen_text1, gen_text2, challenge_id, challenge in zip(generated_texts1, generated_texts2, ids, challenges):
            parsed_output1 = train_utils.parse_output(gen_text1)
            parsed_output2 = train_utils.parse_output(gen_text2)

            if parsed_output1 is None and parsed_output2 is None:
                print(f"Failed to parse both outputs: {gen_text1} and {gen_text2}")
                preds.append({"attempt_1": [[0]], "attempt_2": [[0]]})
            else:
                parsed_output1 = parsed_output1 if parsed_output1 is not None else [[0]]
                parsed_output2 = parsed_output2 if parsed_output2 is not None else [[0]]
                preds.append({"attempt_1": parsed_output1, "attempt_2": parsed_output2})
            challenge_ids.append((challenge_id, challenge["order"]))
    return {"ids": challenge_ids, "preds": preds}

In [None]:
def group_preds_by_challenge_id(challenge_ids, preds):
    grouped_preds = {}
    for (challenge_id, order), pred in zip(challenge_ids, preds):
        if challenge_id not in grouped_preds:
            grouped_preds[challenge_id] = []

        # Check if we already have a prediction for this order
        existing_pred = next((p for p in grouped_preds[challenge_id] if p[0] == order), None)

        if existing_pred:
            # If we have a duplicate (same id and order), choose any (here, we keep the first one)
            continue
        else:
            # Add the new prediction with its order
            grouped_preds[challenge_id].append((order, pred))

    # Sort predictions by order for each challenge_id
    for challenge_id in grouped_preds:
        grouped_preds[challenge_id].sort(key=lambda x: x[0])
        # Remove the order information, keeping only the predictions
        grouped_preds[challenge_id] = [pred for _, pred in grouped_preds[challenge_id]]

    return grouped_preds

In [None]:
pred_results = predict(model, tokenizer, dataset["predict"], batch_size=1)
grouped_preds = group_preds_by_challenge_id(pred_results["ids"], pred_results["preds"])

In [None]:
len(grouped_preds)

In [None]:
# compare solutions with sample_submission.json
with open(f"{BASE_PATH}/arc-prize-2024/sample_submission.json", "r") as json_file:
    sample_submission = json.load(json_file)

# Check if all challenge_ids in sample_submission are in grouped_preds, and all tests have correct number of predictions
# also check if all predictions are 2d matrices of at least 1x1 size
for challenge_id in sample_submission:
    if challenge_id not in grouped_preds:
        print(f"Challenge ID {challenge_id} in sample_submission is not in grouped_preds.")
    elif len(grouped_preds[challenge_id]) != len(sample_submission[challenge_id]):
        print(
            f"Challenge ID {challenge_id} in sample_submission has {len(sample_submission[challenge_id])} predictions, but grouped_preds has {len(grouped_preds[challenge_id])}."
        )

    for pred in grouped_preds[challenge_id]:
        if not isinstance(pred, dict):
            print(f"Challenge ID {challenge_id} in sample_submission has invalid predictions: {pred}")
            continue
        if not isinstance(pred["attempt_1"], list) or not isinstance(pred["attempt_2"], list):
            print(f"Challenge ID {challenge_id} in sample_submission has invalid predictions: {pred}")
        if pred["attempt_1"] is None or pred["attempt_2"] is None:
            print(f"Challenge ID {challenge_id} in sample_submission has invalid predictions: {pred}")
        elif pred["attempt_1"] is None or len(pred["attempt_1"]) < 1 or len(pred["attempt_1"][0]) < 1:
            print(f"Challenge ID {challenge_id} in sample_submission has invalid predictions: {pred['attempt_1']}")
        elif pred["attempt_2"] is None or len(pred["attempt_2"]) < 1 or len(pred["attempt_2"][0]) < 1:
            print(f"Challenge ID {challenge_id} in sample_submission has invalid predictions: {pred['attempt_2']}")

In [None]:
with open("submission.json", "w") as json_file:
    json.dump(grouped_preds, json_file)