# TinyShakespeare

A character Level lanaguage model to generate shakespeare.

In [1]:
import torch
import torch.nn.functional as F

In [2]:
with open('tiny_shakespeare.txt', 'r') as file:
    text=file.read()

print(type(text))
print(len(text))

<class 'str'>
1115394


In [3]:
print(text[:500])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor


In [4]:
# Create Vocabulary (over characters)
vocab=sorted(list(set(text)))
print(f"Here is our Vocab of {len(vocab)} characters: \n{''.join(vocab)}")

Here is our Vocab of 65 characters: 

 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [5]:
#Create Vocab Mappings

char_to_idx = {char: idx for idx, char in enumerate(vocab)}
idx_to_char = {idx: char for idx, char in enumerate(vocab)}

In [6]:
# create function to encode text strings and decode indicies

def encode(text:str):
    return [char_to_idx[char] for char in text]
def decode(indices: list):
    return ''.join([idx_to_char[idx] for idx in indices])

In [7]:
# Encode entire dataset and store in a torch.tensor

# 1D ARRAY
data = torch.tensor(encode(text), dtype=torch.int64) ## by default it trys to interprete dtype --> best to just define
print(data.shape)
print(data[:100])

torch.Size([1115394])
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59])


In [8]:
# split into train-val split - what about splitting over a word? well this is a character level model so doesnt matter!

# Purpose: We dont want the model to just memorise the text. We want it to generate new, shakespeare like text.
# So by with holding some data during training. We can evaluate how good our model is by how well it reproduces shakespeare its never seen.

split = int(data.shape[0]*0.9) #90% Train 10% Validation
train_data = data[:split]
val_data = data[split:]

print(train_data.shape)
print(val_data.shape)

torch.Size([1003854])
torch.Size([111540])


## Establishing the Dataset

Two key points in the dataset:
1. For a certain context length the data has multiple examples, its not just the full context length and the target. This aligns with point 2.
2. We want the model to work with varying context lengths (up to the max context length) because the model may only be given a smaller context (ie. the text "why?" - which is 4 characters. If model only has fixed context length it wont be able to deal with this)

In [9]:
context_len=8

for i in range(1, len(train_data[:context_len])+1):
    
    x = train_data[:i]
    y = train_data[i]
    print(f"Context: {x.tolist()} --> target: {y}")

Context: [18] --> target: 47
Context: [18, 47] --> target: 56
Context: [18, 47, 56] --> target: 57
Context: [18, 47, 56, 57] --> target: 58
Context: [18, 47, 56, 57, 58] --> target: 1
Context: [18, 47, 56, 57, 58, 1] --> target: 15
Context: [18, 47, 56, 57, 58, 1, 15] --> target: 47
Context: [18, 47, 56, 57, 58, 1, 15, 47] --> target: 58


In [10]:
## Creating batches

context_len=8
batch_size=4

# inputs and targets will be tensors of shape (4,8), where targets context will be shifted 1 place to the right.

def create_batch(split:str):
    data: torch.tensor = train_data if split=="train" else val_data
    
    # randomly select batch_size number of starting points in the data
    # The upper bound is (len(data) - context_len) to ensure we can always build a complete
    # context of size context_len, even if we start at the last valid position
    starting_idxs: torch.tensor = torch.randint(0, len(data)-context_len, size=(batch_size,))
    
    #create a list of tensors, which we then stack
    X=torch.stack([data[start:start+context_len] for start in starting_idxs])
    y=torch.stack([data[start+1: start+context_len+1] for start in starting_idxs])
    
    return X, y

In [11]:
torch.manual_seed(196)

X, y = create_batch("train")
print(X)
print(y)
print(f"Total of {y.shape[0]*y.shape[1]} examples")

tensor([[44, 59, 52, 42, 39, 51, 43, 52],
        [52, 43,  5, 43, 56,  1, 57, 53],
        [ 0, 35, 43, 50, 50, 11,  1, 58],
        [53, 58, 46, 43, 56,  6,  1, 40]])
tensor([[59, 52, 42, 39, 51, 43, 52, 58],
        [43,  5, 43, 56,  1, 57, 53,  1],
        [35, 43, 50, 50, 11,  1, 58, 46],
        [58, 46, 43, 56,  6,  1, 40, 59]])
Total of 32 examples


In [12]:
#Spelling out examples


for batch in range(batch_size):
    print("\nNext Example in the Batch:\n")
    for example_idx in range(context_len):
        
        context = X[batch][:example_idx+1] # Remember indexing is up to but not including e.ge [:0] is empty so [:0+1] is 1st example
        target = y[batch][example_idx]
        
        print(f"Context: {context} --> Target: {target}")
    


