### Transformers 103: Using Transformers to generate Shakespeare


This is the last notebook in this series of Transformer notebook. In the previous Notebooks we have already understood and developed in code a perfect understanding of the Transformer Network. In this notebook, we are just going to use the `Transformer` module class we build using Pytorch to train and generate Shakespeare like text.

The dataset that we are going to use in this notebook is [Tiny Shakespeare](https://github.com/karpathy/char-rnn/blob/master/data/tinyshakespeare/input.txt).


Lets get started with the simple imports and get on with our work.


In [1]:
## Importing necessary packages ##

import torch
import torch.nn
from torch.utils.data import Dataset, DataLoader, random_split

from transformer import Transformer

from tqdm import tqdm
import numpy as np

## For displaying exact values  and not in exponentiations##
torch.set_printoptions(sci_mode=False)


In [2]:
## Setting the necessary Hyperparameters ##

SEED = 97  # for reproducing each and every result

# Specific to data
SEQUENCE_LENGTH = 128  # (also needed in model)
TRAIN_VAL_TEST_SPLIT = [0.9, 0.05, 0.05]
BATCH_SIZE = 64

# Specific to model
EMBEDDING_DIM = 32
NUM_HEADS = 8
USE_ENCODER = False
USE_DECODER = True
DECODER_NUM_LAYERS = 6

## Setting the device ##
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Training
LEARNING_RATE = 1e-3
NUM_EPOCHS = 10
DROPOUT_RATE = 0.1

# Saving
CKPT = 'checkpoint/bst_transformer.pt'
LOAD_MODEL = False


In this work we are going to develop a character level generator model with the data we have. So, without further adieu lets load our data and make our vocabulary.


#### Reading Data


In [3]:
## Reading data ##

with open("tiny_shakespeare.txt", "r") as file:
    data = file.read()

## Splitting it into a character level list ##
data = list(data)

print(f"Dataset length is : {len(data)}")


Dataset length is : 1115394


Nice... So, we have a long sequence of all the words broken down to its characters in the dataset.


We will update this more in the next few steps. But at first let us build our vocabulary.


In [4]:
## Vocabulary ##

vocab = sorted(list(set(data)))

print(f"Lenght of the vocabulary : {len(vocab)} and the vocabulary is : {vocab}")


Lenght of the vocabulary : 65 and the vocabulary is : ['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']


Amazing.. So we have loaded our data.

Much like the transformers_101 notebook, we need to build a mapping from the characters to the index and vice versa to help the model receive numeric data and convert from numeric data to characters during generation.


In [5]:
## Mapping ##

char_2_idx = {ch: idx for idx, ch in enumerate(vocab)}

idx_2_char = {idx: ch for idx, ch in enumerate(vocab)}

print(
    f"For sanity check:\n-------------------\n\nchar_2_idx :\n{char_2_idx}\n-------------------\nidx_2_char:\n{idx_2_char}"
)


For sanity check:
-------------------

char_2_idx :
{'\n': 0, ' ': 1, '!': 2, '$': 3, '&': 4, "'": 5, ',': 6, '-': 7, '.': 8, '3': 9, ':': 10, ';': 11, '?': 12, 'A': 13, 'B': 14, 'C': 15, 'D': 16, 'E': 17, 'F': 18, 'G': 19, 'H': 20, 'I': 21, 'J': 22, 'K': 23, 'L': 24, 'M': 25, 'N': 26, 'O': 27, 'P': 28, 'Q': 29, 'R': 30, 'S': 31, 'T': 32, 'U': 33, 'V': 34, 'W': 35, 'X': 36, 'Y': 37, 'Z': 38, 'a': 39, 'b': 40, 'c': 41, 'd': 42, 'e': 43, 'f': 44, 'g': 45, 'h': 46, 'i': 47, 'j': 48, 'k': 49, 'l': 50, 'm': 51, 'n': 52, 'o': 53, 'p': 54, 'q': 55, 'r': 56, 's': 57, 't': 58, 'u': 59, 'v': 60, 'w': 61, 'x': 62, 'y': 63, 'z': 64}
-------------------
idx_2_char:
{0: '\n', 1: ' ', 2: '!', 3: '$', 4: '&', 5: "'", 6: ',', 7: '-', 8: '.', 9: '3', 10: ':', 11: ';', 12: '?', 13: 'A', 14: 'B', 15: 'C', 16: 'D', 17: 'E', 18: 'F', 19: 'G', 20: 'H', 21: 'I', 22: 'J', 23: 'K', 24: 'L', 25: 'M', 26: 'N', 27: 'O', 28: 'P', 29: 'Q', 30: 'R', 31: 'S', 32: 'T', 33: 'U', 34: 'V', 35: 'W', 36: 'X', 37: 'Y', 38: '

Look perfect.

Now its time that we build our Pytorch dataset and dataloader.


In [6]:
## Dataset ##


class ShakespeareDataset(Dataset):
    """Builds the Shakespeare dataset."""

    def __init__(self, data, sequence_length):
        """Constructor."""

        self.data = data
        self.sequence_length = sequence_length

    def __getitem__(self, idx):
        """Returns a single sample, i.e., a sequence and label."""
        sequence = [
            char_2_idx[ch]
            for ch in self.data[
                idx : idx + self.sequence_length
            ]
        ]
        label = [
            char_2_idx[ch]
            for ch in self.data[
                idx + 1 : idx + 1 + self.sequence_length
            ]
        ]

        return torch.tensor(sequence), torch.tensor(label)

    def __len__(self):
        """Length of the dataset"""
        return len(self.data) - (self.sequence_length + 1)


Let's just do some sanity checking on the dataset.


In [7]:
## Sanity checking ##

shakespeare_dataset = ShakespeareDataset(data=data, sequence_length=SEQUENCE_LENGTH)

print(f"Dataset length is {len(shakespeare_dataset)}")


Dataset length is 1115265


Now before doing anything we must split the dataset into train, val, and test.


In [8]:
## Splitting the shakespeare dataset into train, validation and test ##

train_dataset, val_dataset, test_dataset = random_split(
    shakespeare_dataset,
    lengths=TRAIN_VAL_TEST_SPLIT,
    generator=torch.Generator().manual_seed(SEED),
)

print(f"Train dataset length : {len(train_dataset)}")
print(f"Validation dataset length : {len(val_dataset)}")
print(f"Test dataset length : {len(val_dataset)}")


Train dataset length : 1003739
Validation dataset length : 55763
Test dataset length : 55763


Perfect!! This looks awesome!!

Now we can poke our heads once inside the training dataset and see if it actually is giving what we want.


For this we are going to use a utility function.


In [9]:
## Utility function to show data ##


def show_data(data, verbose=True):
    """Given a data tensor, maps them to string and prints them."""
    str_data = [idx_2_char[each_char.item()] for each_char in data.data]
    if verbose:
        print(str_data)
    else:
        return str_data


In [10]:
## Testing ##

torch.manual_seed(SEED)

random_idx = torch.randint(low=0, high=len(train_dataset), size=(1,))

sequence, label = train_dataset[random_idx]

print(f"Sequence")
print(f"----------")
print(f"{show_data(sequence, False)}")
print(f"The shape of the sequence is : {sequence.shape}")

print(f"\nLabel")
print(f"----------")
print(f"{show_data(label, False)}")
print(f"The shape of the label is : {label.shape}")


Sequence
----------
['\n', 'A', 'n', 'd', ' ', 'p', 'i', 't', 'c', 'h', ' ', 'o', 'u', 'r', ' ', 'e', 'v', 'i', 'l', 's', ' ', 't', 'h', 'e', 'r', 'e', '?', ' ', 'O', ',', ' ', 'f', 'i', 'e', ',', ' ', 'f', 'i', 'e', ',', ' ', 'f', 'i', 'e', '!', '\n', 'W', 'h', 'a', 't', ' ', 'd', 'o', 's', 't', ' ', 't', 'h', 'o', 'u', ',', ' ', 'o', 'r', ' ', 'w', 'h', 'a', 't', ' ', 'a', 'r', 't', ' ', 't', 'h', 'o', 'u', ',', ' ', 'A', 'n', 'g', 'e', 'l', 'o', '?', '\n', 'D', 'o', 's', 't', ' ', 't', 'h', 'o', 'u', ' ', 'd', 'e', 's', 'i', 'r', 'e', ' ', 'h', 'e', 'r', ' ', 'f', 'o', 'u', 'l', 'l', 'y', ' ', 'f', 'o', 'r', ' ', 't', 'h', 'o', 's', 'e', ' ', 't', 'h']
The shape of the sequence is : torch.Size([128])

Label
----------
['A', 'n', 'd', ' ', 'p', 'i', 't', 'c', 'h', ' ', 'o', 'u', 'r', ' ', 'e', 'v', 'i', 'l', 's', ' ', 't', 'h', 'e', 'r', 'e', '?', ' ', 'O', ',', ' ', 'f', 'i', 'e', ',', ' ', 'f', 'i', 'e', ',', ' ', 'f', 'i', 'e', '!', '\n', 'W', 'h', 'a', 't', ' ', 'd', 'o', 's', 't

Perfect...

Now its time that we setup our DataLoaders.


In [11]:
## Setting up dataloaders ##

train_dl = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dl = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dl = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

print(f"Train Dataloader length : {len(train_dl)}")
print(f"Val Dataloader length : {len(val_dl)}")
print(f"Test Dataloader length : {len(test_dl)}")


Train Dataloader length : 15684
Val Dataloader length : 872
Test Dataloader length : 872


Nice... So we are almost done with everything regarding the dataset and dataloading. Now its time to setup our model. To do this we are going to use our already built Transformer module from the previous notebook.


In [12]:
## Setting our transformer model ##

transformer = Transformer(
    vocab_size=len(vocab),
    sequence_length=SEQUENCE_LENGTH,
    d_embed=EMBEDDING_DIM,
    use_encoder=USE_ENCODER,
    use_decoder=USE_DECODER,
    num_heads=NUM_HEADS,
    decoder_num_layers=DECODER_NUM_LAYERS,
    device=DEVICE,
    dropout_rate=DROPOUT_RATE
)


In [13]:
## Setting the entire transformer to default device ##

transformer = transformer.to(DEVICE)


Now that we have nicely setup our transformer model, its time to setup the loss function and the optimizer. We are going to use the Adam Optimizer and the CrossEntropyLoss function.


In [14]:
## Setting up loss function ##

criterion = torch.nn.CrossEntropyLoss()

## Setting up optimizer ##

optim = torch.optim.Adam(transformer.parameters(), lr=LEARNING_RATE)

## Learning Rate scheduler ##

#scheduler = torch.optim.lr_scheduler.CyclicLR(optim, base_lr=1e-4, max_lr=1e-2, cycle_momentum=False)

## Keeping the epoch losses ##
epoch_train_losses = []
epoch_val_losses = []


Now we are going to setup a small utility function to derive the single sweep loss function. This is necessary to check the validation loss from time to time. Also at the end it would come in handy to see how our model works out for test data too.


In [15]:
## Single sweep loss function ##


def single_sweep_loss(model, dataloader, device=DEVICE):
    """Single sweep loss function"""

    minibatch_losses = []
    with torch.no_grad():
        for sequence, label in dataloader:
            sequence = sequence.to(device)
            label = label.to(device)
            pred = model(outputs=sequence)
            B, L, C = pred.shape
            loss = criterion(pred.view(B * L, C), label.view(B * L))
            minibatch_losses.append(loss.item())

    return sum(minibatch_losses) / len(minibatch_losses)


Nicely setup....

Now it is time to train our model for atleast 50 epochs and lets see how it performs.


In [16]:
best_val_loss = float("inf")

## Training our model ##

if LOAD_MODEL:
    try:
        transformer.load_state_dict(torch.load(CKPT))
    except:
        print(f"Check the model checkpoint path!")

for i in range(NUM_EPOCHS):
    loop = tqdm(train_dl)

    transformer.train()

    for sequence, label in loop:
        sequence = sequence.to(DEVICE)
        label = label.to(DEVICE)
        pred = transformer(outputs=sequence)
        B, L, C = pred.shape

        loss = criterion(pred.view(B * L, C), label.view(B * L))

        loop.set_description(f"Epoch : {i + 1} / {NUM_EPOCHS} ::")
        loop.set_postfix(loss=loss.item())

        optim.zero_grad()
        loss.backward()
        optim.step()
        # scheduler.step()

    print(f"Doing single sweep loss calculation")

    transformer.eval()

    train_loss = single_sweep_loss(transformer, train_dl)
    val_loss = single_sweep_loss(transformer, val_dl)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        print(f"Best Validation Loss : {best_val_loss}")
        print(f"Saving model now...")
        torch.save(transformer.state_dict(), CKPT)

    print(
        f"Epoch : {i + 1} / {NUM_EPOCHS} ::: Training Loss : {train_loss} , Validation Loss : {val_loss}"
    )

    epoch_train_losses.append(train_loss)
    epoch_val_losses.append(val_loss)


Epoch : 1 / 10 ::: 100%|██████████| 15684/15684 [50:08<00:00,  5.21it/s, loss=1.66]


Doing single sweep loss calculation
Best Validation Loss : 1.5737870949397392
Saving model now...
Epoch : 1 / 10 ::: Training Loss : 1.5741180489331414 , Validation Loss : 1.5737870949397392


Epoch : 2 / 10 ::: 100%|██████████| 15684/15684 [50:47<00:00,  5.15it/s, loss=1.6]  


Doing single sweep loss calculation
Best Validation Loss : 1.5285188949436224
Saving model now...
Epoch : 2 / 10 ::: Training Loss : 1.528699892496339 , Validation Loss : 1.5285188949436224


Epoch : 3 / 10 ::: 100%|██████████| 15684/15684 [51:17<00:00,  5.10it/s, loss=1.57]


Doing single sweep loss calculation
Best Validation Loss : 1.5013233946824291
Saving model now...
Epoch : 3 / 10 ::: Training Loss : 1.5013290766139202 , Validation Loss : 1.5013233946824291


Epoch : 4 / 10 ::: 100%|██████████| 15684/15684 [51:56<00:00,  5.03it/s, loss=1.61]


Doing single sweep loss calculation
Best Validation Loss : 1.4826875184107264
Saving model now...
Epoch : 4 / 10 ::: Training Loss : 1.4826288189139485 , Validation Loss : 1.4826875184107264


Epoch : 5 / 10 ::: 100%|██████████| 15684/15684 [51:50<00:00,  5.04it/s, loss=1.58]


Doing single sweep loss calculation
Best Validation Loss : 1.4727591244726006
Saving model now...
Epoch : 5 / 10 ::: Training Loss : 1.473019695212056 , Validation Loss : 1.4727591244726006


Epoch : 6 / 10 ::: 100%|██████████| 15684/15684 [51:50<00:00,  5.04it/s, loss=1.61]


Doing single sweep loss calculation
Best Validation Loss : 1.4650515639180437
Saving model now...
Epoch : 6 / 10 ::: Training Loss : 1.4654447093997431 , Validation Loss : 1.4650515639180437


Epoch : 7 / 10 ::: 100%|██████████| 15684/15684 [51:44<00:00,  5.05it/s, loss=1.6] 


Doing single sweep loss calculation
Best Validation Loss : 1.4604199184190243
Saving model now...
Epoch : 7 / 10 ::: Training Loss : 1.4607035502491177 , Validation Loss : 1.4604199184190243


Epoch : 8 / 10 ::: 100%|██████████| 15684/15684 [51:43<00:00,  5.05it/s, loss=1.57]


Doing single sweep loss calculation
Best Validation Loss : 1.4540711500502508
Saving model now...
Epoch : 8 / 10 ::: Training Loss : 1.4543130292222988 , Validation Loss : 1.4540711500502508


Epoch : 9 / 10 ::: 100%|██████████| 15684/15684 [51:51<00:00,  5.04it/s, loss=1.62]


Doing single sweep loss calculation
Best Validation Loss : 1.4481872492427126
Saving model now...
Epoch : 9 / 10 ::: Training Loss : 1.4485138261448938 , Validation Loss : 1.4481872492427126


Epoch : 10 / 10 ::: 100%|██████████| 15684/15684 [51:52<00:00,  5.04it/s, loss=1.5] 


Doing single sweep loss calculation
Best Validation Loss : 1.4456292814617857
Saving model now...
Epoch : 10 / 10 ::: Training Loss : 1.4458694900520235 , Validation Loss : 1.4456292814617857


In [19]:
## Generating ##

transformer.load_state_dict(torch.load(CKPT))

transformer.eval()

transformer.generate(max_length=696, idx_2_char_map=idx_2_char, device=DEVICE)


Tell to have the yet puner me speaking,
That away your kingland: tray a sorring'd and a tyrat's.

CLOMENES:
I prove, I happy one, and we can; more!

MENENIUS:
You cannot my business ance, stand why hour,
Would sit me as you will: my queen spirike,
For what nather of by the lighterity,
And for this poison; in the dispose defend;
And not mother shall strange to dost maid
Do faith the king!

GLOUCESTER:
Tyrrachignce to galise.

ANTONIO:
Ox he not drook and was but our oddness.

LEONTES:
Would fear to spake of her winderful hell be
Romeo, dield by his manner guilt, beingmon,
That changely repeaty my defut.

JULIET,
Thou lettheren's I will with it, to near. as thou leave.

QUEEN MARGARET:
You
