# Tiny Shakespeare Language Model

GPT-like language model that is trained on the tinyshakespeare dataset. This notebook was written while following Karpathy's 'Let's build GPT' vide. The only notable difference is the use of SentencePiece tokenization instead of a character level tokenization.

In [1]:
# We always start with a dataset to train on. Let's download the tiny shakespeare dataset
!mkdir data
!curl https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt -o data/tinyshakespeare

mkdir: data: File exists
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 1089k  100 1089k    0     0  2638k      0 --:--:-- --:--:-- --:--:-- 2643k


In [2]:
import torch
import torch.nn as nn

from zeptogpt.model import SimpleGPT

In [3]:
with open('data/tinyshakespeare') as f:
    text = f.read()

In [4]:
!mkdir models
import sentencepiece as spm
spm.SentencePieceTrainer.train(input='data/tinyshakespeare',
                               model_prefix='models/shakespeare_tokenizer_model',
                               vocab_size=1000,
                               character_coverage=1.0,
                               model_type='unigram',
                               remove_extra_whitespaces=False,
                               user_defined_symbols=["\n", "\r"])

mkdir: models: File exists


sentencepiece_trainer.cc(78) LOG(INFO) Starts training with : 
trainer_spec {
  input: data/tinyshakespeare
  input_format: 
  model_prefix: models/shakespeare_tokenizer_model
  model_type: UNIGRAM
  vocab_size: 1000
  self_test_sample_size: 0
  character_coverage: 1
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  pretokenization_delimiter: 
  treat_whitespace_as_suffix: 0
  allow_whitespace_only_pieces: 0
  user_defined_symbols: 

  user_defined_symbols: 
  required_chars: 
  byte_fallback: 0
  vocabulary_output_piece_score: 1
  train_extremely_large_corpus: 0
  seed_sentencepieces_file: 
  hard_vocab_limit: 1
  use_all_vocab: 0
  unk_id: 0
  bos_id: 1
  eos_id: 2
  pad_id: -1
  unk_piece: <unk>
  bos_piece: <s>
  eos_piece: </s

In [5]:
sp = spm.SentencePieceProcessor()
sp.load('models/shakespeare_tokenizer_model.model')
vocab_size = sp.get_piece_size()

In [6]:
data = torch.tensor(sp.encode(text))

traindata = data[:int(0.9 * len(data))]
testdata = data[int(0.9 * len(data)):]

torch.manual_seed(1337)
def get_batch(data, device, batch_size, block_size):
    ix = torch.randint(0, len(data) - block_size, (batch_size, ))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x.to(device), y.to(device)

get_batch(traindata, 'cpu', 4, 8)

(tensor([[  3, 175,  13,  66, 610,  26,  27, 200],
         [ 97, 128,  10,   5,  77,  11,  46, 109],
         [ 39,  16,  12, 709,  30,   3,   3, 191],
         [101, 182,  20, 242,   5,  94, 388, 119]]),
 tensor([[175,  13,  66, 610,  26,  27, 200,  60],
         [128,  10,   5,  77,  11,  46, 109, 130],
         [ 16,  12, 709,  30,   3,   3, 191,  57],
         [182,  20, 242,   5,  94, 388, 119,  36]]))

In [7]:
# Training
from tqdm import tqdm

embed_dims = 32
num_heads = 4
num_decoder_layers = 2
eval_iters = 100
eval_interval = 1000
num_training_iters = 10000
batch_size = 4
block_size = 8

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = SimpleGPT(vocab_size, embed_dims, block_size, num_heads, num_decoder_layers)
model.to(device)
print(model)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters())

@torch.no_grad
def estimate_loss(dataset):
    losses = torch.zeros(eval_iters)
    model.eval()
    for i in range(eval_iters):
        inputs, targets = get_batch(dataset, device, batch_size, block_size)
        logits = model(inputs)
        B, T, C = logits.shape
        loss = loss_fn(logits.view(B*T, C), targets.view(B*T))
        losses[i] = loss.item()
    model.train()
    return losses.mean()


for i in tqdm(range(num_training_iters)):
    inputs, targets = get_batch(traindata, device, batch_size, block_size)
    optimizer.zero_grad()
    logits = model(inputs)
    B, T, C = logits.shape
    loss = loss_fn(logits.view(B*T, C), targets.view(B*T))
    loss.backward()
    optimizer.step()
    if i % eval_interval == 0 or i == num_training_iters - 1:
        print(f"Train Loss={estimate_loss(traindata)} Test Loss={estimate_loss(testdata)}")
    