Next Example in the Batch:

Context: tensor([44]) --> Target: 59
Context: tensor([44, 59]) --> Target: 52
Context: tensor([44, 59, 52]) --> Target: 42
Context: tensor([44, 59, 52, 42]) --> Target: 39
Context: tensor([44, 59, 52, 42, 39]) --> Target: 51
Context: tensor([44, 59, 52, 42, 39, 51]) --> Target: 43
Context: tensor([44, 59, 52, 42, 39, 51, 43]) --> Target: 52
Context: tensor([44, 59, 52, 42, 39, 51, 43, 52]) --> Target: 58

Next Example in the Batch:

Context: tensor([52]) --> Target: 43
Context: tensor([52, 43]) --> Target: 5
Context: tensor([52, 43,  5]) --> Target: 43
Context: tensor([52, 43,  5, 43]) --> Target: 56
Context: tensor([52, 43,  5, 43, 56]) --> Target: 1
Context: tensor([52, 43,  5, 43, 56,  1]) --> Target: 57
Context: tensor([52, 43,  5, 43, 56,  1, 57]) --> Target: 53
Context: tensor([52, 43,  5, 43, 56,  1, 57, 53]) --> Target: 1

Next Example in the Batch:

Context: tensor([0]) --> Target: 35
Context: tensor([ 0, 35]) --> Target: 43
Context: tensor([ 0, 35

## Baseline: Bigram Model

As a baseline we will use the Bigram Language model. Which predicts the next character given a previous character. The prediction is made based on the frequency of bigram pairs (i.e. How many times does a come after r in our dataset). We collect all the pairwise combinations of our vocab (creating a (vocab_size, vocab_size) Look-up Table). We normalise over the row, to create a probability dist of the next character given a certain character. i.e. if the first character in our vocab is a. then the first row of our look-up table represents the frequency of every character in our vocab coming after a (i.e. given a). We normalise to create a probility distribution of the next character given a. These are the logits of the bigram model.

On inference we can feed prob_dist into multinomial dist (which will sample the next character idx based on the prob_dist we gave). This idx is then used to index into our loo-up table, grabing the prob_dist of that character... and so on untill we reach an end character!

In [56]:
import torch.nn as nn

class BigramLanguageModel(nn.Module):
    
    def __init__(self, vocab_size):
        """
        Define Layers and Parameters
        """
        super().__init__()
        self.bigram_lookup_table = nn.Embedding(vocab_size, vocab_size)
    
    
    def forward(self, idxs, targets=None):
        """
        Define how data flows through the layers of the network
        
        idxss | (batch_dim, context_len): batches of indices of characters in our Vocab. 
        targets | (batch_dim, context_len): batch of target indices
        """
        # For each index in idxs you'll have a 1D array of length vocab_size representing the "frequency" dist over the whole vocab given that index
        logits = self.bigram_lookup_table(idxs) # (batch_dim, context_len, vocab_size) <-- (each row represents an index, along the columns is the dist for that index) we then have batch_dim num of those matricies
        
        if targets == None:
            loss=None
        else:
            ## ----- Loss ----- ##
            # ve- log likelihood <==> Cross entropy in this senario
            # Cross Entropy wants the last dim to be num_classes (i.e. the len(vocab))
            # reshape/view
            batch_size, context_len, num_classes = logits.shape
            logits = logits.view(batch_size*context_len, num_classes) # stack the batches. each row represents one example of context
            targets = targets.view(batch_size*context_len) # long 1D array
            loss = F.cross_entropy(logits, targets)
        
        return logits, loss
    
    def generate(self, idxs, max_token_length):
        """
        Define how the model goes about generating new tokens
        """
        # No longer generating names. Now we want to, given a sequence (idxs) generate new tokens and concat to idxs.
        
        for i in range(max_token_length):
            
            # Calling self() or model() in Pytorch calls forward()
            logits, _ = self(idxs, targets=None)
            
            # we only need the one context (as Bigram model). logits shape: (batch_dim, context_len, vocab_len)
            # take the last context in each batch. So one row is taken from each batch giving shape (batch_dim, vocab_len)
            logits = logits[:, -1, :] 
            
            ## create prob_dists and sample ##
            prob_dists = torch.softmax(logits, dim=1) # across the columns (so dim=1). Very Easy when you understand!
            indicies = torch.multinomial(prob_dists, num_samples=1, replacement=True) # (batch_dim, 1) out of the vocab_len one is chosen
            #append onto
            
            idxs = torch.cat((idxs, indicies), dim=1)
        return idxs
            
             
        

In [57]:
bigram = BigramLanguageModel(vocab_size=len(vocab))

In [58]:
X

tensor([[44, 59, 52, 42, 39, 51, 43, 52],
        [52, 43,  5, 43, 56,  1, 57, 53],
        [ 0, 35, 43, 50, 50, 11,  1, 58],
        [53, 58, 46, 43, 56,  6,  1, 40]])

In [77]:
#Randomly initalised Look-up table (we havent done any training yet)
# Select the very first row, which should act as a prob dist 
prob_dist, loss= bigram(idx=X, targets=y)
print(prob_dist)
print(loss)

tensor([[ 0.5626,  0.9114,  2.2276,  ...,  1.0004, -1.9456,  0.6057],
        [ 0.5789, -1.4890, -0.5426,  ...,  0.7079, -0.3116,  0.7640],
        [-0.0932, -0.3452,  0.2444,  ...,  0.7291, -0.7309,  1.2668],
        ...,
        [ 0.0129,  1.2247,  0.7694,  ..., -0.1291,  0.1203, -1.4294],
        [-0.7484, -1.6542,  0.1551,  ...,  0.2770,  1.2927, -0.1394],
        [-0.3436,  2.7319, -1.6450,  ...,  1.0328, -0.8345, -1.7028]],
       grad_fn=<ViewBackward0>)
tensor(4.5137, grad_fn=<NllLossBackward0>)


We expect the loss to be -ln(1/65) ~= 4.17

In [78]:
## Lets generate 10 new characters
idx = bigram.generate(idx=X, max_token_length=10)
print(f"Our input: \n{X}")
print(f"Our output: \n{idx}")

Our input: 
tensor([[44, 59, 52, 42, 39, 51, 43, 52],
        [52, 43,  5, 43, 56,  1, 57, 53],
        [ 0, 35, 43, 50, 50, 11,  1, 58],
        [53, 58, 46, 43, 56,  6,  1, 40]])
Our output: 
tensor([[44, 59, 52, 42, 39, 51, 43, 52, 63, 61,  6, 19, 17, 35, 23, 61, 28, 11],
        [52, 43,  5, 43, 56,  1, 57, 53, 11, 22, 25,  0, 19, 30, 55,  0, 46,  1],
        [ 0, 35, 43, 50, 50, 11,  1, 58,  8, 41, 31,  8, 27, 64, 30, 57, 14, 24],
        [53, 58, 46, 43, 56,  6,  1, 40, 24, 40, 19,  0, 31, 34, 58,  7, 30, 31]])


In [None]:
# starting with newline char as first token in the sequence, generating and then decoding
idx = torch.zeros((1, 1), dtype=torch.int64)
generation = bigram.generate(idx, max_token_length=100)
print(decode(generation.flatten().tolist()))


cSOSGEaciPG&;awqO3hgOuYI:a!Wey-:qnJMJ'uYK;$HQR'aj'ci:SV&ciAUvN.PNHt'EuTO3haC'3y?aIH.VHnvNLriAczgSQu,


#### Train Bigram Model

In [80]:
bigram = BigramLanguageModel(vocab_size=len(vocab))
optimiser = torch.optim.AdamW(bigram.parameters(), lr=1e-3)

In [88]:
num_steps=100_000
for i in range(num_steps):
    
    #create a batch
    X_batch, y_batch = create_batch(split="train")
    
    logits, loss = bigram.forward(idx=X_batch, targets=y_batch)
    loss.backward()
    optimiser.step()
    
    if i % 10_000 == 0:
        print(f"Loss on {i}th iteration: {loss.item():.4f}")
    

Loss on 0th iteration: 3.4301
Loss on 10000th iteration: 2.4263
Loss on 20000th iteration: 2.3831
Loss on 30000th iteration: 3.0552
Loss on 40000th iteration: 3.2765
Loss on 50000th iteration: 2.4752
Loss on 60000th iteration: 2.5668
Loss on 70000th iteration: 2.2361
Loss on 80000th iteration: 2.3899
Loss on 90000th iteration: 3.4464


In [89]:
## Lets generate again...
idx = torch.zeros((1, 1), dtype=torch.int64)
generation = bigram.generate(idx, max_token_length=250)
print(decode(generation.flatten().tolist()))


BELAUS:

A:

BELI yis
Noude atha, there, hie ongh, you woy,
bumeal s pongbe t buprdd, onenos! a ungore a chads s, od.

Asthad.
Se that d S:
Planke, ad ad, at:
Ge
TI liren, Makngrkillau itise bye ngr he toin; r t aved thy f l'le; me spro ad' d sonthyo
