In [1]:
%env CUDA_VISIBLE_DEVICES=0
%env TOKENIZERS_PARALLELISM=false

env: CUDA_VISIBLE_DEVICES=0
env: TOKENIZERS_PARALLELISM=false


In [2]:
BASE_PATH = "/home/stepan/kaggle-arc-agi"
MODEL_ID = f"{BASE_PATH}/models/llama-3_1-8b-it"
MAX_NEW_TOKENS = 1024
MAX_SEQ_LENGTH = 32768 - MAX_NEW_TOKENS

In [3]:
import sys

sys.path.append(BASE_PATH)
sys.path.append(f"{BASE_PATH}/scripts")

In [4]:
import torch  # type: ignore
import numpy as np  # type: ignore

from datasets import DatasetDict, Dataset  # type: ignore

from unsloth import FastLanguageModel  # type: ignore

from tqdm.auto import tqdm  # type: ignore

from logger import get_logger  # type: ignore
import train_utils  # type: ignore
import data_utils  # type: ignore

  from .autonotebook import tqdm as notebook_tqdm


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


In [5]:
log = get_logger(f"{BASE_PATH}/logs/llama-3_1-8b", "arc-agi")

In [6]:
def get_model_tokenizer(dtype=None, load_in_4bit=True):
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=MODEL_ID,
        max_seq_length=MAX_SEQ_LENGTH,
        dtype=dtype,
        load_in_4bit=load_in_4bit,
        attn_implementation="flash_attention_2",
        device_map="auto",
        max_memory = {0: "23GiB", "cpu": "16GiB"},
    )

    return model, tokenizer

In [7]:
def eval(f):
    def wrapper(model, tokenizer, *args, **kwargs):
        FastLanguageModel.for_inference(model)
        return f(model, tokenizer, *args, **kwargs)

    return wrapper

In [8]:
model, tokenizer = get_model_tokenizer()

==((====))==  Unsloth 2024.9: Fast Llama patching. Transformers = 4.43.4.
   \\   /|    GPU: NVIDIA RTX A5000. Max memory: 23.677 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.3.0+cu121. CUDA = 8.6. CUDA Toolkit = 12.1.
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.26.post1. FA2 = True]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth


Unsloth 2024.9 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


In [9]:
dataset = data_utils.prepare_dataset(tokenizer, fit_dataset=False, base_path=BASE_PATH)
dataset

Map: 100%|██████████| 105/105 [00:00<00:00, 1424.91 examples/s]
Map: 100%|██████████| 416/416 [00:00<00:00, 964.19 examples/s]
Map: 100%|██████████| 419/419 [00:00<00:00, 958.41 examples/s]
Map: 100%|██████████| 416/416 [00:00<00:00, 1421.92 examples/s]
Map: 100%|██████████| 419/419 [00:00<00:00, 708.98 examples/s]


DatasetDict({
    train: Dataset({
        features: ['id', 'challenge', 'solution', 'texts', 'messages'],
        num_rows: 416
    })
    test: Dataset({
        features: ['id', 'challenge', 'solution', 'texts', 'messages'],
        num_rows: 293
    })
    val: Dataset({
        features: ['id', 'challenge', 'solution', 'texts', 'messages'],
        num_rows: 126
    })
    predict: Dataset({
        features: ['id', 'challenge', 'texts', 'messages'],
        num_rows: 105
    })
})

In [10]:
def generate_with_temp(model, inputs, temperature):
    outputs = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=True, temperature=temperature, top_k=50, use_cache=True)
    return outputs


def evaluate_batch(model, tokenizer, batch):
    inputs = {
        "input_ids": batch["input_ids"],
        "attention_mask": batch["attention_mask"],
    }

    with torch.no_grad():
        outputs1 = generate_with_temp(model, inputs, 0.3)
        outputs2 = generate_with_temp(model, inputs, 0.7)

    input_ids_length = inputs["input_ids"].shape[1]  # sequence length without new tokens
    new_tokens1 = outputs1[:, input_ids_length:]
    new_tokens2 = outputs2[:, input_ids_length:]

    generated_texts1 = tokenizer.batch_decode(new_tokens1, skip_special_tokens=True)
    generated_texts2 = tokenizer.batch_decode(new_tokens2, skip_special_tokens=True)

    return generated_texts1, generated_texts2

