In [1]:
%load_ext autoreload
%autoreload 2
import os
os.chdir("../")

In [2]:
import torch
import torch.nn.functional as F
from xent.tasks import Closure
from xent.models import M
from xent.lang import X
from xent.dataprocessing import Wikipedia
from xent.config import *

In [3]:
model = M("gpt2", "M0", base="base")
checker_model = M("gpt2", "M1-evolved", base="closure")

In [4]:
corpus_generator = Wikipedia(split=0.8)
get_test_sample = corpus_generator.get_random_test_text

Loading dataset shards:   0%|          | 0/41 [00:00<?, ?it/s]

In [5]:
task = Closure(model)
synth = task.generate(get_test_sample, space="tokens")

In [6]:
cut = task.find_xstring(synth, X.xreturn)
CUT = cut + 6

In [7]:
checker_model.model.eval()
with torch.no_grad():
    print(synth.shape)
    logits = checker_model.model(synth).logits
    loss = F.cross_entropy(logits[0, CUT:-1], synth[0, CUT+1:], reduction="none")
    # Get the predicted token probabilities
    probs = F.softmax(logits[0, CUT:-1], dim=-1)
    # Get the indices of tokens with highest probabilities
    highest_prob_tokens = torch.argmax(probs, dim=-1)
    # Convert to list for easier inspection

torch.Size([1, 991])


In [8]:
print(model.detokenize(synth[0])[:20])

 Odessa
Ukrainian fo


In [9]:
shift = 2 #distance in tokens from the token to its xent
xlen = 4 #skip the first token which is not predicted by the model
word_origin = synth[0, CUT+xlen::4]
origin = synth[0, CUT+xlen+shift::4]
word_genera = highest_prob_tokens[3::4]
genera = highest_prob_tokens[3+shift::4]
xloss = loss[3+shift::4]

print(f"{'actual':20}  | {'generated':21} | loss")
print("----------------------|-----------------------|--------")
for wo, o, wg, g, l in zip(word_origin, origin, word_genera, genera, xloss):
    wo_str = model.detokenize(wo)
    o_str = model.detokenize(o)
    wg_str = model.detokenize(wg) 
    g_str = model.detokenize(g)
    if '\n' in wo_str: wo_str = "\\n"
    if '\n' in wg_str: wg_str = "\\n"
    print(f"{wo_str:15} {o_str:5} | {wg_str:15} {g_str:5} | {l:.4f}")

actual                | generated             | loss
----------------------|-----------------------|--------
\n               4    | \n               2    | 3.5025
Uk               16   | Uk               21   | 3.6473
rain             0    | rain             11   | 12.7875
ian              0    | ian              5    | 5.4352
 football        8    |  football        9    | 1.5963
ers              3    | ers              5    | 2.1817
\n               5    | \n               1    | 6.2508
Ass              10   | Ass              15   | 4.4861
ociation         2    | ociation         6    | 3.9144
 football        5    |  football        14   | 12.4833
 forwards        11   |  forwards        11   | 1.9272
\n               1    | \n               1    | 0.7969
Uk               3    | Uk               18   | 9.4555
rain             0    | rain             13   | 3.3235
ian              0    | ian              0    | 0.3105
 exp             8    |  exp             10   | 2.7996
atri     

In [10]:
numbers = [" 0"," 1"," 2"," 3"," 4"," 5"," 6"," 7"," 8"," 9"," 10"," 11"," 12"," 13"," 14"," 15"," 16"," 17"," 18"," 19"," 20", "21", "22", "23", "24", "25"]
numtoks = torch.tensor([model.tokenize(num).input_ids for num in numbers]).to(device)

In [11]:
checker_model.model.eval()

synth = task.generate(get_test_sample, space="tokens")
cut, xlen = task.find_xstring(synth, X.xreturn, return_len=True)
CUT = cut + xlen + 1 # +1 is for the newline \n

genshift = 1

with torch.no_grad():
    logits = checker_model.model(synth).logits
    loss = F.cross_entropy(logits[0, CUT-genshift:-1], synth[0, CUT-genshift+1:], reduction="none")
    probs = F.softmax(logits[0, CUT-genshift:-1], dim=-1)
    highest_prob_tokens = torch.argmax(probs, dim=-1)

# print(model.detokenize(highest_prob_tokens[:32]))
# print("--------------------------------------")
# print(model.detokenize(synth[0, CUT:CUT+32]))
# cut = 20

# values, indices = torch.topk(logits[0, CUT+1::4][0], k=10)
values, indices = torch.topk(logits[0, CUT-1::], k=5)
idxvaltup = [(idx, val) for idx, val in zip(indices, values)]

for xl, gent, xlist in zip(loss[:cut], highest_prob_tokens[:cut], idxvaltup):
    gent = "\\n" if model.detokenize(gent) == "\n" else model.detokenize(gent)
    xlist = [(model.detokenize(t),round(v.item(), 5)) for t,v in zip(xlist[0], xlist[1])]
    print(f"token: {gent:<15} | xent: {round(xl.item(), 5):<10} | {xlist}")

xloss = loss[2::4]
xloss.mean().item()

# for idx, val in zip(indices, values):
#     print(model.detokenize(idx))

