In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path

import lightning as L
import torch as t
from data import EOT_TOKEN, WikitextDataset
from model import LangNet

In [3]:
DATAROOT = Path.home() / "mldata" / "wikitext-2-raw"
RUNROOT = Path.home() / "mlruns" / "makemore-2"

In [4]:
# Somehow figure out how to get this from the checkpoint's metadata
context_len = 5

In [5]:
dataset = WikitextDataset(
    DATAROOT / "wiki.train.raw", 
    context_len=context_len
)

print(len(dataset))

2075677


In [6]:
# ckpt_path = RUNROOT / "Makemore2" / "u9g12aj6" / "checkpoints" / "epoch=2-step=6228.ckpt"
ckpt_path = RUNROOT / "Makemore2" / "colorful-snow-16" / "checkpoints" / "model.ckpt"
lang_net = LangNet.load_from_checkpoint(ckpt_path)

In [113]:
from dataclasses import dataclass

@dataclass
class DebugInfo:
    rank: int
    prob: float
    idx: int
    top_idxs: list[int]
    top_probs: list[float]
    context: list[int]

Choosing the next token by sampling the entire vocabulary based on the probability outputted by the net. This does not give very good results because the system might choose some token with a very low probability.

In [114]:
vocab = dataset.vocab
eot_tok_idx = vocab.idx_of(EOT_TOKEN)
context = [eot_tok_idx] * context_len
predicted_idx = None
gen_text = []
debug_info = []
with t.no_grad():
    while predicted_idx != eot_tok_idx and len(gen_text) < 100:
        input = t.tensor(context).unsqueeze(0).to(lang_net.device)
        logits = lang_net.model.forward(input)
        probs = t.softmax(logits, 0)
        
        predicted_idx = t.multinomial(probs, 1)[0].item()

        # Debug
        vals, idxs = t.sort(probs, descending=True)
        rank = t.nonzero(idxs == predicted_idx).squeeze()
        prob = probs[predicted_idx]
        top_idxs = idxs.detach().numpy()[:10]
        top_probs = vals.detach().numpy()[:10]
        debug_info.append(DebugInfo(rank=rank, prob=prob, idx=predicted_idx, top_idxs=top_idxs, top_probs=top_probs, context=context))

        word = vocab.word_at(predicted_idx)
        gen_text.append(word)
        context = context[1:] + [predicted_idx]

print(" ".join(gen_text))
    

Marella mentioned Arton that 9 , Eshkol is bursary only . mélodie was similar completed , but would not outmoded during the name centre of us , but again , Vozdooshnykh of the United North on the 19th 2006 ; one , but without by the team mine 's armoured ) , granted ( regiment ft ) over children of its redundant , ALPAC they his west . As <0>


In [115]:
for di in debug_info:
    print(di)
    print("\n")

DebugInfo(rank=tensor(66897), prob=tensor(5.2043e-07), idx=73078, top_idxs=array([ 85,   1, 690,  62, 931, 127, 179, 146, 276, 538]), top_probs=array([0.20140465, 0.17070054, 0.06795194, 0.02305998, 0.02288323,
       0.01178322, 0.01056079, 0.01034859, 0.00708862, 0.00698738],
      dtype=float32), context=[0, 0, 0, 0, 0])


DebugInfo(rank=tensor(840), prob=tensor(9.8406e-05), idx=5854, top_idxs=array([ 13,  37, 132,  16,  62, 135,  85,  43,  10,  26]), top_probs=array([0.02212626, 0.01831166, 0.01516041, 0.01448331, 0.01347196,
       0.01053168, 0.01032   , 0.00998903, 0.00941204, 0.0092516 ],
      dtype=float32), context=[0, 0, 0, 0, 73078])


DebugInfo(rank=tensor(23621), prob=tensor(6.5174e-06), idx=44544, top_idxs=array([ 10, 132,  13,  16, 359,  37, 135, 715,  17,  79]), top_probs=array([0.04668538, 0.0323087 , 0.02991573, 0.02142455, 0.01558523,
       0.01515251, 0.01278486, 0.0116325 , 0.01157354, 0.01024644],
      dtype=float32), context=[0, 0, 0, 73078, 5854])


DebugInf

Choosing the highest probability token is also not a good idea because these tend to be the most frequently used words regardless of the context.

In [111]:
text_idxs = []
vocab = dataset.vocab
eot_tok_idx = vocab.idx_of(EOT_TOKEN)
context = [eot_tok_idx] * context_len
predicted_idx = None
with t.no_grad():
    while predicted_idx != eot_tok_idx and len(text_idxs) < 100:
        input = t.tensor(context).unsqueeze(0).to(lang_net.device)
        logits = lang_net.model.forward(input)
        probs = t.softmax(logits, 0)
        
        predicted_idx = t.argmax(probs).item()
        word = vocab.word_at(predicted_idx)
        print(word, end=" ")

        text_idxs.append(predicted_idx)
        context = context[1:] + [predicted_idx]

The first of the city , and the city of the game , and the city of the game , and the city of the game , and the city of the game , and the city of the game , and the city of the game , and the city of the game , and the city of the game , and the city of the game , and the city of the game , and the city of the game , and the city of the game , and the city of the game , and the city 

A good solution seems to be to sample from the top K tokens.

In [136]:
def predict(top_k):
    text_idxs = []
    vocab = dataset.vocab
    eot_tok_idx = vocab.idx_of(EOT_TOKEN)
    context = [eot_tok_idx] * context_len
    predicted_idx = None
    with t.no_grad():
        while predicted_idx != eot_tok_idx and len(text_idxs) < 100:
            input = t.tensor(context).unsqueeze(0).to(lang_net.device)
            logits = lang_net.model.forward(input)
            probs = t.softmax(logits, 0)
            
            vals, idxs = t.sort(probs, descending=True)
            top_idxs = idxs.detach().numpy()[:top_k]
            top_probs = vals.detach().numpy()[:top_k]
            xx = t.multinomial(t.tensor(top_probs), 1)
            predicted_idx = top_idxs[xx]
            word = vocab.word_at(predicted_idx)
            print(word, end=" ")

            text_idxs.append(predicted_idx)
            context = context[1:] + [predicted_idx]

In [137]:
predict(10)

The song was a first of the second century , the first of a large in of the season of the United States , and the first of the episode . A , which he also his as a " , the " " . " is now in the " of the end , which was one of the game , the song , and the episode , and has been a large and the album . <0> 

In [138]:
predict(10)

At the same season . He has been a " for " a single " " , " is to " . The same and a new the end @-@ and the new , was the most . As , the second and the film of a few . <0> 

In [139]:
predict(100)

" ( 3 – 2 ) . The song was Hero by the group , it was similar to the old role , . The storm 's more the church . The family that " was not a run @-@ in November species . In 8 between the time and has been with . <0> 

In [140]:
predict(50)

At album , the north and its " ) , an and his first in many was sent to the team day , but many a major , where him and the United with the second , which had been found in 2010 's third , and is a few of three in two and the state of the village of the following in which he of the National film . The show of the new system of 1 @,@ 10 's " . There was not that the last video " I could its " and " his " 

In [141]:
predict(1000)

Later the greatest a was role as the Hurricane and originally considered , and an again to four eastern sense . The New York of an City being of 25 are end . According to the result was released prior to the Scientology . The food 's start of generally troops . It game that the woman above and than hours BBC - , she until 11 and due to the state of western @.@ and completed machine many of regions ( Church ) of America @-@ two main and point units to Indian Hall Carolina ( , but a 