In [11]:
@eval
def evaluate(model, tokenizer, dataset, batch_size):
    eval_dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=train_utils.collate(mode="test", tokenizer=tokenizer),
    )

    challenge_ids = []
    preds = []
    labels = []
    for i, batch in tqdm(enumerate(eval_dataloader), total=len(eval_dataloader)):
        generated_texts1, generated_texts2 = evaluate_batch(model, tokenizer, batch)

        # Ensure solutions is always a list
        ids = batch["id"]
        challenges = batch["challenge"]
        solutions = batch["solution"]

        # I don't like how complicated this is, but I don't see an easier way to do it right now
        for gen_text1, gen_text2, label, challenge_id, challenge in zip(generated_texts1, generated_texts2, solutions, ids, challenges):
            parsed_output1 = train_utils.parse_output(gen_text1)
            parsed_output2 = train_utils.parse_output(gen_text2)

            if parsed_output1 is None and parsed_output2 is None:
                print(f"Failed to parse both outputs: {gen_text1} and {gen_text2}")
                preds.append(None)
            else:
                # Choose the best prediction based on partial match score
                score1 = train_utils.calculate_partial_match(parsed_output1, train_utils.tensor_to_int(label)) if parsed_output1 is not None else 0
                score2 = train_utils.calculate_partial_match(parsed_output2, train_utils.tensor_to_int(label)) if parsed_output2 is not None else 0
                best_pred = parsed_output1 if score1 >= score2 else parsed_output2
                preds.append(best_pred)

            labels.append(train_utils.tensor_to_int(label))
            challenge_ids.append((challenge_id, challenge["order"]))

    return {
        "ids": challenge_ids,
        "preds": preds,
        "labels": labels,
    }

In [12]:
results = evaluate(model, tokenizer, dataset["test"], batch_size=1)
# Calculate metrics
accuracy, avg_partial_match = train_utils.calculate_metrics(results["preds"], results["labels"])

log.info(f"Exact match accuracy: {accuracy:.4f}")
log.info(f"Average partial match score: {avg_partial_match:.4f}")

 66%|██████▌   | 192/293 [2:10:13<5:27:06, 194.33s/it]

Failed to parse both outputs: <output>
0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000

 80%|███████▉  | 233/293 [2:26:07<3:07:46, 187.78s/it]

Failed to parse both outputs: <output>
009009009
009009009
009009009
009009009
009009009
009009009
009009009
009009009
009009009
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
01109111
0110911

 88%|████████▊ | 259/293 [2:39:31<05:47, 10.21s/it]   

Failed to parse both outputs: <output>
0070070000000000000070000
0000000000000700000070000
0070000000007080000000000
0000000707080007007787700
0070000007080077000000000
0000000000000000000000000
0077780007000000000000000
0000000070007780000000000
0000000000000000000000000
0000000000000007787780000
0000000000000070000000000
7000000000000000000000000
0070000787700000000070000
0000070000000000000000000
7000000000000007000000000
0000000000000000000000000
0007000000000000000000000
0000000077780000700000070
0000070000000000000000700
0000000000700000000000000
0000000000000000000000000
07000000000000007070877870
0000007000000000000000000
0000000000070000007000000
0000000000000000700000000
0000000000000000000070000
0060000607000070006007000
</output> and <output>
0070070000000000000070000
0000000000000700000070000
0070000000007070000000000
0000000707070007007787700
0070000007070077000000000
0000000000000000000000000
0077770007000000000000000
0000000070007780000000000
0000000000000000000000000
0

100%|██████████| 293/293 [2:48:30<00:00, 34.51s/it] 
Exact match accuracy: 0.0375
Average partial match score: 0.5913
