# Hyperparameters and libraries

In [1]:
from pathlib import Path
import random
from typing import List
import re
import torch
from torch import optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from tqdm import tqdm
from utils import get_text
import torch.nn as nn
from datetime import datetime

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")

device: cuda


# Character-wise tokening
Here we will take each character as a unique token.

In [2]:
batch_size = 64
seq_length = 100
epochs = 30

## Dataset

We will create custom `Dataset` class that can be used by dataloader with additional methods for our use.

In [3]:
class TextDatasetPlus(Dataset):
    """
    Dataset class that
    1. takes path of data-files,
    2. extracts texts
    This class itself becomes the Dataset class that can be used for DataLoader.
    """
    def __init__(self, data_paths: List[Path], seq_length: int = 10, batch_size: int = 64):
        self.paths: List[Path] = data_paths
        self.text: str = get_text(self.paths)
        self.tokens = list(set(self.tokenize_text(self.text)))
        self.tokens.sort()
        self.vocab_size = len(self.tokens)
        self.token_to_id = {token: idx for idx, token in enumerate(self.tokens)}
        self.id_to_token = {idx: token for idx, token in enumerate(self.tokens)}
        self.seq_length = seq_length
        self.batch_size = batch_size
        self.text_as_token_ids = [self.token_to_id[w] for w in self.tokenize_text(self.text)]

    def __len__(self):
        return len(self.text_as_token_ids) - self.seq_length

    def __getitem__(self, index):
        return (
            torch.tensor(self.text_as_token_ids[index:index+self.seq_length], dtype=torch.long).to(device),
            torch.tensor(self.text_as_token_ids[index+1:index+self.seq_length+1], dtype=torch.long).to(device),
        )

    def tokenize_text(self, text: str) -> list:
        return list(text)

    def textify_tokens(self, tokens: list) -> str:
        return ''.join(tokens)


In [4]:
data_path = Path('./data/Shakespeare')
dataset = TextDatasetPlus([data_path], seq_length=seq_length, batch_size=batch_size)
vocab_size = dataset.vocab_size

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

print(
    f"DATASET:  {data_path}\n"
    f"Vocab size: {dataset.vocab_size}  Corpus size: {len(dataset.text_as_token_ids)}"
    f"\tNumber of batches:{len(dataloader)} \t batch_size = {batch_size} \t seq_length = {seq_length}"
)

DATASET:  data\Shakespeare
Vocab size: 65  Corpus size: 1215376Number of batches:18988 	 batch_size = 64 	 seq_length = 100


## Understanding dataset

For each call to `dataset.__getitem__()` we get input and output of length `seq_length`.  
Example: for text, "To be or not to be" and `seq_length=12`, it would give:  
```python
(
    ['T','o',' ','b','e',' ','o','r',' ','n','o','t',' ','t','o',' ','b'],
    ['o',' ','b','e',' ','o','r',' ','n','o','t',' ','t','o',' ','b','e'],
)
# except that these would be their corresponding numerical-token-ids instead of character.
```

A batch from dataloader would have `batch_size` many of above things.  

Here we will train the model to predict the next character, given the current character and hidden state.  
The input-prediction mapping would look as follows:   
```python
Input:     ['T','o',' ','b','e',' ','o','r',' ','n','o','t',' ','t','o',' ','b'],
             |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |
Target:    ['o',' ','b','e',' ','o','r',' ','n','o','t',' ','t','o',' ','b','e'],
```

In [5]:
# Sample __getitem__()
X, Y = dataset.__getitem__(1)
print(f"Input-Target token-ids from __getitem__():\n{X}\n{Y}")
print(f"---------------- input  -- shape: {X.shape}")
print(''.join([dataset.id_to_token[x] for x in X.cpu().tolist()]))
print(f"---------------- target  -- shape: {Y.shape}")
print(''.join([dataset.id_to_token[x] for x in Y.cpu().tolist()]))
print(f"----------------")

Input-Target token-ids from __getitem__():
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59], device='cuda:0')
tensor([47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44, 53,
        56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,  1,
        44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1, 57,
        54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,  6,
         1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47, 58,
        47, 64, 43, 52, 10,  0, 37, 53, 59,  1], device='cuda:0')
---------------- input  -

In [6]:
# a batch
X, Y = next(iter(dataloader))
print(f"Input shape: {X.shape} \t Target shape: {Y.shape}")

Input shape: torch.Size([64, 100]) 	 Target shape: torch.Size([64, 100])


## Model
We will make use of GRU layers.

In [7]:
class RNNGen(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers):
        super().__init__()
        self.gru_layers = num_layers
        self.gru_dim = hidden_dim
        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=embed_dim
        )
        self.gru = nn.GRU(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True
        )
        self.out = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, states):
        x = self.embedding(x)
        gru_out, states = self.gru(x, states)
        out = self.out(gru_out)
        return out, states

    def init_state(self, batch_size: int):
        """
        Returns the initial hidden state for gru layer.
        """
        return torch.zeros(self.gru_layers, batch_size, self.gru_dim).to(device)

