In [1]:
import numpy as np

import torch
from torch import nn
import torch.optim as optim

from Transformer import Transformer_Decoder

In [2]:
with open('data/shakespeare.txt', 'r', encoding='utf-8') as f:
  shakespeare_text = f.read()

len(shakespeare_text)

1115394

## Tokenizer

In [3]:
chars = sorted(list(set(shakespeare_text)))

vocab_size = len(chars)

# create a mapping from characters to integers and vice versa
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }

encode = lambda s: [stoi[c] for c in s]           # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l])  # decoder: take a list of integers, output a string

In [5]:
shakespeare_tokenized = encode(shakespeare_text)
shakespeare_tokenized[:10]

[18, 47, 56, 57, 58, 1, 15, 47, 58, 47]

In [7]:
decode(shakespeare_tokenized[:100])

'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou'

In [15]:
shakespeare_tokenized_tensor = torch.tensor(shakespeare_tokenized, dtype=torch.long)

total_length = len(shakespeare_tokenized_tensor)

train_data = shakespeare_tokenized_tensor[:total_length * 90 // 100]
val_data = shakespeare_tokenized_tensor[total_length * 90 // 100:]

train_data.shape, val_data.shape

(torch.Size([1003854]), torch.Size([111540]))

## Dataloader

In [16]:
def get_batch(split, device, input_length=256, batch_size=64):
    data = train_data
    if split == 'val':
        data = val_data
    
    ix = torch.randint(len(data) - input_length, (batch_size,))
    x = torch.stack([data[i:i+input_length] for i in ix])
    y = torch.stack([data[i+1:i+input_length+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

x, y = get_batch('train', 'cpu', input_length=256, batch_size=64)
x.shape, y.shape

(torch.Size([64, 256]), torch.Size([64, 256]))

## Env

In [12]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

max_input_length = 256

num_layers = 6
num_heads = 6
embed_dim = 384
input_dropout = 0.2

lr = 0.0003
epochs = 1000
eval_interval = 1
batch_size = 64

## Training

In [13]:
SEED = 42

np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

model = Transformer_Decoder(vocab_size, num_layers, num_heads, embed_dim, max_input_length, input_dropout).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=lr)

for epoch in range(epochs):
  model.train()

  train_loss = 0

  x, y = get_batch('train', device=device, input_length=max_input_length, batch_size=batch_size)
  y_logits = model(x)
    
  B, T, C = y_logits.shape
  loss = loss_fn(y_logits.reshape(B * T, C), y.reshape(B * T))
  train_loss += loss.item()

  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()

  if epoch % eval_interval == 0 or epoch == epochs-1:
    model.eval()
    val_loss = 0

    with torch.inference_mode():
      x, y = get_batch('val', device=device, input_length=max_input_length, batch_size=batch_size)

      y_logits = model(x)
    
      B, T, C = y_logits.shape
      loss = loss_fn(y_logits.reshape(B * T, C), y.reshape(B * T))
      val_loss += loss.item()
    
    print(f"Epoch {epoch} | Train loss: {train_loss:.4f} | Val loss: {val_loss:.4f}")
    
    torch.save({
      'epoch': epoch,
      'model_state_dict': model.state_dict(),
      'optimizer_state_dict': optimizer.state_dict(),
      'loss': train_loss,
    }, 'model/shakespeare_char_checkpoint.pth')

Epoch 0 | Train loss: 4.4013 | Val loss: 3.9155
Epoch 1 | Train loss: 3.7516 | Val loss: 3.4881
Epoch 2 | Train loss: 3.5366 | Val loss: 3.5311
Epoch 3 | Train loss: 3.5521 | Val loss: 3.4301
Epoch 4 | Train loss: 3.4187 | Val loss: 3.3043
Epoch 5 | Train loss: 3.3228 | Val loss: 3.2782
Epoch 6 | Train loss: 3.2462 | Val loss: 3.3088
Epoch 7 | Train loss: 3.2902 | Val loss: 3.3433
Epoch 8 | Train loss: 3.2938 | Val loss: 3.2300
Epoch 9 | Train loss: 3.2563 | Val loss: 3.1969
Epoch 10 | Train loss: 3.2076 | Val loss: 3.1429
Epoch 11 | Train loss: 3.1441 | Val loss: 3.1190
Epoch 12 | Train loss: 3.1279 | Val loss: 3.0718
Epoch 13 | Train loss: 3.1251 | Val loss: 3.0477
Epoch 14 | Train loss: 3.0532 | Val loss: 2.9681
Epoch 15 | Train loss: 2.9792 | Val loss: 2.9208
Epoch 16 | Train loss: 2.9543 | Val loss: 2.9221
Epoch 17 | Train loss: 2.9004 | Val loss: 2.9486
Epoch 18 | Train loss: 2.9213 | Val loss: 2.8706
Epoch 19 | Train loss: 2.8821 | Val loss: 2.8542
Epoch 20 | Train loss: 2.8387 

In [14]:
torch.save(model.state_dict(), 'model/shakespeare_char_model_weights.pth')

## Generate Shakespeare-like text

In [28]:
model = Transformer_Decoder(vocab_size, num_layers, num_heads, embed_dim, max_input_length, input_dropout).to(device)
model.load_state_dict(torch.load('model/shakespeare_char_model_weights.pth'))
context = torch.zeros((1, 1), dtype=torch.long, device=device)
output = model.generate(context, vocab_size, max_new_tokens=256)
context.shape, output.shape

(torch.Size([1, 1]), torch.Size([1, 257]))

In [29]:
print(decode(output[0].cpu().numpy()))


TORWAy, think, we be the wayn tis they
Are eye, I no she; to me rast be gounes
All reart wilto city to farthed?

GLOUCESTER:
For thour's laster.

SIR Engman:
They arvelly deveren of and the desepterous.
Pwhom: with thou do cemion the.

GLOUCESTER:
By ples 
