In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from transformers import GPT2LMHeadModel, GPT2Tokenizer

%matplotlib inline
%reload_ext autoreload
%autoreload 2

from src import *

In [2]:
# Load resources from HF
hf_model = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained(
    "openai-community/gpt2", clean_up_tokenization_spaces=False
)

In [3]:
# Load weights in model
num_heads = 12
embed_dim = 768
context_len = 1024
device = 'mps'

model = modules.GPT2(tokenizer.vocab_size, embed_dim, context_len, num_heads)
model.load_state_dict(hf_model.state_dict())

# Assign to variable to silence output
_ = hf_model.to('mps')
_ = model.to('mps')

In [4]:
prompt = "Captain's Note (8/9/74, 13980 leagues under sea level): Today the crew"

g = torch.Generator(device=device).manual_seed(42)
completions = pipeline.generate_completion(
    prompt,
    tokenizer,
    model,
    generator=g,
    loading_bar_prefix="Our Completions",
    device=device,
)
g = torch.Generator(device=device).manual_seed(42)
hf_completions = pipeline.generate_completion(
    prompt,
    tokenizer,
    hf_model,
    generator=g,
    loading_bar_prefix=" HF Completions",
    device=device,
)

completion_padding = " " * 7
for i, (hf_completion, completion) in enumerate(zip(hf_completions, completions)):
    print(f"\nCompletion {i}:\n============")
    print(f"Ours  : {completion.replace('\n\n', '\n' + completion_padding)}")
    print(completion_padding + "---------------------------------------------")
    print(f"HF    : {hf_completion.replace('\n\n', '\n' + completion_padding)}")

Our Completions (mps): 100%|██████████| 40/40 [00:04<00:00,  8.91it/s]
 HF Completions (mps): 100%|██████████| 40/40 [00:05<00:00,  7.59it/s]


Completion 0:
Ours  : Captain's Note (8/9/74, 13980 leagues under sea level): Today the crew is ashore in Volk Flora. The crew has sufficient equipment and is encountering a Pokémon that can help them overcome its 'Smith I'. This mysteriously happens to the Pokémon, too. A message of
       ---------------------------------------------
HF    : Captain's Note (8/9/74, 13980 leagues under sea level): Today the crew is ashore in Volk Flora. The crew has sufficient equipment and is encountering a Pokémon that can help them overcome its 'Smith I'. This mysteriously happens to the Pokémon, too. A message of

Completion 1:
Ours  : Captain's Note (8/9/74, 13980 leagues under sea level): Today the crew is learning about the current situation of the Black Sea fleet and birth of the new captain. Sword flipper David Crocker arrives at Marengo Engineering and joins the other leaders of the critical fleet to
       ---------------------------------------------
HF    : Captain's Note (8/9/74, 13980 