In [8]:
embed_dim = 256
hidden_dim = 1024
num_layers = 1
dropout = 0.2
batch_size = 64
seq_length = 100
epochs = 30

model = RNNGen(vocab_size=vocab_size, embed_dim=embed_dim, hidden_dim=hidden_dim, num_layers=num_layers).to(device)
print(f"\nMODEL: {model}")


MODEL: RNNGen(
  (embedding): Embedding(65, 256)
  (gru): GRU(256, 1024, batch_first=True)
  (out): Linear(in_features=1024, out_features=65, bias=True)
)


## pre-train performance

In [9]:
def predict(dataset, model, seed: str = None, gen_length=100, eval_mode: bool=True):
    if seed is None:
        tokens = [random.choice(dataset.tokens)]
    else:
        tokens = dataset.tokenize_text(seed)
    if eval_mode:
        model.eval()
    state = model.init_state(batch_size=1)
    for i in range(0, gen_length):
        x = torch.tensor([[dataset.token_to_id[w] for w in tokens[i:]]]).to(device)
        y_pred, state = model(x, state)
        state = state.detach()
        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()
        token_index = np.random.choice(len(last_word_logits), p=p)
        tokens.append(dataset.id_to_token[token_index])
    return dataset.textify_tokens(tokens)

In [10]:
print(predict(dataset, model))

kKP
E
arMslvgICdo
Tq jtVmlvb!XZQDoy hL;FtVupcGrrLuUDdOl-FarHoDWWgYXkVy;!unyRxzzz&NAZkWyrMnOVPb3H?T-Ec


## Creating new model

### Training
Loss: `CrossEntropyLoss()` as this can be seen as classification problem:
> where (cuurent_character, hidden_state) need to be classified as belonging to (class of one of the unique characters)
  
Optimizer: `Adam`  

We also have option to save the checkpoints and generate sample-text after each epoch.

In [11]:
def train(dataloader, model, max_epochs, sequence_length, checkpoints: bool = True, sample_text=True):
    model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    for epoch in range(1, max_epochs+1):
        print(f"\nEPOCH : {epoch}")
        state = model.init_state(batch_size=batch_size)
        for batch, (x, y) in (pbar := tqdm(enumerate(dataloader))):
            optimizer.zero_grad()

            y_pred, state = model(x, state)
            loss = criterion(y_pred.transpose(1, 2), y)

            state = state.detach()

            loss.backward()
            optimizer.step()
            if batch % 1000 == 0:
                pbar.set_description(f'epoch: {epoch}/{max_epochs} \t batch: {batch} \t loss: {loss.item()} \t')
        if checkpoints:
            torch.save(
                model.state_dict(),
                Path(f'./models/checkpoints/{datetime.now().strftime("%d_%m_%Y__%H_%M_%S")}__{epoch}__char')
            )
        if sample_text:
            print(
                f"Generated text:\n"
                f"-------------------\n"
                f"{predict(dataset, model, eval_mode=False)}\n"
                f"-------------------\n"
            )

In [12]:
print("\nSTARTING TRAINING:")
train(dataloader, model, max_epochs=epochs, sequence_length=seq_length)
print("TRAINING COMPLETE.\n")


STARTING TRAINING:

EPOCH : 1


epoch: 1/30 	 batch: 18000 	 loss: 1.379300594329834 	: : 18988it [15:18, 20.68it/s] 


Generated text:
-------------------
-gravive of all,
I sup peace the bore to kill'd,
Call who so I being eyes of hight,
And then will we 
-------------------


EPOCH : 2


epoch: 2/30 	 batch: 18000 	 loss: 1.5849789381027222 	: : 18988it [14:11, 22.30it/s]


Generated text:
-------------------
.
Alclas! I am on humpest answel then the lack'd his libers,
Tommovent blesperate is discomes
Unpent 
-------------------


EPOCH : 3


epoch: 3/30 	 batch: 18000 	 loss: 1.1817270517349243 	: : 18988it [14:05, 22.46it/s]


Generated text:
-------------------
her way
deformation; in a school;
And will you alter'd, I say.
Fere
is the worthy fault
I am few or g
-------------------


EPOCH : 4


epoch: 4/30 	 batch: 18000 	 loss: 1.1623046398162842 	: : 18988it [14:04, 22.48it/s]


Generated text:
-------------------
queansk this arms
As I will not, say your hearts,
That suage my deed'st not part with blesshrely.

BE
-------------------


EPOCH : 5


epoch: 5/30 	 batch: 18000 	 loss: 1.9458537101745605 	: : 18988it [14:04, 22.49it/s]


