In [1]:
import torch
from pathlib import Path
from copy import deepcopy
import os
import json
import matplotlib.pyplot as plt
import seaborn as sn

In [2]:
import model as m

In [3]:
DEVICE = "cuda"
CHECKPOINT_PATH = Path("checkpoints/model_diff_best.tar")
DATA_PATH = Path("data/pretraining")

In [4]:
checkpoint = torch.load(CHECKPOINT_PATH, weights_only=False)
config, tokenizer = checkpoint["config"], checkpoint["tokenizer"]
model = m.GPT(config=config, device=DEVICE).to(DEVICE)
model.load_state_dict(checkpoint["model_state_dict"])
mask_token = tokenizer.special_tokens["mask_token"]

In [5]:
tokenizer.special_tokens

{'start_of_input': 10,
 'end_of_input': 11,
 'start_of_output': 12,
 'end_of_output': 13,
 'row_indicator': 14,
 'context_indicator': 15,
 'mask_token': 16}

In [6]:
from get_dataloaders import CustomDataset
from torch.utils.data import DataLoader

In [59]:
def calculate_accuracy(model, x, y, mask_token, num_iter):
    mask_ids = (x == mask_token).nonzero(as_tuple=True)[1]
    nonmask_ids = (x != mask_token).nonzero(as_tuple=True)[1]
    gen_len = mask_ids.numel()
    x_clone = x.clone()
    for i in range(1, num_iter + 1):
        with torch.inference_mode():
            with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
                logits = model(x_clone)
            s = 1 - (i / num_iter)
            num_masks = int(gen_len * s)
            probs = torch.softmax(logits, dim=-1)
            top_probs, top_preds = torch.max(probs, dim=-1)
            top_probs[:, nonmask_ids] = 1.0
            _, remask_ids = torch.topk(top_probs, k=num_masks, largest=False, dim=1)
            x_clone[:, mask_ids] = top_preds[:, mask_ids]
            x_clone[:, remask_ids.view(-1)] = mask_token
            print(f"Iter {i}:")
            print("start_of_input:", torch.sum(top_preds == 9).item())
            print("end_of_input:", torch.sum(top_preds == 10).item())
            print("start_of_output:", torch.sum(top_preds == 11).item())
            print("end_of_output:", torch.sum(top_preds == 12).item())
            print("row_indicator:", torch.sum(top_preds == 13).item())

            accuracy = torch.sum(y[:, mask_ids] == top_preds[:, mask_ids]) / y[:, mask_ids].size(1)

    return accuracy, top_preds

In [48]:
tokenizer.special_tokens

{'start_of_input': 9,
 'end_of_input': 10,
 'start_of_output': 11,
 'end_of_output': 12,
 'row_indicator': 13,
 'context_indicator': 14,
 'mask_token': 15}

In [68]:
block_size = 2048
tr_dataset = CustomDataset(
    Path("data/pretraining"), block_size, is_train=True, mask_token=mask_token
)

In [69]:
tr_dataloader = DataLoader(tr_dataset, batch_size=1, shuffle=True)
tr_iter = iter(tr_dataloader)

In [98]:
x, y = next(tr_iter)
x, y = x.to("cuda"), y.to("cuda")

In [99]:
print("start_of_input:", torch.sum(y == 9).item())
print("end_of_input:", torch.sum(y == 10).item())
print("start_of_output:", torch.sum(y == 11).item())
print("end_of_output:", torch.sum(y == 12).item())
print("row_indicator:", torch.sum(y == 13).item())

start_of_input: 5
end_of_input: 3
start_of_output: 3
end_of_output: 3
row_indicator: 88


In [100]:
acc, answer = calculate_accuracy(model, x, y, mask_token, 30)
acc, torch.sum(x == 16).item(), torch.sum(y == 9).item()

Iter 1:
start_of_input: 3
end_of_input: 3
start_of_output: 3
end_of_output: 1
row_indicator: 69
Iter 2:
start_of_input: 2
end_of_input: 3
start_of_output: 3
end_of_output: 2
row_indicator: 62
Iter 3:
start_of_input: 2
end_of_input: 3
start_of_output: 3
end_of_output: 2
row_indicator: 63
Iter 4:
start_of_input: 2
end_of_input: 3
start_of_output: 3
end_of_output: 2
row_indicator: 61
Iter 5:
start_of_input: 2
end_of_input: 3
start_of_output: 2
end_of_output: 1
row_indicator: 64
Iter 6:
start_of_input: 2
end_of_input: 3
start_of_output: 3
end_of_output: 2
row_indicator: 59
Iter 7:
start_of_input: 2
end_of_input: 3
start_of_output: 2
end_of_output: 2
row_indicator: 64
Iter 8:
start_of_input: 2
end_of_input: 3
start_of_output: 3
end_of_output: 2
row_indicator: 59
Iter 9:
start_of_input: 2
end_of_input: 2
start_of_output: 2
end_of_output: 2
row_indicator: 64
Iter 10:
start_of_input: 2
end_of_input: 3
start_of_output: 3
end_of_output: 2
row_indicator: 55
Iter 11:
start_of_input: 2
end_of_input

