Sampling from the model.

In [34]:
import torch
import torch.nn.functional as F
import sys
torch.manual_seed(123)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

from models import Transformer
from load_hp_data import load_data

In [35]:
data_file = 'hp_dataset/01 Harry Potter and the Sorcerers Stone.txt'
text, vocab_size, encode, decode = load_data(data_file)
data = torch.tensor(encode(text), dtype=torch.long, device=device)

length of raw dataset in characters:  439478
Characters present in the raw dataset:  
 !'()*,-.0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz–—‘’“”…
Vocabulary size of the raw dataset:  82


length of dataset after augmentation in characters:  435847
Characters present in the augmented dataset:  
 !"',-.:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz…
Vocabulary size of the augmented dataset:  64




Load the trained model.

In [36]:
context_len = 256
embed_dim = 384
n_heads = 6
n_blocks = 6

loaded_model = Transformer(vocab_size, context_len, embed_dim, n_heads, n_blocks)

SAVE_PATH = 'saved_models/hp_model_cl_256_ed_384_nh_6_nb_6_bs_64_dr_0.2_lr_3e-4_max_iters_5000.pth'
# loaded_model.load_state_dict(torch.load(SAVE_PATH))
loaded_model.load_state_dict(torch.load(SAVE_PATH, map_location=device))
loaded_model.eval()
loaded_model.to(device) 

Transformer(
  (token_embedding): Embedding(64, 384)
  (position_embedding): Embedding(256, 384)
  (embedding_dropout): Dropout(p=0.1, inplace=False)
  (blocks): ModuleList(
    (0-5): 6 x TransformerBlock(
      (attention_heads): ModuleList(
        (0-5): 6 x SingleHeadAttention(
          (k): Linear(in_features=384, out_features=64, bias=False)
          (q): Linear(in_features=384, out_features=64, bias=False)
          (v): Linear(in_features=384, out_features=64, bias=False)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (attention_projection): Linear(in_features=384, out_features=384, bias=True)
      (attention_dropout): Dropout(p=0.1, inplace=False)
      (ffn): Sequential(
        (0): Linear(in_features=384, out_features=1536, bias=True)
        (1): ReLU()
        (2): Linear(in_features=1536, out_features=384, bias=True)
      )
      (ffn_dropout): Dropout(p=0.1, inplace=False)
      (ln1): LayerNorm((384,), eps=1e-05, elementwise_affine=True

In [37]:
def generate(model, idx, max_new_tokens, context_len, decode_func):
    model.eval()
    device = next(model.parameters()).device
    
    with torch.no_grad():
        for _ in range(max_new_tokens):
            # Crop the context to at most context_len tokens
            idx_cond = idx[:, -context_len:]

            logits = model(idx_cond)
            
            # Focus only on the logit for the last time step
            logits = logits[:, -1, :] # Becomes (B, C)
            
            probs = F.softmax(logits, dim=-1) # (B, C)
            
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            
            # Real-time printing: decode the token and print immediately
            token_id = idx_next.item()
            print(decode_func([token_id]), end="", flush=True)
            
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
            
    return idx

Sampling from the model with blank context.

In [38]:
print("--- Generating from a blank context ---")

# Ensure device is set (from your previous code)
device = 'cpu' 

start_context = torch.zeros((1, 1), dtype=torch.long, device=device)
tokens_to_generate = 200

# We pass 'decode' as an argument so the function can use it
generated_ids = generate(
    model=loaded_model,
    idx=start_context,
    max_new_tokens=tokens_to_generate,
    context_len=256,
    decode_func=decode # Passing your existing decode function here
)

print("\n--- Generation Complete ---")

--- Generating from a blank context ---
"Brewill time three-heak," said great on the chance into the air.
"Ah, Malfoy!" shouted. "Can I can't stood get platformward Harry."
"I'm not must to lose without anyway, every now," said Hagrid, "but
--- Generation Complete ---