Generated text:
-------------------
&ME:
SEBAUTOLY:
Happre geores:
Thishis my le! LAURS:
But is know will Look.

LARIN SARRENCE:
I not bl
-------------------


EPOCH : 6


epoch: 6/30 	 batch: 18000 	 loss: 1.9079746007919312 	: : 18988it [14:02, 22.54it/s]


Generated text:
-------------------
,
What armsen farend me?s,
That pe cherpit.
VOLUMNNIO:
Yen, Pe so my pas.
Then meage, dainquit.
Why, 
-------------------


EPOCH : 7


epoch: 7/30 	 batch: 18000 	 loss: 1.9395318031311035 	: : 18988it [14:02, 22.54it/s]


Generated text:
-------------------
vUCESTA:
Look uizeve my Cloud not Manghight whereepll,
Sitize,
BRUCHASIS:
I lovers,
What hearcarry wh
-------------------


EPOCH : 8


epoch: 8/30 	 batch: 18000 	 loss: 1.9247877597808838 	: : 18988it [14:02, 22.54it/s]


Generated text:
-------------------
gc?OMNTESTrung'd.
Yemby infysweverckisheirstrm.
Mou hang onk, antionds
at,
Marrn dafurry knful uponde
-------------------


EPOCH : 9


epoch: 9/30 	 batch: 18000 	 loss: 1.9846082925796509 	: : 18988it [14:02, 22.54it/s]


Generated text:
-------------------
,&THYCUTHN:
Unow,--have ow, Bupord; trulds vid know, I frowead; know my tr my darrdid Comen their.
SI
-------------------


EPOCH : 10


epoch: 10/30 	 batch: 18000 	 loss: 1.7498091459274292 	: : 18988it [14:01, 22.57it/s]


Generated text:
-------------------
Q$HAnd's my bad, Edwatozew,
Whind.
Grown the pitizenting,
Go, dides macce:
Oure crentild, ife.
Urits 
-------------------


EPOCH : 11


epoch: 11/30 	 batch: 18000 	 loss: 1.3973308801651 	: : 18988it [14:00, 22.58it/s]   


Generated text:
-------------------
T: His ham
Shallath towburn, courtish dear.

DUKE VINCENTIO:
This i' it kisd, he faitted to and like 
-------------------


EPOCH : 12


epoch: 12/30 	 batch: 18000 	 loss: 1.2981277704238892 	: : 18988it [14:01, 22.57it/s]


Generated text:
-------------------
GHAMIANDA:
Yond mooming?
Hadst, counse butters, jest wsheap in Barname,
My and amender may makes hope
-------------------


EPOCH : 13


epoch: 13/30 	 batch: 18000 	 loss: 1.2730536460876465 	: : 18988it [14:01, 22.57it/s]


Generated text:
-------------------
izy:
Yet by butt be cand?

GRUMIO:
Anon!
The gods he case of joyful next doot in are here you have hi
-------------------


EPOCH : 14


epoch: 14/30 	 batch: 18000 	 loss: 1.2160072326660156 	: : 18988it [14:00, 22.59it/s]


Generated text:
-------------------
. dides thee to hardness
Did pardon'd the tran abhor unged
Shall here in sole in your father.

POMPEY
-------------------


EPOCH : 15


epoch: 15/30 	 batch: 18000 	 loss: 1.2180771827697754 	: : 18988it [14:01, 22.58it/s]


Generated text:
-------------------
ET:
Ming him and what you come?

ABHORSON:
No, hear, as just
Murgeneth thanks to the bed.

HASTINA:
T
-------------------


EPOCH : 16


epoch: 16/30 	 batch: 18000 	 loss: 1.1997156143188477 	: : 18988it [14:00, 22.59it/s]


Generated text:
-------------------
UKE VINCENTINENIUS:
Let's have but peace thou fierceives
With precute my doth before I may be as ther
-------------------


EPOCH : 17


epoch: 17/30 	 batch: 18000 	 loss: 1.203384518623352 	: : 18988it [14:00, 22.58it/s] 


Generated text:
-------------------
t reator,
A crew shall be the famous walls,
And peace. A prized into the tame most purt
Is secretty p
-------------------


EPOCH : 18


epoch: 18/30 	 batch: 18000 	 loss: 1.195938229560852 	: : 18988it [14:01, 22.56it/s] 


Generated text:
-------------------
KY VI:
A banius?

Nurse:
God kept this hundred bleed?
ANGELO:
'Tis sunder great king.

PRINCE EDWARD:
-------------------


EPOCH : 19


epoch: 19/30 	 batch: 18000 	 loss: 1.1954779624938965 	: : 18988it [14:01, 22.57it/s]


Generated text:
-------------------
MBRO:
What too?

ANTIGONUS:
Are you only better:
The lists you the briddling him all them Forth
Woung
-------------------


