# Tiny Shakespeare Language Model

GPT-like language model that is trained on the tinyshakespeare dataset. This notebook was written while following Karpathy's 'Let's build GPT' vide. The only notable difference is the use of SentencePiece tokenization instead of a character level tokenization.

In [1]:
# We always start with a dataset to train on. Let's download the tiny shakespeare dataset
!mkdir data
!curl https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt -o data/tinyshakespeare

mkdir: data: File exists
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 1089k  100 1089k    0     0  7689k      0 --:--:-- --:--:-- --:--:-- 7725k


In [2]:
with open('data/tinyshakespeare') as f:
    text = f.read()

In [3]:
!mkdir models
import sentencepiece as spm
spm.SentencePieceTrainer.train(input='data/tinyshakespeare',
                               model_prefix='models/shakespeare_tokenizer_model',
                               vocab_size=1000,
                               character_coverage=1.0,
                               model_type='unigram',
                               remove_extra_whitespaces=False,
                               user_defined_symbols=["\n", "\r"])

mkdir: models: File exists


sentencepiece_trainer.cc(78) LOG(INFO) Starts training with : 
trainer_spec {
  input: data/tinyshakespeare
  input_format: 
  model_prefix: models/shakespeare_tokenizer_model
  model_type: UNIGRAM
  vocab_size: 1000
  self_test_sample_size: 0
  character_coverage: 1
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  pretokenization_delimiter: 
  treat_whitespace_as_suffix: 0
  allow_whitespace_only_pieces: 0
  user_defined_symbols: 

  user_defined_symbols: 
  required_chars: 
  byte_fallback: 0
  vocabulary_output_piece_score: 1
  train_extremely_large_corpus: 0
  seed_sentencepieces_file: 
  hard_vocab_limit: 1
  use_all_vocab: 0
  unk_id: 0
  bos_id: 1
  eos_id: 2
  pad_id: -1
  unk_piece: <unk>
  bos_piece: <s>
  eos_piece: </s

In [4]:
sp = spm.SentencePieceProcessor()
sp.load('models/shakespeare_tokenizer_model.model')
vocab_size = sp.get_piece_size()

In [5]:
import jax
import jax.numpy as jnp

In [12]:
from functools import partial

data = jnp.array(sp.encode(text))

traindata = data[:int(0.9 * len(data))]
testdata = data[int(0.9 * len(data)):]

dynamic_slice_vmap = jax.vmap(jax.lax.dynamic_slice, in_axes=(None, 0, None))

@partial(jax.jit, static_argnames=['batch_size, block_size'])
def get_batch(key, data, batch_size, block_size):
    batch_size=4
    block_size=8
    ix = jax.random.randint(key, shape=(batch_size, 1), minval=0, maxval=len(data) - block_size)
    x = dynamic_slice_vmap(data, ix, (block_size,))
    y = dynamic_slice_vmap(data, ix + 1, (block_size,))
    return x, y

key = jax.random.key(1337)
for _ in range(1):
    key, subkey = jax.random.split(key)
    print(get_batch(key, traindata, 4, 8))

(Array([[  3,  61,  58,  56, 119,  55,  25,  86],
       [  7,   3,  89, 146,  41, 431, 283, 349],
       [  3,   3, 360,  61,  89, 111, 599, 122],
       [ 72, 125, 170,  22,  25,   7, 102, 356]], dtype=int32), Array([[ 61,  58,  56, 119,  55,  25,  86,   7],
       [  3,  89, 146,  41, 431, 283, 349,   5],
       [  3, 360,  61,  89, 111, 599, 122, 126],
       [125, 170,  22,  25,   7, 102, 356,  94]], dtype=int32))


In [7]:
import zeptogpt
import importlib

importlib.reload(zeptogpt)

from zeptogpt.model import GPT
from zeptogpt.trainer import Trainer

In [8]:
# Training
from flax.training import train_state
import optax
import functools

embed_dims = 32
num_heads = 4
num_decoder_layers = 2
eval_iters = 100
eval_interval = 1000
num_training_iters = 10000
batch_size = 4
block_size = 8

key = jax.random.key(1337)
model = GPT(vocab_size, block_size, embed_dims, num_heads, num_decoder_layers)
params = model.init(key, jnp.ones((1, block_size), dtype=jnp.int32))
optimizer = optax.adamw(learning_rate=0.002)

# Create training state
state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=optimizer
)
loss_fn = lambda logits, targets: optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean()
train_batch_fn = functools.partial(get_batch, data=traindata, batch_size=batch_size, block_size=block_size)
test_batch_fn = functools.partial(get_batch, data=testdata, batch_size=batch_size, block_size=block_size)
trainer = Trainer(state, loss_fn, train_batch_fn, test_batch_fn, eval_iters, eval_interval, num_training_iters)
state = trainer.train(key)

  0%|          | 13/10000 [00:05<54:08,  3.07it/s]  

