In [53]:
import numpy as np

from minbpe import BasicTokenizer

import torch
from torch import nn
import torch.optim as optim

from Transformer import Transformer_Decoder

In [3]:
with open('data/shakespeare.txt', 'r', encoding='utf-8') as f:
  shakespeare_text = f.read()

len(shakespeare_text)

1115394

## Tokenizer

In [4]:
tokenizer = BasicTokenizer()
tokenizer.train(shakespeare_text, vocab_size=1024)

In [5]:
# Add special tokens, we do not need <unk> here because data is in english and fit in ASCII
max_vocab_id = list(tokenizer.vocab.keys())[-1]
tokenizer.special_tokens = {
    "<sos>": max_vocab_id + 1,
    "<eos>": max_vocab_id + 2,
    "<unk>": max_vocab_id + 3,
    "<pad>": max_vocab_id + 4,
}

# Save to disk
tokenizer.save("model/model_shakespeare")

In [55]:
# Load from disk
tokenizer = BasicTokenizer()
tokenizer.load("model/model_shakespeare.model")

In [10]:
shakespeare_tokenized = tokenizer.encode(shakespeare_text)

In [12]:
len(shakespeare_tokenized)

443727

In [15]:
shakespeare_tokenized_tensor = torch.tensor(shakespeare_tokenized, dtype=torch.long)

total_length = len(shakespeare_tokenized_tensor)
train_data = shakespeare_tokenized_tensor[:total_length * 80 // 100]
val_data = shakespeare_tokenized_tensor[total_length * 80 // 100:total_length * 90 // 100]
test_data = shakespeare_tokenized_tensor[total_length * 90 // 100:]

train_data.shape, val_data.shape, test_data.shape

(torch.Size([354981]), torch.Size([44373]), torch.Size([44373]))

## Dataloader

In [32]:
def get_batch(split, device, input_length=256, batch_size=64):
    data = train_data
    if split == 'val':
        data = val_data
    elif split == 'test':
        data = test_data
    
    ix = torch.randint(len(data) - input_length, (batch_size,))
    x = torch.stack([data[i:i+input_length] for i in ix])
    y = torch.stack([data[i+1:i+input_length+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

x, y = get_batch('train', 'cpu', input_length=256, batch_size=64)
x.shape, y.shape

(torch.Size([64, 256]), torch.Size([64, 256]))

## Env

In [41]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

max_input_length = 256

vocab_size = len(tokenizer.vocab)

num_layers = 6
num_heads = 6
embed_dim = 384
input_dropout = 0.2

lr = 0.0001
epochs = 1000
eval_interval = 1
batch_size = 64

## Training

In [42]:
SEED = 42

np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

model = Transformer_Decoder(vocab_size, num_layers, num_heads, embed_dim, max_input_length, input_dropout).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=lr)

for epoch in range(epochs):
  model.train()

  train_loss = 0

  x, y = get_batch('train', device=device, input_length=max_input_length, batch_size=batch_size)
  y_logits = model(x)
    
  B, T, C = y_logits.shape
  loss = loss_fn(y_logits.reshape(B * T, C), y.reshape(B * T))
  train_loss += loss.item()

  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()

  if epoch % eval_interval == 0 or epoch == epochs-1:
    model.eval()
    val_loss = 0

    with torch.inference_mode():
      x, y = get_batch('val', device=device, input_length=max_input_length, batch_size=batch_size)

      y_logits = model(x)
    
      B, T, C = y_logits.shape
      loss = loss_fn(y_logits.reshape(B * T, C), y.reshape(B * T))
      val_loss += loss.item()
    
    print(f"Epoch {epoch} | Train loss: {train_loss:.4f} | Val loss: {val_loss:.4f}")
    
    torch.save({
      'epoch': epoch,
      'model_state_dict': model.state_dict(),
      'optimizer_state_dict': optimizer.state_dict(),
      'loss': train_loss,
    }, 'model/shakespeare_checkpoint.pth')

Epoch 0 | Train loss: 7.0943 | Val loss: 6.9431
Epoch 1 | Train loss: 6.9581 | Val loss: 6.8423
Epoch 2 | Train loss: 6.8479 | Val loss: 6.7657
Epoch 3 | Train loss: 6.7625 | Val loss: 6.7087
Epoch 4 | Train loss: 6.7243 | Val loss: 6.6528
Epoch 5 | Train loss: 6.6800 | Val loss: 6.6189
Epoch 6 | Train loss: 6.6530 | Val loss: 6.6026
Epoch 7 | Train loss: 6.6145 | Val loss: 6.5718
Epoch 8 | Train loss: 6.5904 | Val loss: 6.5537
Epoch 9 | Train loss: 6.5516 | Val loss: 6.5300
Epoch 10 | Train loss: 6.5503 | Val loss: 6.5156
Epoch 11 | Train loss: 6.5221 | Val loss: 6.5062
Epoch 12 | Train loss: 6.5024 | Val loss: 6.4848
Epoch 13 | Train loss: 6.4896 | Val loss: 6.4853
Epoch 14 | Train loss: 6.4883 | Val loss: 6.4538
Epoch 15 | Train loss: 6.4642 | Val loss: 6.4460
Epoch 16 | Train loss: 6.4531 | Val loss: 6.4120
Epoch 17 | Train loss: 6.4380 | Val loss: 6.4303
Epoch 18 | Train loss: 6.4173 | Val loss: 6.4102
Epoch 19 | Train loss: 6.4121 | Val loss: 6.4011
Epoch 20 | Train loss: 6.3903 

In [44]:
torch.save(model.state_dict(), 'model/shakespeare_model_weights.pth')

## Testing

In [None]:
SEED = 42

np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

loss_fn = nn.CrossEntropyLoss()

model.eval()
test_loss = 0

with torch.inference_mode():
  x, y = get_batch('test', device=device, input_length=max_input_length, batch_size=batch_size)

  y_logits = model(x)
    
  B, T, C = y_logits.shape
  loss = loss_fn(y_logits.reshape(B * T, C), y.reshape(B * T))
  test_loss += loss.item()
    
print(f"Test loss: {test_loss:.4f}")

Test loss: 4.6359


## Generate Shakespeare-like text

In [54]:
model = Transformer_Decoder(vocab_size, num_layers, num_heads, embed_dim, max_input_length, input_dropout).to(device)
model.load_state_dict(torch.load('model/shakespeare_model_weights.pth'))
context = torch.zeros((1, 1), dtype=torch.long, device=device)
output = model.generate(context, tokenizer.special_tokens['<eos>'], max_new_tokens=256)
context.shape, output.shape

(torch.Size([1, 1]), torch.Size([1, 257]))

In [56]:
print(tokenizer.decode(output[0].cpu().numpy()))

 up mine own:
Fell, the helfter per, one ashower.
Now, the pted lie my and pike of your noble t,
And heaven or lover a king.

DORDORTHUMNORTHASTANNwot to himself besces inter:
And not! theirs.

KING RICHARD II:
Marence.

D:
And then be etion I toorper dand:
Bosusback you;
Andier, and meanswer, I'lly, that wayill?

WEDWARD IV:
Barwick, I hamour plaguster, no man forbece, fin their of ht, ifesed I shall have meetchilts to this g our Now V:
Warwick, that stit!
I have had buts it my in solge. Fa mays ones,
Since fe, to seas?

HENRY Brazisso.

CLAND:
GLOUCESTER:
It helree
Thanest has now thou nolow, held, with thy seenough spossild hat's faafforth
