In [1]:
%load_ext autoreload
%autoreload 2
import os
os.chdir("../")

In [2]:
import torch
import torch.nn.functional as F
from xent.tasks import Closure
from xent.models import M
from xent.lang import X
from xent.dataprocessing import Wikipedia, SkeinAdventures
from xent.config import *

In [3]:
model = M("gpt2", "M0", base="base")
checker_model = M("gpt2", "M1-evolved", base="closure")

In [4]:
corpus_generator = SkeinAdventures(split=0.8)
get_test_sample = corpus_generator.get_random_test_text

In [5]:
task = Closure(model)
synth = task.generate(get_test_sample, space="tokens")

In [6]:
cut = task.find_xstring(synth, X.xreturn)
CUT = cut + 6

In [7]:
checker_model.model.eval()
with torch.no_grad():
    print(synth.shape)
    logits = checker_model.model(synth).logits
    loss = F.cross_entropy(logits[0, CUT:-1], synth[0, CUT+1:], reduction="none")
    # Get the predicted token probabilities
    probs = F.softmax(logits[0, CUT:-1], dim=-1)
    # Get the indices of tokens with highest probabilities
    highest_prob_tokens = torch.argmax(probs, dim=-1)
    # Convert to list for easier inspection

torch.Size([1, 681])


In [8]:
print(model.detokenize(synth[0])[:20])

[Themes: modern]

Ok


In [9]:
shift = 2 #distance in tokens from the token to its xent
xlen = 4 #skip the first token which is not predicted by the model
word_origin = synth[0, CUT+xlen::4]
origin = synth[0, CUT+xlen+shift::4]
word_genera = highest_prob_tokens[3::4]
genera = highest_prob_tokens[3+shift::4]
xloss = loss[3+shift::4]

print(f"{'actual':20}  | {'generated':21} | loss")
print("----------------------|-----------------------|--------")
for wo, o, wg, g, l in zip(word_origin, origin, word_genera, genera, xloss):
    wo_str = model.detokenize(wo)
    o_str = model.detokenize(o)
    wg_str = model.detokenize(wg) 
    g_str = model.detokenize(g)
    if '\n' in wo_str: wo_str = "\\n"
    if '\n' in wg_str: wg_str = "\\n"
    print(f"{wo_str:15} {o_str:5} | {wg_str:15} {g_str:5} | {l:.4f}")

actual                | generated             | loss
----------------------|-----------------------|--------
mes              8    | mes              10   | 1.6983
:                4    | ::               2    | 2.6235
 modern          11   | \n               17   | 6.0430
]                6    | ]:               4    | 1.7539
\n               1    | \n               2    | 1.3565
\n               0    | \n               0    | 0.0000
Okay             8    | Okay             22   | 15.4512
,                0    | ,                4    | 7.2654
 what            5    |  what            6    | 1.8765
 will            5    |  will            5    | 1.1765
 you             2    |  you             1    | 1.2818
 use             4    |  use             4    | 1.1973
 as              3    |  as              4    | 1.8025
 a               1    |  a               1    | 0.8028
 weapon          6    |  weapon          4    | 2.1925
?                1    | ?:               3    | 2.3845
 You      

In [10]:
numbers = [" 0"," 1"," 2"," 3"," 4"," 5"," 6"," 7"," 8"," 9"," 10"," 11"," 12"," 13"," 14"," 15"," 16"," 17"," 18"," 19"," 20", "21", "22", "23", "24", "25"]
numtoks = torch.tensor([model.tokenize(num).input_ids for num in numbers]).to(device)

In [11]:
checker_model.model.eval()

synth = task.generate(get_test_sample, space="tokens")
cut, xlen = task.find_xstring(synth, X.xreturn, return_len=True)
CUT = cut + xlen + 1 # +1 is for the newline \n

genshift = 1

with torch.no_grad():
    logits = checker_model.model(synth).logits
    loss = F.cross_entropy(logits[0, CUT-genshift:-1], synth[0, CUT-genshift+1:], reduction="none")
    probs = F.softmax(logits[0, CUT-genshift:-1], dim=-1)
    highest_prob_tokens = torch.argmax(probs, dim=-1)

# print(model.detokenize(highest_prob_tokens[:32]))
# print("--------------------------------------")
# print(model.detokenize(synth[0, CUT:CUT+32]))
# cut = 20

# values, indices = torch.topk(logits[0, CUT+1::4][0], k=10)
values, indices = torch.topk(logits[0, CUT-1::], k=5)
idxvaltup = [(idx, val) for idx, val in zip(indices, values)]

for xl, gent, xlist in zip(loss[:cut], highest_prob_tokens[:cut], idxvaltup):
    gent = "\\n" if model.detokenize(gent) == "\n" else model.detokenize(gent)
    xlist = [(model.detokenize(t),round(v.item(), 5)) for t,v in zip(xlist[0], xlist[1])]
    print(f"token: {gent:<15} | xent: {round(xl.item(), 5):<10} | {xlist}")

xloss = loss[2::4]
xloss.mean().item()

# for idx, val in zip(indices, values):
#     print(model.detokenize(idx))

