Sampling from the model.

In [None]:
import torch
import torch.nn.functional as F
torch.manual_seed(123)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

from models import Transformer
from load_hp_data import load_data

In [2]:
data_file = 'hp_dataset/01 Harry Potter and the Sorcerers Stone.txt'
text, vocab_size, encode, decode = load_data(data_file)
data = torch.tensor(encode(text), dtype=torch.long, device=device)

length of raw dataset in characters:  439478
Characters present in the raw dataset:  
 !'()*,-.0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz–—‘’“”…
Vocabulary size of the raw dataset:  82


length of dataset after augmentation in characters:  435847
Characters present in the augmented dataset:  
 !"',-.:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz…
Vocabulary size of the augmented dataset:  64




Load the trained model.

In [None]:
context_len = 256
embed_dim = 384
n_heads = 6
n_blocks = 6

loaded_model = Transformer(vocab_size, context_len, embed_dim, n_heads, n_blocks)

SAVE_PATH = 'saved_models/hp_model_cl_256_ed_384_nh_6_nb_6_bs_64_dr_0.2_lr_3e-4_max_iters_5000.pth'
loaded_model.load_state_dict(torch.load(SAVE_PATH))
loaded_model.eval()
loaded_model.to(device) 

Transformer(
  (token_embedding): Embedding(64, 384)
  (position_embedding): Embedding(256, 384)
  (embedding_dropout): Dropout(p=0.1, inplace=False)
  (blocks): ModuleList(
    (0-5): 6 x TransformerBlock(
      (attention_heads): ModuleList(
        (0-5): 6 x SingleHeadAttention(
          (k): Linear(in_features=384, out_features=64, bias=False)
          (q): Linear(in_features=384, out_features=64, bias=False)
          (v): Linear(in_features=384, out_features=64, bias=False)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (attention_projection): Linear(in_features=384, out_features=384, bias=True)
      (attention_dropout): Dropout(p=0.1, inplace=False)
      (ffn): Sequential(
        (0): Linear(in_features=384, out_features=1536, bias=True)
        (1): ReLU()
        (2): Linear(in_features=1536, out_features=384, bias=True)
      )
      (ffn_dropout): Dropout(p=0.1, inplace=False)
      (ln1): LayerNorm((384,), eps=1e-05, elementwise_affine=True

In [None]:
def generate(model, idx, max_new_tokens, context_len):
    model.eval()
    device = next(model.parameters()).device
    
    with torch.no_grad():
        for _ in range(max_new_tokens):
            # Crop the context to at most context_len tokens
            idx_cond = idx[:, -context_len:]

            logits = model(idx_cond)
            # logits = model(idx_cond.to(device))
            
            # Focus only on the logit for the last time step
            logits = logits[:, -1, :] # Becomes (B, C)
            
            probs = F.softmax(logits, dim=-1) # (B, C)
            
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
            
    return idx

Sampling from the model with blank context.

In [7]:
print("--- Generating from a blank context ---")

start_context = torch.zeros((1, 1), dtype=torch.long, device=device)

tokens_to_generate = 10000

generated_ids = generate(
    model=loaded_model,
    idx=start_context,
    max_new_tokens=tokens_to_generate,
    context_len=256
)

# Decode the generated token IDs and print the result
print(decode(generated_ids[0].tolist()))

--- Generating from a blank context ---

"Good-bye," said Ron. "Coddle, you wouldn't me them, I've just carry ourselves and pelay you father green my."
"I've done out a month," said Uncle Vernon, racing the boy he'd explain, led because "I I gotta bit me to cupboard for Gryffindor," said Mr. Ollivander. "Got of of the mirror magic and don't you."
"Never really speak," Harry cat to Ron and Hermione spoke. "A pair of you see. I means thinks…"
He sat down down on the fireplashing what he'd hand. He mummed to argue dark to them and down at the forth-nobody else when outside have draid, "It's guarding through - I Got ages all right, out a batch. Keful tried out, not want a who cold second, don't duck terrible - a thought of circus I can you can't see what do it's somethin'. No really An' the Snitch appeare, wouldn't it at bit real myself have people in my parents your fink."
One knocked morning. It was a very centaur it. Like numbed asked if they led-got it, he wouldn't be inter the pronpap