token:  Embassy        | xent: 0.00035    | [(' Embassy', 28.66735), (' embassy', 20.6739), (' Ambassador', 17.41423), (' Airport', 16.17044), (' Mission', 13.55211)]
token: :               | xent: -0.0       | [(':', -68.56298), ('::', -89.91278), ('":', -90.33371), ("':", -90.45376), ('.:', -91.25053)]
token:  13             | xent: 3.53707    | [(' 13', 9.81415), (' 14', 9.804), (' 15', 8.92762), (' 12', 8.79944), (' 16', 7.39639)]
token: \n              | xent: 0.0        | [('\n', 2.77421), ('\n\n', -13.56106), ('', -14.50026), ('\n\xa0', -19.83238), (' -', -20.3785)]
token:  of             | xent: 3e-05      | [(' of', 25.57396), (' Of', 13.16839), (' to', 13.00054), ('of', 12.42415), (' and', 12.27743)]
token: :               | xent: -0.0       | [(':', -79.53226), ('?:', -104.47343), ('::', -105.02264), ('!:', -105.51805), ('.:', -106.04768)]
token:  4              | xent: 2.8838     | [(' 4', 4.52289), (' 3', 4.47051), (' 2', 4.07417), (' 5', 3.69267), (' 1', 2.85217)]
token: 

3.490494966506958

In [12]:


for xl, gent, xlist in zip(loss[2:cut:4], highest_prob_tokens[2:cut:4], idxvaltup[2::4]):
    gent = "\\n" if model.detokenize(gent) == "\n" else model.detokenize(gent)
    xlist = [(model.detokenize(t),round(v.item(), 5)) for t,v in zip(xlist[0], xlist[1])]
    print(f"token: {gent:<8} | xent: {round(xl.item(), 5):<10} | {xlist}")

xloss = loss[2::4]
xloss.mean().item()



token:  13      | xent: 3.53707    | [(' 13', 9.81415), (' 14', 9.804), (' 15', 8.92762), (' 12', 8.79944), (' 16', 7.39639)]
token:  4       | xent: 2.8838     | [(' 4', 4.52289), (' 3', 4.47051), (' 2', 4.07417), (' 5', 3.69267), (' 1', 2.85217)]
token:  10      | xent: 6.2874     | [(' 10', 7.22309), (' 9', 7.1543), (' 11', 6.82907), (' 8', 6.47371), (' 12', 5.87219)]
token:  0       | xent: 0.21175    | [(' 0', 13.00431), (' 1', 11.36886), (' 2', 9.52392), (' 3', 8.00952), (' 4', 6.96141)]
token:  5       | xent: 1.24775    | [(' 5', 1.27841), (' 4', 1.11363), (' 6', 0.61957), (' 3', 0.2383), (' 7', -0.68766)]
token:  5       | xent: 3.76765    | [(' 5', 13.21952), (' 4', 13.18919), (' 6', 12.62882), (' 3', 12.20369), (' 7', 11.40727)]
token:  6       | xent: 3.97416    | [(' 6', 3.88394), (' 7', 3.64248), (' 5', 3.41548), (' 8', 2.91451), (' 4', 2.38375)]
token:  4       | xent: 1.37544    | [(' 4', 12.92315), (' 3', 12.33092), (' 5', 12.25872), (' 6', 10.10727), (' 2', 10.02455)]

3.490494966506958

In [13]:
print("-"*95+f" TOTAL LOSS ON NUMBERS: {loss[2::4].mean().item()}")
print("   actual -- predicted  | loss on number   | first 5 picked numbers: (num_tok, perplexity)")
print("-"*24+"|"+"-"*18+"|"+"-"*93)
for xl, gent, xlist, sent in zip(loss[2:cut:4], highest_prob_tokens[2:cut:4], idxvaltup[2::4], synth[0, CUT+2::4]):
    gent = "\\n" if model.detokenize(gent) == "\n" else model.detokenize(gent)
    sent = "\\n" if model.detokenize(sent) == "\n" else model.detokenize(sent)
    xlist = [(model.detokenize(t),round(v.item(), 5)) for t,v in zip(xlist[0], xlist[1])]
    print(f"nums: {sent:<3} -- {gent:<10} | xent: {round(xl.item(), 5):<10} | ", end="")
    for xl in xlist:
        print(f"{str(xl):<18}",end="")
    print()

----------------------------------------------------------------------------------------------- TOTAL LOSS ON NUMBERS: 3.490494966506958
   actual -- predicted  | loss on number   | first 5 picked numbers: (num_tok, perplexity)
------------------------|------------------|---------------------------------------------------------------------------------------------
nums:  11 --  13        | xent: 3.53707    | (' 13', 9.81415)  (' 14', 9.804)    (' 15', 8.92762)  (' 12', 8.79944)  (' 16', 7.39639)  
nums:  1  --  4         | xent: 2.8838     | (' 4', 4.52289)   (' 3', 4.47051)   (' 2', 4.07417)   (' 5', 3.69267)   (' 1', 2.85217)   
nums:  5  --  10        | xent: 6.2874     | (' 10', 7.22309)  (' 9', 7.1543)    (' 11', 6.82907)  (' 8', 6.47371)   (' 12', 5.87219)  
nums:  0  --  0         | xent: 0.21175    | (' 0', 13.00431)  (' 1', 11.36886)  (' 2', 9.52392)   (' 3', 8.00952)   (' 4', 6.96141)   
nums:  4  --  5         | xent: 1.24775    | (' 5', 1.27841)   (' 4', 1.11363)   (' 6', 0.