token:  do             | xent: 0.0        | [(' do', 0.70465), (' don', -13.89131), (' doing', -13.99693), (' done', -14.19346), ('do', -14.43748)]
token: :               | xent: -0.0       | [(':', -76.80453), ('.:', -98.21325), ("':", -100.04868), ('\n', -100.5218), ('":', -101.20831)]
token:  7              | xent: 3.11536    | [(' 7', 21.17551), (' 6', 20.70576), (' 8', 20.47319), (' 5', 18.89612), (' 9', 18.40056)]
token: \n              | xent: 0.0        | [('\n', 26.60911), ('\n\n', 10.68337), ('', 9.75642), (' |', 6.60054), (' -', 4.08818)]
token: ,               | xent: 6e-05      | [(',', 1.12814), (' do', -10.2446), (' only', -11.30838), (' then', -11.37747), ("'t", -11.65106)]
token: :               | xent: -0.0       | [(':', -36.99671), ('.:', -58.23947), ('\n', -59.56508), ("':", -59.8117), (' :', -60.61139)]
token:  2              | xent: 1.18578    | [(' 2', 6.86219), (' 3', 6.3443), (' 1', 5.40201), (' 4', 4.67621), (' 5', 1.88768)]
token: \n              | xent: -0.

2.7493796348571777

In [12]:


for xl, gent, xlist in zip(loss[2:cut:4], highest_prob_tokens[2:cut:4], idxvaltup[2::4]):
    gent = "\\n" if model.detokenize(gent) == "\n" else model.detokenize(gent)
    xlist = [(model.detokenize(t),round(v.item(), 5)) for t,v in zip(xlist[0], xlist[1])]
    print(f"token: {gent:<8} | xent: {round(xl.item(), 5):<10} | {xlist}")

xloss = loss[2::4]
xloss.mean().item()



token:  7       | xent: 3.11536    | [(' 7', 21.17551), (' 6', 20.70576), (' 8', 20.47319), (' 5', 18.89612), (' 9', 18.40056)]
token:  2       | xent: 1.18578    | [(' 2', 6.86219), (' 3', 6.3443), (' 1', 5.40201), (' 4', 4.67621), (' 5', 1.88768)]
token:  8       | xent: 1.27894    | [(' 8', 8.01476), (' 7', 7.78295), (' 9', 7.17071), (' 6', 6.89117), (' 10', 6.41702)]
token:  0       | xent: 3.67148    | [(' 0', 1.61399), (' 1', 0.39264), (' 2', -1.76876), (' 3', -3.76555), (' 4', -5.25219)]
token:  10      | xent: 1.11828    | [(' 10', 18.21907), (' 9', 18.08381), (' 8', 17.24315), (' 11', 17.06043), (' 7', 15.26709)]
token:  8       | xent: 3.02015    | [(' 8', -3.48154), (' 9', -3.68496), (' 7', -3.70479), (' 10', -3.9458), (' 6', -4.11503)]
token:  1       | xent: 1.86178    | [(' 1', 5.28553), (' 2', 4.82629), (' 0', 4.326), (' 3', 4.009), (' 4', 3.17111)]
token:  6       | xent: 2.67691    | [(' 6', 3.78728), (' 5', 3.28833), (' 7', 3.25789), (' 8', 2.21254), (' 4', 2.09294)]


2.7493796348571777

In [13]:
print("-"*95+f" TOTAL LOSS ON NUMBERS: {loss[2::4].mean().item()}")
print("   actual -- predicted  | loss on number   | first 5 picked numbers: (num_tok, perplexity)")
print("-"*24+"|"+"-"*18+"|"+"-"*93)
for xl, gent, xlist, sent in zip(loss[2:cut:4], highest_prob_tokens[2:cut:4], idxvaltup[2::4], synth[0, CUT+2::4]):
    gent = "\\n" if model.detokenize(gent) == "\n" else model.detokenize(gent)
    sent = "\\n" if model.detokenize(sent) == "\n" else model.detokenize(sent)
    xlist = [(model.detokenize(t),round(v.item(), 5)) for t,v in zip(xlist[0], xlist[1])]
    print(f"nums: {sent:<3} -- {gent:<10} | xent: {round(xl.item(), 5):<10} | ", end="")
    for xl in xlist:
        print(f"{str(xl):<18}",end="")
    print()

----------------------------------------------------------------------------------------------- TOTAL LOSS ON NUMBERS: 2.7493796348571777
   actual -- predicted  | loss on number   | first 5 picked numbers: (num_tok, perplexity)
------------------------|------------------|---------------------------------------------------------------------------------------------
nums:  5  --  7         | xent: 3.11536    | (' 7', 21.17551)  (' 6', 20.70576)  (' 8', 20.47319)  (' 5', 18.89612)  (' 9', 18.40056)  
nums:  3  --  2         | xent: 1.18578    | (' 2', 6.86219)   (' 3', 6.3443)    (' 1', 5.40201)   (' 4', 4.67621)   (' 5', 1.88768)   
nums:  7  --  8         | xent: 1.27894    | (' 8', 8.01476)   (' 7', 7.78295)   (' 9', 7.17071)   (' 6', 6.89117)   (' 10', 6.41702)  
nums:  2  --  0         | xent: 3.67148    | (' 0', 1.61399)   (' 1', 0.39264)   (' 2', -1.76876)  (' 3', -3.76555)  (' 4', -5.25219)  
nums:  9  --  10        | xent: 1.11828    | (' 10', 18.21907) (' 9', 18.08381)  (' 8', 1