In [1]:
%load_ext autoreload
%autoreload 2
import os
os.chdir("../")

In [2]:
from tqdm import tqdm

import torch
import torch.nn.functional as F
from xent.tasks import Closure
from xent.models import M
from xent.lang import X
from xent.dataprocessing import Wikipedia
from xent.config import *

In [3]:
model = M("gpt2", "M0", base="base")
checker_model = M("gpt2", "M1-big", base="closure")

corpus_generator = Wikipedia(split=0.8)
get_test_sample = corpus_generator.get_random_test_text

task = Closure(model)

Loading dataset shards:   0%|          | 0/41 [00:00<?, ?it/s]

In [4]:
numbers = [" 0"," 1"," 2"," 3"," 4"," 5"," 6"," 7"," 8"," 9"," 10"," 11"," 12"," 13"," 14"," 15"," 16"," 17"," 18"," 19"," 20"]
numtoks = torch.tensor([model.tokenize(num).input_ids for num in numbers]).to(device)
logit_vector = torch.zeros(model.model.config.vocab_size, device=device)
logit_vector[numtoks.flatten()] = torch.rand(numtoks.flatten().shape[0], device=device) * 0.2 + 0.9 # example of a random logits vector

In [5]:
checker_model.model.eval()
model_loss_on_nums = []
random_loss_on_nums = []

with torch.no_grad():
    for n in tqdm(range(100)):
        
        synth = task.generate(get_test_sample, space="tokens")
        cut, xlen = task.find_xstring(synth, X.xreturn, return_len=True)
        CUT = cut + xlen + 1 # +1 is for the newline \n
        genshift = 1
        
        logits = checker_model.model(synth).logits
        cut_logits = logits[0, CUT-genshift:-1]
        random_logits = cut_logits.clone()
        for pos in torch.arange(2, cut_logits.shape[0], 4):
            new_random_logits = logit_vector
            new_random_logits[numtoks.flatten()] = torch.rand(numtoks.flatten().shape[0], device=device) * 0.2 + 0.9
            random_logits[pos] = new_random_logits

        model_loss = F.cross_entropy(cut_logits, synth[0, CUT-genshift+1:], reduction="none")
        random_loss = F.cross_entropy(random_logits, synth[0, CUT-genshift+1:], reduction="none")
 
        model_probs = F.softmax(cut_logits, dim=-1)
        highest_prob_tokens = torch.argmax(model_probs, dim=-1)
        model_values, model_indices = torch.topk(cut_logits, k=5)

        random_probs = F.softmax(random_logits, dim=-1)
        random_values, random_indices = torch.topk(random_logits, k=5)
        
        model_tot_loss = model_loss[2::4].mean().item()
        model_loss_on_nums.append(model_tot_loss)

        random_tot_loss = random_loss[2::4].mean().item()
        random_loss_on_nums.append(random_tot_loss)

print(sum(model_loss_on_nums)/len(model_loss_on_nums))
print(sum(random_loss_on_nums)/len(random_loss_on_nums))

100%|██████████| 100/100 [01:23<00:00,  1.20it/s]

1.3356525659561158
9.833741092681885





In [6]:
# # visualize that random logits are being added at the correct position
# for v, i in zip(r_val, r_idxs):
#     for (x, t) in zip(i, v):
#         print((model.detokenize(x), t))
#     print("---------")