In [1]:
%env CUDA_VISIBLE_DEVICES=1
%env TOKENIZERS_PARALLELISM=false

env: CUDA_VISIBLE_DEVICES=1
env: TOKENIZERS_PARALLELISM=false


In [9]:
BASE_PATH = '/home/stepan/kaggle-arc-agi'
MODEL_ID = '/home/stepan/kaggle-arc-agi/models/gemma-2-9b-it/checkpoint-750'
MAX_NEW_TOKENS = 2048
MAX_SEQ_LENGTH = 8192 - MAX_NEW_TOKENS

In [3]:
import sys

sys.path.append(BASE_PATH)
sys.path.append(f"{BASE_PATH}/scripts")

In [4]:
import torch  # type: ignore
import numpy as np  # type: ignore

from datasets import DatasetDict, Dataset  # type: ignore

from unsloth import FastLanguageModel  # type: ignore

from tqdm.auto import tqdm  # type: ignore

from logger import get_logger  # type: ignore
import train_utils  # type: ignore
import data_utils  # type: ignore

  from .autonotebook import tqdm as notebook_tqdm


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


In [5]:
log = get_logger(f"{BASE_PATH}/logs/gemma-2-9b", "arc-agi")

In [6]:
def get_model_tokenizer(dtype=None, load_in_4bit=True):
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=MODEL_ID,
        max_seq_length=MAX_SEQ_LENGTH,
        dtype=dtype,
        load_in_4bit=load_in_4bit,
        device_map={"": 0},
        attn_implementation="flash_attention_2",
        # token = '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', # use one if using gated models like meta-llama/Llama-2-7b-hf
    )

    return model, tokenizer

In [7]:
def eval(f):
    def wrapper(model, tokenizer, *args, **kwargs):
        FastLanguageModel.for_inference(model)
        return f(model, tokenizer, *args, **kwargs)

    return wrapper

In [10]:
model, tokenizer = get_model_tokenizer()

==((====))==  Unsloth 2024.9: Fast Gemma2 patching. Transformers = 4.43.4.
   \\   /|    GPU: NVIDIA RTX A5000. Max memory: 23.679 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.3.0+cu121. CUDA = 8.6. CUDA Toolkit = 12.1.
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.26.post1. FA2 = True]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth


Unsloth 2024.9 patched 42 layers with 42 QKV layers, 42 O layers and 42 MLP layers.


In [11]:
dataset = data_utils.prepare_dataset(tokenizer, fit_dataset=True, base_path=BASE_PATH)
dataset

Map: 100%|██████████| 430/430 [00:00<00:00, 1397.47 examples/s]
Map: 100%|██████████| 112/112 [00:00<00:00, 495.15 examples/s]
Map: 100%|██████████| 459/459 [00:00<00:00, 929.61 examples/s]


DatasetDict({
    train: Dataset({
        features: ['id', 'challenge', 'solution', 'texts', 'messages'],
        num_rows: 430
    })
    test: Dataset({
        features: ['id', 'challenge', 'solution', 'texts', 'messages'],
        num_rows: 321
    })
    val: Dataset({
        features: ['id', 'challenge', 'solution', 'texts', 'messages'],
        num_rows: 138
    })
    predict: Dataset({
        features: ['id', 'challenge', 'texts', 'messages'],
        num_rows: 112
    })
})

In [12]:
def generate_with_temp(model, inputs, temperature):
    outputs = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=True, temperature=temperature, top_k=50, use_cache=True)
    return outputs


def evaluate_batch(model, tokenizer, batch):
    inputs = {
        "input_ids": batch["input_ids"],
        "attention_mask": batch["attention_mask"],
    }

    with torch.no_grad():
        outputs1 = generate_with_temp(model, inputs, 0.3)
        outputs2 = generate_with_temp(model, inputs, 0.7)

    input_ids_length = inputs["input_ids"].shape[1]  # sequence length without new tokens
    new_tokens1 = outputs1[:, input_ids_length:]
    new_tokens2 = outputs2[:, input_ids_length:]

    generated_texts1 = tokenizer.batch_decode(new_tokens1, skip_special_tokens=True)
    generated_texts2 = tokenizer.batch_decode(new_tokens2, skip_special_tokens=True)

    return generated_texts1, generated_texts2