Train Loss=7.338044166564941 Test Loss=7.307309627532959


 10%|█         | 1033/10000 [00:09<00:45, 198.87it/s]

Train Loss=4.81251859664917 Test Loss=4.781274795532227


 21%|██        | 2073/10000 [00:12<00:26, 302.16it/s]

Train Loss=4.3535614013671875 Test Loss=4.4693522453308105


 31%|███       | 3063/10000 [00:14<00:25, 271.85it/s]

Train Loss=4.334170818328857 Test Loss=4.40308141708374


 41%|████      | 4052/10000 [00:17<00:18, 329.11it/s]

Train Loss=4.162240028381348 Test Loss=4.308862209320068


 51%|█████     | 5061/10000 [00:19<00:17, 279.46it/s]

Train Loss=4.128535270690918 Test Loss=4.30137825012207


 61%|██████    | 6075/10000 [00:22<00:12, 325.63it/s]

Train Loss=3.9623284339904785 Test Loss=4.1711297035217285


 71%|███████   | 7069/10000 [00:24<00:08, 361.55it/s]

Train Loss=3.969635009765625 Test Loss=4.069582939147949


 81%|████████  | 8090/10000 [00:26<00:05, 321.44it/s]

Train Loss=3.937488317489624 Test Loss=4.138076305389404


 91%|█████████ | 9059/10000 [00:29<00:03, 249.73it/s]

Train Loss=4.022124290466309 Test Loss=4.181920528411865


100%|██████████| 10000/10000 [00:32<00:00, 312.00it/s]

Train Loss=3.875575542449951 Test Loss=4.13692569732666





In [9]:
key = jax.random.key(1337)
model = GPT(vocab_size, block_size, embed_dims, num_heads, num_decoder_layers)
print(sp.decode(model.generate(key, state.params, jnp.ones((1, 1), dtype=jnp.int32) * 13, 1000)[:, 0, 0].tolist()))

d forward they come, cousin at her
My, or age.
ISer'sish, you are well to,
I have brield slain
I amre, in honour which kned with, see, the coplave you Kate, I have exvy, it should follow Henry?
Give
OKE VINCENTIO cres curst against thee fromet Cise inweashers your full.

KING EDWARD I's thou' how I here:
To.
Hein'pose my royal so now,

I, strership came.BRUTUS: come cold,
DUCCANO you, and kill him me and myness, and, and long to cubar ppostercience be then I knows, when itfckselamey k and deck, king.
Hhereas partter.
Gogo soun you rather trg with'd,
ARuituse ho?
Are to unt:
KING RICHARD I learn of thy way theic no?

BUCAMILLO:

f Cors somein, your honour that I,
BUKE ONow! should this news a spow'd exLL against him asin, but Here him,
Are.

I:
As! holy are not,
By. Good,
'sed pom a old serly
UKE VINCENTILAU'ss's,
I EDWARDLAURENCENTIO:
chiin some de's
NO:
Wersion tongue?
Hans more I provant had spey say nots.
By; and know by France nothing

I am from a her with thy face.
Ho is yet yetni