# Building a char-RNN

Inspired by [HOML](https://www.oreilly.com/library/view/hands-on-machine-learning/9798341607972/)

## Create training dataset

In [1]:
# download subset of Shakespeare's works (~1/4)
# gotten from Andrej Karpathy's char-rnn project
from pathlib import Path
import urllib.request

def download_shakespeare_text():
  path = Path('datasets/shakespeare/shakespeare.txt')
  # if not file, make dir
  if not path.is_file():
    path.parent.mkdir(parents=True, exist_ok=True)
    # get data from url
    url = 'https://homl.info/shakespeare'
    urllib.request.urlretrieve(url, path)
  return path.read_text() # grab text data

# invoke func
shakespeare_text = download_shakespeare_text()

In [2]:
# print first couple lines
print(shakespeare_text[:80])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.


## Tokenize

In [3]:
# get list of chars used in text
# turn text lowercase, make a set (unique chars),
# then sort based upon encoded value
vocab = sorted(set(shakespeare_text.lower()))

# mash vocab together into one string
"".join(vocab)

"\n !$&',-.3:;?abcdefghijklmnopqrstuvwxyz"

In [4]:
# assign token id to each character; also, create way to go between word & id
char_to_id = {char: index for index, char in enumerate(vocab)} # dict comp
id_to_char = {index: char for index, char in enumerate(vocab)}

print(char_to_id['a'])
print(id_to_char[13])

13
a


In [5]:
## helper funcs to encode text to tensors of token ids and to decode them
## back to text
import torch

def encode_text(text):
  # pass list comp to tensor
  return torch.tensor([char_to_id[char] for char in text.lower()])

def decode_text(char_ids):
  return "".join([id_to_char[char_id.item()] for char_id in char_ids])

In [6]:
## test it out
encoded = encode_text('Hello, world!')

encoded

tensor([20, 17, 24, 24, 27,  6,  1, 35, 27, 30, 24, 16,  2])

In [7]:
decode_text(encoded)

'hello, world!'

## Create dataset

In [26]:
from torch.utils.data import (Dataset,
                              DataLoader)

class CharDataset(Dataset):
  # dataset initialization
  def __init__(self, text, window_length):
    self.encoded_text = encode_text(text)
    self.window_length = window_length

  def __len__(self):
    # get length of dataset (have to remove window length)
    # remember, this is a sliding window
    return len(self.encoded_text) - self.window_length

  def __getitem__(self, idx):
    if idx >= len(self):
      raise IndexError('dataset index out of range') # id beyond range of data
    # ending points = to current idx + window length of data
    end = idx + self.window_length
    window = self.encoded_text[idx : end]
    # shift by 1
    target = self.encoded_text[idx + 1: end + 1]
    return window, target

In [28]:
## create data loaders
window_length = 50
batch_size = 256

train_set = CharDataset(shakespeare_text[:1_000_000], window_length)
valid_set = CharDataset(shakespeare_text[1_000_000:1_060_000], window_length)
test_set = CharDataset(shakespeare_text[1_060_000:], window_length)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=batch_size)
test_loader = DataLoader(test_set, batch_size=batch_size)

## Building and Training char-RNN Model

In [31]:
import torch.nn as nn
class ShakespeareModel(nn.Module):
  def __init__(self, vocab_size, n_layers = 2, embed_dim = 10, hidden_dim = 128,
               dropout = 0.1):
    super().__init__()
    # embedding matrix
    self.embed = nn.Embedding(vocab_size, embed_dim)
    # gru (rnn) layers
    self.gru = nn.GRU(embed_dim, hidden_dim, n_layers, dropout = dropout,
                      batch_first = True)
    # output layer
    self.output = nn.Linear(hidden_dim, vocab_size)

  # create forward pass functionality
  def forward(self,X):
    embeddings = self.embed(X) # get embeddings

    outputs, _states = self.gru(embeddings) # get outputs

    logits = self.output(outputs).permute(0,2,1) # get logits; swap 2 & 1 dims
                                                 # need class dim to be 2 dim
    return logits

torch.manual_seed(42) # set seed

# detect device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# create instance of model
model = ShakespeareModel(len(vocab)).to(device)

In [11]:
!pip install torchmetrics