In [13]:
@eval
def evaluate(model, tokenizer, dataset, batch_size):
    eval_dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=train_utils.collate(mode="test", tokenizer=tokenizer),
    )

    challenge_ids = []
    preds = []
    labels = []
    for i, batch in tqdm(enumerate(eval_dataloader), total=len(eval_dataloader)):
        generated_texts1, generated_texts2 = evaluate_batch(model, tokenizer, batch)

        # Ensure solutions is always a list
        ids = batch["id"]
        challenges = batch["challenge"]
        solutions = batch["solution"]

        # I don't like how complicated this is, but I don't see an easier way to do it right now
        for gen_text1, gen_text2, label, challenge_id, challenge in zip(generated_texts1, generated_texts2, solutions, ids, challenges):
            parsed_output1 = train_utils.parse_output(gen_text1)
            parsed_output2 = train_utils.parse_output(gen_text2)

            if parsed_output1 is None and parsed_output2 is None:
                print(f"Failed to parse both outputs: {gen_text1} and {gen_text2}")
                preds.append(None)
            else:
                # Choose the best prediction based on partial match score
                score1 = train_utils.calculate_partial_match(parsed_output1, train_utils.tensor_to_int(label)) if parsed_output1 is not None else 0
                score2 = train_utils.calculate_partial_match(parsed_output2, train_utils.tensor_to_int(label)) if parsed_output2 is not None else 0
                best_pred = parsed_output1 if score1 >= score2 else parsed_output2
                preds.append(best_pred)

            labels.append(train_utils.tensor_to_int(label))
            challenge_ids.append((challenge_id, challenge["order"]))

    return {
        "ids": challenge_ids,
        "preds": preds,
        "labels": labels,
    }

In [14]:
print(dataset["test"][0]['texts'])

<bos><start_of_turn>user
You are a puzzle solving wizard. You are given a puzzle from the abstraction and reasoning corpus developed by Francois Chollet.

Here are the example input and output pairs from which you should learn the underlying rule to later predict the output for the given test input:
-----------------
<input>
0001000000
0011100000
0000006000
6600006600
6600000000
0011100000
0111100000
0000000000
0000000000
0000000000
</input>

<output>
0000000000
0000000000
0000000000
0000000000
0000000000
6000100000
6601110000
0000000000
6600111000
6601111000
</output>

<input>
0000000000
0000000000
0333300333
0330000000
0330444000
0000044000
0000040040
0000000040
0000000440
0000000000
</input>

<output>
0000000000
0000000000
0000000000
0000004000
0000004000
3330044000
0000000000
3333044400
3300004400
3300004000
</output>

<input>
0000000000
8800004000
8000444400
8880004000
0000000000
0000044440
0008800000
0088880000
0000000000
0000000000
</input>

<output>
0000000000
0000000000
000000

In [15]:
results = evaluate(model, tokenizer, dataset["test"], batch_size=1)
# Calculate metrics
accuracy, avg_partial_match = train_utils.calculate_metrics(results["preds"], results["labels"])

log.info(f"Exact match accuracy: {accuracy:.4f}")
log.info(f"Average partial match score: {avg_partial_match:.4f}")

 71%|███████▏  | 229/321 [2:25:05<1:50:09, 71.84s/it]

Failed to parse both outputs: <output>
683338183881631231113122682386
166881336388211368831218238882
883631683863128182132111111836
131281822213183631131111111263
812386888888888813811111118368
613186888888888811316111116283
333383888888888813116288888813
331321888888888881132632188828
183283888888888868882831336131
812316888888888883866183328881
331128888888888883818186281116
611216832281638263213338383112
132812881322223338321166383316
812631328133818366831813323388
881328131131183883333138288331
683281311311838863183888116331
381182332816166813183888881163
881236331883368181223838826861
318233633631131181318281121338
888128831133336333333138288128
683812838318331333333838383311
816312881138238333333838383313
212218631388822333333881311336
113663228263862333333881311326
136183213313838333333883833112
883882222886833333333822161868
331118833328233333333822811218
663336338888136833811161262638
128321333313866181621633383818
862638288311611218121318363338
</output and <output>
683338183

100%|█████████▉| 320/321 [3:20:38<01:17, 77.52s/it]  

Failed to parse both outputs: <output>
000000000333333300000000000000
000000000333333300000000000000
000000000033333000000000000000
000000000003333300000000000000
000444400333333300000000000000
000444400333333300000000000000
000444400033333300000000000000
000444400333333300000000000000
000000000333333300000000000000
000000000333333300000000000000
000000000333333300000000000000
000000000333333000000000000000
000000000333333300000000000000
600000000033333300000000000000
660000000033333330000000000000
666000000333333330000000000000
666600000333333330000002222000
000000000333333000000002222000
000000000033333300000002200000
000000000003333330000002200000
000000000033333333000002000000
000000000333333333000002000000
000000000333333333000002000000
000000000333333333000000000000
000000000033333333000000000000
000000000333333333000000000000
000000000333333333000000000000
000000000333333330000000000000
000000000333333330000000000000
000000000333333330000000000000
000000000333333330000000000000


100%|██████████| 321/321 [3:21:23<00:00, 37.64s/it]
Exact match accuracy: 0.0654
Average partial match score: 0.6283
