In [1]:
import torch
from pathlib import Path
from copy import deepcopy
import os
import json
import matplotlib.pyplot as plt
import seaborn as sn

In [2]:
import model as m

In [3]:
DEVICE = "cuda"
CHECKPOINT_PATH = Path("checkpoints/model_diff_best.tar")
DATA_PATH = Path("data/ARC-AGI-2/data/training")

In [4]:
checkpoint = torch.load(CHECKPOINT_PATH, weights_only=False)
config, tokenizer = checkpoint["config"], checkpoint["tokenizer"]
model = m.GPT(config=config, device=DEVICE).to(DEVICE)
model.load_state_dict(checkpoint["model_state_dict"])

<All keys matched successfully>

In [30]:
def create_context(task, test_index, tokenizer):
    new_task = {}
    new_task["context"] = deepcopy(task["train"])
    train = tokenizer.encode(task["train"])
    test = tokenizer.encode([task["test"][test_index]])
    soi, eoi, soo, eoo = test.index(10), test.index(11), test.index(12), test.index(13)
    test_input, solution = test[soi:eoi+1], test[soo+1:eoo+1]
    test_input = test_input + [tokenizer.special_tokens["start_of_output"]]
    context = train + test_input
    new_task["context"].append(test_input)
    return context, solution

In [None]:
def calculate_accuracy(
    model, context_tensor, solution, gen_len, mask_token, num_iter
):
    con_len = context_tensor.size(1)
    next_tokens = torch.tensor([mask_token] * gen_len, device="cuda").view(1, -1)
    context_masked = torch.cat((context_tensor, next_tokens), dim=-1)
    context_masked = context_masked[:, -2048:]
    for i in range(1, num_iter + 1):
        with torch.inference_mode():
            with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
                logits = model(context_masked)
            s = 1 - (i / num_iter)
            num_masks = int(gen_len * s)
            probs = torch.softmax(logits, dim=-1)[:, con_len:]
            top_probs, top_tokens = torch.max(probs, dim=-1)
            _, mask_indices = torch.topk(
                top_probs, k=num_masks, largest=False, dim=1
            )
            next_tokens = top_tokens.clone()
            next_tokens[:, mask_indices.view(-1)] = mask_token
            context_masked = torch.cat((context_tensor, next_tokens), dim=-1)
            context_masked = context_masked[:, -2048:]

    accuracy = torch.sum(
        solution.view(1, -1) == top_tokens[:, : solution.size(0)]
    ) / solution.size(0)

    return accuracy

In [60]:
index = 1
filelist = os.listdir(DATA_PATH)
task_path = DATA_PATH / filelist[index]
task = json.loads(task_path.read_text())

In [61]:
mask_token = tokenizer.special_tokens["mask_token"]
gen_len = 10
context, solution = create_context(task, 0, tokenizer)
context_tensor = torch.tensor(context, device="cuda").view(1, -1)
solution = torch.tensor(solution, device="cuda").view(1, -1)

In [62]:
calculate_accuracy(
    model,
    context_tensor,
    solution,
    gen_len,
    mask_token,
    num_iter=1,
)

RuntimeError: The size of tensor a (930) must match the size of tensor b (0) at non-singleton dimension 1

In [22]:
test = tokenizer.encode([task["test"][0]])
soi, eoi, soo, eoo = test.index(10), test.index(11), test.index(12), test.index(13)
test_input, solution = test[soi:eoi+1], test[soo:eoo+1]
test_input[-3:]

[0, 0, 11]

In [17]:
tokenizer.special_tokens

{'start_of_input': 10,
 'end_of_input': 11,
 'start_of_output': 12,
 'end_of_output': 13,
 'row_indicator': 14,
 'context_indicator': 15,
 'mask_token': 16}

In [12]:
from collections import Counter

In [14]:
tc = Counter(test)
tc

Counter({0: 189,
         3: 42,
         4: 30,
         14: 17,
         8: 10,
         15: 1,
         10: 1,
         11: 1,
         12: 1,
         13: 1})

In [5]:
from utils import plot_losses

In [6]:
results = checkpoint["results"]

In [7]:
plot_losses(results)

NameError: name 'plot_losses' is not defined

In [12]:
results["train_losses"][-20:], len(results["train_losses"])

([0.21436849447432904,
  0.2189096684404649,
  0.21017801734618843,
  0.23862626820802688,
  0.2284164482052438,
  0.24584521789569408,
  0.22737450523185543,
  0.2263147470052354,
  0.2353070877667051,
  0.21298127432819455,
  0.22931468431837856,
  0.2515357826091349,
  0.197449276142288,
  0.22920571574242785,
  0.24085723243653775,
  0.22511778990039602,
  0.23992838334990665,
  0.23327979229157791,
  0.24409853484481572,
  0.2230953457870055],
 3840)

In [10]:
results["val_losses"][180:]

[(3620, 0.12067817151546478),
 (3640, 0.12050575017929077),
 (3660, 0.11765394359827042),
 (3680, 0.11918774247169495),
 (3700, 0.11804640293121338),
 (3720, 0.1210043877363205),
 (3740, 0.11865618824958801),
 (3760, 0.11785630881786346),
 (3780, 0.12139032036066055),
 (3800, 0.12202248722314835),
 (3820, 0.12273457646369934),
 (3840, 0.11935696750879288)]

In [34]:
torch.tensor([3])[:, None, None, None].size()

torch.Size([1, 1, 1, 1])