(tensor(0.8486, device='cuda:0'), 1361, 5)

In [443]:
step = 0

In [444]:
window_size = 16
step += 1
start, end = step*window_size, (step+1)*window_size 
mask_ids = (x == mask_token).nonzero(as_tuple=True)[1]
y_mask, a_mask = y[:, mask_ids], answer[:, mask_ids]
print("y:", y_mask[:, start:end])
print("a:", a_mask[:, start:end])

y: tensor([[ 3,  3,  3,  3, 13,  3,  0,  3,  3,  3,  3,  3,  3,  3,  3,  3]],
       device='cuda:0')
a: tensor([[ 3,  3,  3,  3,  3,  3,  3,  3,  3,  3, 13,  3,  3,  3,  3,  3]],
       device='cuda:0')


In [408]:
x[:, :50], y[:, :50]

(tensor([[ 7,  8, 16, 16,  8, 16, 16, 16,  8, 16,  3,  2,  3, 16, 16,  9,  8, 16,
          16, 16, 16,  1, 16, 16, 16, 16,  1, 16, 16,  7,  0, 16]],
        device='cuda:0'),
 tensor([[ 7,  8,  7,  7,  8,  7,  7,  2,  8,  3,  3,  2,  3,  0,  0,  9,  8,  8,
           0, 13,  7,  1,  1,  1,  1,  1,  1,  1,  1,  7,  0,  3]],
        device='cuda:0'))

In [382]:
y[:, :40]

tensor([[ 6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,
         10, 11,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,
          3,  3,  3,  3]], device='cuda:0')

In [225]:
x.size()

torch.Size([1, 2048])

In [374]:
tokenizer.special_tokens

{'start_of_input': 10,
 'end_of_input': 11,
 'start_of_output': 12,
 'end_of_output': 13,
 'row_indicator': 14,
 'context_indicator': 15,
 'mask_token': 16}

In [None]:
def create_context(task, test_index, tokenizer):
    new_task = {}
    new_task["context"] = deepcopy(task["train"])
    train = tokenizer.encode(task["train"])
    test = tokenizer.encode([task["test"][test_index]])
    soi, eoi, soo, eoo = test.index(10), test.index(11), test.index(12), test.index(13)
    test_input, solution = test[soi : eoi + 1], test[soo + 1 : eoo + 1]
    test_input = test_input + [tokenizer.special_tokens["start_of_output"]]
    context = train + test_input
    new_task["context"].append(test_input)
    return context, solution

In [None]:
def calculate_accuracy_inference(
    model, context_tensor, solution, gen_len, mask_token, num_iter
):
    con_len = context_tensor.size(1)
    next_tokens = torch.tensor([mask_token] * gen_len, device="cuda").view(1, -1)
    context_masked = torch.cat((context_tensor, next_tokens), dim=-1)
    for i in range(1, num_iter + 1):
        with torch.inference_mode():
            with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
                logits = model(context_masked)
            s = 1 - (i / num_iter)
            num_masks = int(gen_len * s)
            probs = torch.softmax(logits, dim=-1)[:, con_len:]
            top_probs, top_tokens = torch.max(probs, dim=-1)
            _, mask_indices = torch.topk(top_probs, k=num_masks, largest=False, dim=1)
            next_tokens = top_tokens.clone()
            next_tokens[:, mask_indices.view(-1)] = mask_token
            context_masked = torch.cat((context_tensor, next_tokens), dim=-1)

    accuracy = torch.sum(
        solution.view(1, -1) == top_tokens[:, : solution.size(0)]
    ) / solution.size(0)

    return accuracy, top_tokens

In [383]:
tokenizer.special_tokens

{'start_of_input': 10,
 'end_of_input': 11,
 'start_of_output': 12,
 'end_of_output': 13,
 'row_indicator': 14,
 'context_indicator': 15,
 'mask_token': 16}

In [32]:
from get_tokenizer import Tokenizer

In [33]:
tokenizer = Tokenizer(16)

In [34]:
tokenizer.special_tokens

{'start_of_input': 9,
 'end_of_input': 10,
 'start_of_output': 11,
 'end_of_output': 12,
 'row_indicator': 13,
 'context_indicator': 14,
 'mask_token': 15}

In [35]:
data = [{"input": [[1,2,3], [4, 5, 6], [1, 1, 1]], "output": [[7,8,9], [7,7,6]]},
        {"input": [[1,2,3], [4, 9, 6], [1, 3, 1]], "output": [[2,8,9], [7,7,6]]},]

In [37]:
torch.tensor(tokenizer.encode(data))

tensor([14,  9,  1,  2,  3, 13,  4,  5,  6, 13,  1,  1,  1, 10, 11,  7,  8,  9,
        13,  7,  7,  6, 12,  9,  1,  2,  3, 13,  4,  9,  6, 13,  1,  3,  1, 10,
        11,  2,  8,  9, 13,  7,  7,  6, 12])