In [2]:
import torch
from torch.utils.data import Subset
from transformers import GPT2Tokenizer

%reload_ext autoreload
%autoreload 2

from src import data, modules, pipeline

In [2]:
dataset = data.TinyStoriesDataset(1024, num_stories=50000)

Tokenizing Stories: 100%|██████████| 50000/50000 [00:34<00:00, 1434.81 stories/s]


In [6]:
# Sanity check training on a single batch
# It's slow but it seems to be "successfully" overfitting... Will need to move to
# A GPU to really know
batch_size = 16
num_heads = 12
embed_dim = 768
context_len = 1024
vocab_size = 50257
device = "cuda"

g = torch.Generator().manual_seed(42)
model = modules.GPT2(vocab_size, embed_dim, context_len, num_heads)
model.to(device)
single_batch_ds = Subset(dataset, list(range(batch_size)))

pipeline.train_gpt2(
    model,
    dataset,
    batch_size=batch_size,
    logging_interval=10,
    num_epochs=2,
    device=device,
    generator=g,
)

Epoch:   0%|          | 0/2 [00:00<?, ? epochs/s]

Minibatch:   0%|          | 0/647 [00:00<?, ?it/s]

[Epoch  0, Minibatch  0]: Loss=10.9714
[Epoch  0, Minibatch   10]: Loss=11.8661
[Epoch  0, Minibatch   20]: Loss=6.2461
[Epoch  0, Minibatch   30]: Loss=5.7567
[Epoch  0, Minibatch   40]: Loss=5.5397
[Epoch  0, Minibatch   50]: Loss=5.2969
[Epoch  0, Minibatch   60]: Loss=5.1428
[Epoch  0, Minibatch   70]: Loss=4.9451
[Epoch  0, Minibatch   80]: Loss=4.7182
[Epoch  0, Minibatch   90]: Loss=4.5752
[Epoch  0, Minibatch  100]: Loss=4.5005
[Epoch  0, Minibatch  110]: Loss=4.3930
[Epoch  0, Minibatch  120]: Loss=4.2081
[Epoch  0, Minibatch  130]: Loss=4.2648
[Epoch  0, Minibatch  140]: Loss=4.1489
[Epoch  0, Minibatch  150]: Loss=4.1575
[Epoch  0, Minibatch  160]: Loss=4.0872
[Epoch  0, Minibatch  170]: Loss=4.0872
[Epoch  0, Minibatch  180]: Loss=4.1213
[Epoch  0, Minibatch  190]: Loss=4.0085
[Epoch  0, Minibatch  200]: Loss=4.0446
[Epoch  0, Minibatch  210]: Loss=3.9964
[Epoch  0, Minibatch  220]: Loss=3.8128
[Epoch  0, Minibatch  230]: Loss=3.9209
[Epoch  0, Minibatch  240]: Loss=3.8599


Minibatch:   0%|          | 0/647 [00:00<?, ?it/s]

[Epoch  1, Minibatch  0]: Loss=2.9509
[Epoch  1, Minibatch   10]: Loss=2.9785
[Epoch  1, Minibatch   20]: Loss=2.9265
[Epoch  1, Minibatch   30]: Loss=3.0180
[Epoch  1, Minibatch   40]: Loss=2.9483
[Epoch  1, Minibatch   50]: Loss=3.0269
[Epoch  1, Minibatch   60]: Loss=2.9856
[Epoch  1, Minibatch   70]: Loss=3.0272
[Epoch  1, Minibatch   80]: Loss=2.8430
[Epoch  1, Minibatch   90]: Loss=2.8608
[Epoch  1, Minibatch  100]: Loss=2.8182
[Epoch  1, Minibatch  110]: Loss=2.6982
[Epoch  1, Minibatch  120]: Loss=2.8824
[Epoch  1, Minibatch  130]: Loss=2.9007
[Epoch  1, Minibatch  140]: Loss=2.8973
[Epoch  1, Minibatch  150]: Loss=2.6070
[Epoch  1, Minibatch  160]: Loss=2.7794
[Epoch  1, Minibatch  170]: Loss=2.7807
[Epoch  1, Minibatch  180]: Loss=2.5002
[Epoch  1, Minibatch  190]: Loss=2.7908
[Epoch  1, Minibatch  200]: Loss=2.6032
[Epoch  1, Minibatch  210]: Loss=2.6693
[Epoch  1, Minibatch  220]: Loss=2.7044
[Epoch  1, Minibatch  230]: Loss=2.6698
[Epoch  1, Minibatch  240]: Loss=2.6012
[E

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained(
    "openai-community/gpt2", clean_up_tokenization_spaces=False
)
g = torch.Generator(device=device).manual_seed(42)
completions = pipeline.generate_completion(
    'Once upon a time,',
    tokenizer,
    model,
    generator=g,
    loading_bar_prefix="Our Completions",
    num_completions=3,
    completion_len=1000,
    device=device,
)

In [9]:
completions

['Once upon a time, there was no lumber anymore. One day, a little girl named Lily was playing in her backyard when suddenly fell asleep under the field. She was looking for a little puppy named Max. Lily had lost',
 "Once upon a time, there was a little girl named Lily. She loved taking her toys all day long. One day, Lily's mom told her that they were going to ride a map! Lily was so happy that she",
 "Once upon a time, there was a girl named Lily. She loved to play in the sunshine and look for clear about playing in the sunshine. Suddenly, Lily's mom told her she was so excited that Lily quickly had an"]