SimpleGPT(
  (tok_emb_table): Embedding(1000, 32)
  (pos_emb_table): Embedding(8, 32)
  (decoder_blocks): Sequential(
    (0): DecoderBlock(
      (ln1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      (attn): MultiheadedSelfAttention(
        (c_attn): Linear(in_features=32, out_features=96, bias=True)
        (c_proj): Linear(in_features=32, out_features=32, bias=True)
      )
      (ln2): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      (mlp): Sequential(
        (0): Linear(in_features=32, out_features=128, bias=True)
        (1): ReLU()
        (2): Linear(in_features=128, out_features=32, bias=True)
        (3): Dropout(p=0.5, inplace=False)
      )
    )
    (1): DecoderBlock(
      (ln1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      (attn): MultiheadedSelfAttention(
        (c_attn): Linear(in_features=32, out_features=96, bias=True)
        (c_proj): Linear(in_features=32, out_features=32, bias=True)
      )
      (ln2): LayerNorm((32,), eps=

  from .autonotebook import tqdm as notebook_tqdm
  0%|          | 21/10000 [00:00<02:40, 62.25it/s]

Train Loss=7.0566182136535645 Test Loss=7.075974941253662


 10%|█         | 1036/10000 [00:05<00:54, 165.27it/s]

Train Loss=5.057060718536377 Test Loss=4.96154260635376


 20%|██        | 2027/10000 [00:10<00:48, 165.44it/s]

Train Loss=4.6959309577941895 Test Loss=4.719968318939209


 30%|███       | 3044/10000 [00:15<00:38, 179.65it/s]

Train Loss=4.432583808898926 Test Loss=4.549083709716797


 40%|████      | 4041/10000 [00:20<00:40, 145.57it/s]

Train Loss=4.3404436111450195 Test Loss=4.444079399108887


 50%|█████     | 5041/10000 [00:24<00:27, 177.76it/s]

Train Loss=4.250292778015137 Test Loss=4.368217468261719


 60%|██████    | 6035/10000 [00:29<00:23, 172.10it/s]

Train Loss=4.156818389892578 Test Loss=4.302825927734375


 70%|███████   | 7028/10000 [00:33<00:17, 169.48it/s]

Train Loss=4.123233318328857 Test Loss=4.143256664276123


 80%|████████  | 8032/10000 [00:39<00:21, 91.80it/s] 

Train Loss=3.9391846656799316 Test Loss=4.1493916511535645


 90%|█████████ | 9034/10000 [00:44<00:05, 176.92it/s]

Train Loss=3.969677448272705 Test Loss=4.232929706573486


100%|██████████| 10000/10000 [00:48<00:00, 206.42it/s]

Train Loss=4.024553298950195 Test Loss=4.1811370849609375





In [8]:
print(sp.decode(model.generate(torch.ones((1,1), dtype=torch.long, device = device) * 80, 1000)[0].tolist()))

Sowtlabssked by three that am thy married
Theth mine, he graveer I but see'chent-bome ed:
Nor is have heaven not finrigeage but makes'll think
EEN EL:
Mend, when I would op'th bad-bly.

ANTIGONTHSingan fighticail.
Mhe me there now vies laveriess that have fosemonk him and I am,
three rather insolel?

NORK: for the grace Lord you.

ESCALUS:
I hast to set and
ARS answer,
Second subjects, thou shalt soon lety
I not nothing, I bulueth of call,
I 'And sovereign more wordly to'd of my daughter to themer,
Neaster a sateice
Second huitizeplastutonceal me for most
Thoal kind what,
Loundantly must appeasoedfulpel stillgickerk's me to labarvsumer wit must to puten sweet fe.

PNEV:
Do that fortdthers oby, much vainrond life.

YomRUT ELIZ:
Sover meuld it like adoorem's.
Mel you bid that it be here, howtimeg and this now.

H
ASIOLERRENCE:
Head die tend wit, but have give of this abasse; see the Lady here, and mother mech to reper att takes and semstiT:
I veryt our country, joundsit
Bus lieleten us y