In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import GPT2Tokenizer, GPT2LMHeadModel

In [None]:
"""To check our implementation, we will try to load the GPT-2 weights from huggingface, alongside the official implementation 
from the `transformers` library, and see if we get the same token probabilities for an example prompt.
"""

tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
official_gpt_model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2")

our_gpt_model = GPT.from_pretrained("openai-community/gpt2")

prompt = "Have you ever seen the rain?"
token_ids = tokenizer.encode(prompt, return_tensors="pt")

our_logits = our_gpt_model(token_ids)
official_logits = official_gpt_model(token_ids).logits

print("Our logits:")
print(our_logits[0, -1, :10])
print("Official logits:")
print(official_logits[0, -1, :10])

print("=" * 60)
print("Mean difference (this should be close to 0):")
print((our_logits - official_logits).abs().mean())

In [None]:
"""
Generation
Implement top-k sampling: from the token probabilities, keep only the top k, and sample one of them according to the probabilities

compare generated samples from the gpt model you implemented to the official gpt implementation. Are they both intelligible?
"""

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # You can configure colab to run on a T4 GPU for faster generation
tokens = [15496, 11, 314, 1101, 257, 3303, 2746, 11] # "Hello, I'm a language model,"
tokens = torch.tensor(tokens, dtype=torch.long) # (8,)
tokens = tokens.unsqueeze(0).repeat(5, 1) # Generate in a batch of 5. Shape is (5, 8)
x = tokens.to(device)

# Move the model to the correct device
our_gpt_model.to(device)

k = 20

print("Generating on device:", device)
# generate!
our_gpt_model.eval()
while x.size(1) < 30: # max_length=30
    # forward the model to get the logits
    with torch.no_grad():
        logits = our_gpt_model(x)
        # only care about the last token
        logits = logits[:, -1, :] # Shape (batch, vocab_size)

        # Implement top-k masking: set non-top-k logits to -inf
        topk_values, _ = torch.topk(logits, k, dim=-1)
        kth_largest_value = topk_values[:, -1].unsqueeze(-1) # shape (batch, 1)
        logits_masked = torch.where(logits < kth_largest_value, torch.full_like(logits, float('-inf')), logits)

        # now do softmax
        probs = F.softmax(logits_masked, dim=-1) # Softmax on (batch, vocab_size)

        # sample according to prob
        next = torch.multinomial(probs, num_samples=1)
        x = torch.cat((x, next), dim=1)

print(tokenizer.batch_decode(x))
[print(s) for s in tokenizer.batch_decode(x)]