EPOCH : 20


epoch: 20/30 	 batch: 18000 	 loss: 1.1936008930206299 	: : 18988it [14:01, 22.57it/s]


Generated text:
-------------------
$lll, knew my duct.

First Citizen:
We had palms:
That in with his head to my king,
When Rome, to but
-------------------


EPOCH : 21


epoch: 21/30 	 batch: 18000 	 loss: 1.158156394958496 	: : 18988it [14:01, 22.56it/s] 


Generated text:
-------------------
RIUS:
Go, go hond, but born, desperpetire here! Your word
That hath bud that man lies himself, are ma
-------------------


EPOCH : 22


epoch: 22/30 	 batch: 18000 	 loss: 1.1933094263076782 	: : 18988it [14:01, 22.57it/s]


Generated text:
-------------------
XFFA:
This beast the dead-beard,
After to cour yourself too of Bolicy and receive
tay who see senses 
-------------------


EPOCH : 23


epoch: 23/30 	 batch: 18000 	 loss: 1.1669394969940186 	: : 18988it [14:00, 22.58it/s]


Generated text:
-------------------
Y?

MERCUTIO:
That's my mother.

GLOUCESTER:
Sir, my kneeds us bowe,
Respecting fire from his new cou
-------------------


EPOCH : 24


epoch: 24/30 	 batch: 18000 	 loss: 1.1842360496520996 	: : 18988it [14:01, 22.57it/s]


Generated text:
-------------------
?RIA:
What, hunous in their king?

CAMILLO:
I will wast dead of Richard in share to hews.
Would I sha
-------------------


EPOCH : 25


epoch: 25/30 	 batch: 18000 	 loss: 1.2100005149841309 	: : 18988it [14:01, 22.57it/s]


Generated text:
-------------------
PHIIS:
The bodiest my bitterns none hand
Where none. This is the midle upon.

KING HENRY VI:
Far asho
-------------------


EPOCH : 26


epoch: 26/30 	 batch: 18000 	 loss: 1.1597119569778442 	: : 18988it [14:01, 22.58it/s]


Generated text:
-------------------
DWhod absence,
Taught and so quaint up with the enemies our crow,
And I'll woe to the bold afternoon;
-------------------


EPOCH : 27


epoch: 27/30 	 batch: 18000 	 loss: 1.193662166595459 	: : 18988it [14:02, 22.54it/s] 


Generated text:
-------------------
RUNUS:
This is question
There o' their feverforces and down our knees.

CAMILLO:
Who, holour beed to 
-------------------


EPOCH : 28


epoch: 28/30 	 batch: 18000 	 loss: 1.2368059158325195 	: : 18988it [14:01, 22.56it/s]


Generated text:
-------------------
Qz
Forge firm, God's and I:
Her dry dies! he for jest day!

BENVOLIO:
Tuthody, poor with a borneed'st
-------------------


EPOCH : 29


epoch: 29/30 	 batch: 18000 	 loss: 1.1921899318695068 	: : 18988it [14:02, 22.54it/s]


Generated text:
-------------------
DWOS do. Why, alack, what Henry neither was dried;
incensmy household! what couls must break.

KING R
-------------------


EPOCH : 30


epoch: 30/30 	 batch: 18000 	 loss: 1.2401832342147827 	: : 18988it [14:02, 22.55it/s]


Generated text:
-------------------
ax-None.

DUKE VINCENTIO:
But when, then I was his life, and hold
Shall have you:
This o two will, si
-------------------

TRAINING COMPLETE.



### Generation sample

In [13]:
generated = predict(
    dataset=dataset,
    model=model,
    seed=None,       # starting token; if None, randomly selects starting token
    gen_length=200,  # count of new tokens to generate
    eval_mode=True   # set model.eval() for prediction
)
print(generated)

INNE:
Virtue hark you!

DUKE VINCENTIO:
I shall have found, but thou rest in April and me; you'll this knows, as for your love
Than as thou deceive, lest thou slay with these wonder'd both your life to


## Loading trained model

In [18]:
model = RNNGen(vocab_size=vocab_size, embed_dim=embed_dim, hidden_dim=hidden_dim, num_layers=num_layers).to(device)
model.load_state_dict(torch.load(Path('./models/checkpoints/27_07_2024__05_48_09__30__char'), weights_only=True))

<All keys matched successfully>

In [21]:
generated = predict(
    dataset=dataset,
    model=model,
    seed=None,       # starting token; if None, randomly selects starting token
    gen_length=200,  # count of new tokens to generate
    eval_mode=True   # set model.eval() for prediction
)
print(generated)

Y PERTRUTUS:
O, even by hand, who ready was then were accomplain:
By heaven be less this damnable there?

ARIEL:
No, sir, your wive of our shall be to hear mock asmilest ever:
Since this fires the youn


# Word-wise tokening