Collecting torchmetrics
  Downloading torchmetrics-1.8.2-py3-none-any.whl.metadata (22 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Downloading torchmetrics-1.8.2-py3-none-any.whl (983 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m25.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.15.2-py3-none-any.whl (29 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.15.2 torchmetrics-1.8.2


In [12]:
import torchmetrics

def evaluate_tm(model, data_loader, metric):
    '''evaluates model on data loader.'''
    model.eval()
    metric.reset()
    with torch.no_grad():
        for X_batch, y_batch in data_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            y_pred = model(X_batch)
            metric.update(y_pred, y_batch)
    return metric.compute()

def train(model, optimizer, loss_fn, metric, train_loader, valid_loader,
          n_epochs, patience=2, factor=0.5, epoch_callback=None):
    ''' trains model and keeps train and validation.'''
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="max", patience=patience, factor=factor)
    history = {"train_losses": [], "train_metrics": [], "valid_metrics": []}
    for epoch in range(n_epochs):
        total_loss = 0.0
        metric.reset()
        model.train()
        if epoch_callback is not None:
            epoch_callback(model, epoch)
        for index, (X_batch, y_batch) in enumerate(train_loader):
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            y_pred = model(X_batch)
            loss = loss_fn(y_pred, y_batch)
            total_loss += loss.item()
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            metric.update(y_pred, y_batch)
            train_metric = metric.compute().item()
            print(f"\rBatch {index + 1}/{len(train_loader)}", end="")
            print(f", loss={total_loss/(index+1):.4f}", end="")
            print(f", {train_metric=:.2%}", end="")
        history["train_losses"].append(total_loss / len(train_loader))
        history["train_metrics"].append(train_metric)
        val_metric = evaluate_tm(model, valid_loader, metric).item()
        history["valid_metrics"].append(val_metric)
        scheduler.step(val_metric)
        print(f"\rEpoch {epoch + 1}/{n_epochs},                      "
              f"train loss: {history['train_losses'][-1]:.4f}, "
              f"train metric: {history['train_metrics'][-1]:.2%}, "
              f"valid metric: {history['valid_metrics'][-1]:.2%}")
    return history

In [32]:
n_epochs = 20
xentropy = nn.CrossEntropyLoss()
optimizer = torch.optim.NAdam(model.parameters())
accuracy = torchmetrics.Accuracy(task="multiclass",
                                 num_classes=len(vocab)).to(device)

history = train(model, optimizer, xentropy, accuracy, train_loader, valid_loader,
                n_epochs)

Epoch 1/20,                      train loss: 1.5213, train metric: 53.30%, valid metric: 53.37%
Epoch 2/20,                      train loss: 1.3738, train metric: 56.93%, valid metric: 53.93%
Epoch 3/20,                      train loss: 1.3535, train metric: 57.43%, valid metric: 53.97%
Epoch 4/20,                      train loss: 1.3408, train metric: 57.77%, valid metric: 54.33%
Epoch 5/20,                      train loss: 1.3304, train metric: 58.05%, valid metric: 54.54%
Epoch 6/20,                      train loss: 1.3257, train metric: 58.16%, valid metric: 54.51%
Epoch 7/20,                      train loss: 1.3225, train metric: 58.24%, valid metric: 54.62%
Epoch 8/20,                      train loss: 1.3200, train metric: 58.30%, valid metric: 54.66%
Epoch 9/20,                      train loss: 1.3182, train metric: 58.34%, valid metric: 54.43%
Epoch 10/20,                      train loss: 1.3165, train metric: 58.39%, valid metric: 54.59%
Epoch 11/20,                      train

In [33]:
## use trained model to predict next char in sequence
model.eval() # no updates
text = 'To be or not to b'
encoded_text = encode_text(text).unsqueeze(dim=0).to(device) # add dim

with torch.no_grad():
  Y_logits = model(encoded_text)
  predicted_char_id = Y_logits[0, :, -1].argmax().item() # get predicted next char
  predicted_char = id_to_char[predicted_char_id]

print(f'{text} -> {predicted_char}')

To be or not to b -> e


## Generating Fake Shakespearean Text

In [34]:
## use multinomial to sample indices given class probs
torch.manual_seed(42)
probs = torch.tensor([0.5, 0.4, 0.1])

## example
samples = torch.multinomial(probs, replacement=True, num_samples = 8)
samples

tensor([0, 0, 0, 0, 1, 0, 2, 2])

In [35]:
## func to use temperature to pick next char to add to input txt
import torch.nn.functional as F

def next_char(model, text, temperature=1.0):
  encoded_text = encode_text(text).unsqueeze(dim=0).to(device) # add dim

  with torch.no_grad():
    Y_logits = model(encoded_text)
    # get logits for last step and div by temp
    Y_probs = F.softmax(Y_logits[0, :, -1] / temperature, dim=-1)
    # get predicted next char
    predicted_char_id = torch.multinomial(Y_probs, num_samples=1).item()

  return id_to_char[predicted_char_id]

In [36]:
# func to repreatedly call next_char() and append char to text
def extend_text(model, text, n_chars=80, temperature=1.0):
  for _ in range(n_chars):
    text += next_char(model, text, temperature)
  return text

In [37]:
## low temp
print(extend_text(model, 'To be or not to b', temperature=0.01))

To be or not to be the state,
and the state and the state and the state and the state and the sta


In [38]:
## medium temp
print(extend_text(model, 'To be or not to b', temperature=0.4))

To be or not to be so far
and be not the heavens will sit and end the first stay with my death
of


In [39]:
## high temp
print(extend_text(model, 'To be or not to b', temperature=100))

To be or not to bmhf:my:r,k;s-h cqvvnfnfsut&-oq'ryoeen?x-hp:d,y&wv f3,dzrdzj-p$lv?xpzc,fborp;'?$u
