Thank you to HOMLP and Andrej Karphathy (https://karpathy.github.io/2015/05/21/rnn-effectiveness/) for guide material.

In [109]:
from pathlib import Path
import urllib.request

def download_shakespeare_text():
  path = Path('datasets/shakespeare/shakespeare.txt')
  return path.read_text()

shakespeare_text = download_shakespeare_text()
shakespeare_text[:1000]

"First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you know Caius Marcius is chief enemy to the people.\n\nAll:\nWe know't, we know't.\n\nFirst Citizen:\nLet us kill him, and we'll have corn at our own price.\nIs't a verdict?\n\nAll:\nNo more talking on't; let it be done: away, away!\n\nSecond Citizen:\nOne word, good citizens.\n\nFirst Citizen:\nWe are accounted poor citizens, the patricians good.\nWhat authority surfeits on would relieve us: if they\nwould yield us but the superfluity, while it were\nwholesome, we might guess they relieved us humanely;\nbut they think we are too dear: the leanness that\nafflicts us, the object of our misery, is as an\ninventory to particularise their abundance; our\nsufferance is a gain to them Let us revenge this with\nour pikes, ere we become rakes: for the gods know I\nspeak this in hunger 

In [64]:
vocab = sorted(set(shakespeare_text.lower()))
print(vocab)

['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']


In [65]:
char_to_id = {char: index for index, char in enumerate(vocab)}
id_to_char = {index: char for index, char in enumerate(vocab)}
id_to_char[13]

'a'

In [66]:
import torch

def encode_text(text):
  return torch.tensor([char_to_id[char] for char in text.lower()])

def decode_text(char_ids):
  return "".join([id_to_char[char_id.item()] for char_id in char_ids])

In [110]:
encoded = encode_text("mama jama")
decode_text(encoded)

'mama jama'

In [111]:
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

device

'cuda'

In [69]:
from torch.utils.data import Dataset, DataLoader

class CharDataset(Dataset):
  def __init__(self, text, window_length):
    self.encoded_text = encode_text(text)
    self.window_length = window_length

  def __len__(self):
    return len(self.encoded_text) - self.window_length

  def __getitem__(self, idx):
      if idx >= len(self):
          raise IndexError("Index out of range")
      end = idx + self.window_length
      window = self.encoded_text[idx:end]
      target = self.encoded_text[idx+1:end+1]
      return window, target


In [70]:
window_length = 50
batch_size = 512

train_set = CharDataset(shakespeare_text[:1_000_000], window_length)
valid_set = CharDataset(shakespeare_text[1_000_000:1_060_000], window_length)
test_set = CharDataset(shakespeare_text[1_060_000:], window_length)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=batch_size)
test_loader = DataLoader(test_set, batch_size=batch_size)

In [71]:
import torch.nn as nn

In [72]:
class ShakespeareModel(nn.Module):
  def __init__(self, vocab_size, n_layers=2, embed_dim=10, hidden_dim=128, dropout=0.1):
    super().__init__()
    self.embed = nn.Embedding(vocab_size, embed_dim)
    self.gru = nn.GRU(embed_dim, hidden_dim, num_layers=n_layers, dropout=dropout, batch_first=True)
    self.output = nn.Linear(hidden_dim, vocab_size)

  def forward(self, X):
    embeddings = self.embed(X)
    outputs, _states = self.gru(embeddings)
    return self.output(outputs).permute(0,2,1)

torch.manual_seed(42)
model = ShakespeareModel(len(vocab)).to(device)

In [73]:
### Reusing evaluate and train functions from past notebooks:
!pip install torchmetrics
import torchmetrics

def evaluate_tm(model, data_loader, metric):
    model.eval()
    metric.reset()
    with torch.no_grad():
        for X_batch, y_batch in data_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            y_pred = model(X_batch)
            metric.update(y_pred, y_batch)
    return metric.compute()

def train(model, optimizer, loss_fn, metric, train_loader, valid_loader,
          n_epochs, patience=2, factor=0.5, epoch_callback=None):
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="max", patience=patience, factor=factor)
    history = {"train_losses": [], "train_metrics": [], "valid_metrics": []}
    for epoch in range(n_epochs):
        total_loss = 0.0
        metric.reset()
        model.train()
        if epoch_callback is not None:
            epoch_callback(model, epoch)
        for index, (X_batch, y_batch) in enumerate(train_loader):
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            y_pred = model(X_batch)
            loss = loss_fn(y_pred, y_batch)
            total_loss += loss.item()
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            metric.update(y_pred, y_batch)
            train_metric = metric.compute().item()
            print(f"\rBatch {index + 1}/{len(train_loader)}", end="")
            print(f", loss={total_loss/(index+1):.4f}", end="")
            print(f", {train_metric=:.2%}", end="")
        history["train_losses"].append(total_loss / len(train_loader))
        history["train_metrics"].append(train_metric)
        val_metric = evaluate_tm(model, valid_loader, metric).item()
        history["valid_metrics"].append(val_metric)
        scheduler.step(val_metric)
        print(f"\rEpoch {epoch + 1}/{n_epochs},                      "
              f"train loss: {history['train_losses'][-1]:.4f}, "
              f"train metric: {history['train_metrics'][-1]:.2%}, "
              f"valid metric: {history['valid_metrics'][-1]:.2%}")
    return history



In [74]:
n_epochs = 20
xentropy = nn.CrossEntropyLoss()
optimizer = torch.optim.NAdam(model.parameters())
accuracy = torchmetrics.Accuracy(task="multiclass",
                                 num_classes=len(vocab)).to(device)

history = train(model, optimizer, xentropy, accuracy, train_loader, valid_loader,
                n_epochs)

Epoch 1/20,                      train loss: 1.6040, train metric: 51.28%, valid metric: 51.98%
Epoch 2/20,                      train loss: 1.3843, train metric: 56.72%, valid metric: 52.83%
Epoch 3/20,                      train loss: 1.3547, train metric: 57.46%, valid metric: 53.64%
Epoch 4/20,                      train loss: 1.3403, train metric: 57.81%, valid metric: 53.45%
Epoch 5/20,                      train loss: 1.3320, train metric: 58.02%, valid metric: 53.32%
Epoch 6/20,                      train loss: 1.3264, train metric: 58.16%, valid metric: 53.79%
Epoch 7/20,                      train loss: 1.3225, train metric: 58.26%, valid metric: 53.71%
Epoch 8/20,                      train loss: 1.3193, train metric: 58.34%, valid metric: 54.09%
Epoch 9/20,                      train loss: 1.3167, train metric: 58.40%, valid metric: 54.33%
Epoch 10/20,                      train loss: 1.3148, train metric: 58.45%, valid metric: 54.08%
Epoch 11/20,                      train

In [75]:
model.eval()

ShakespeareModel(
  (embed): Embedding(39, 10)
  (gru): GRU(10, 128, num_layers=2, batch_first=True, dropout=0.1)
  (output): Linear(in_features=128, out_features=39, bias=True)
)

In [112]:
text = "to be or n"
encoded_text = encode_text(text).unsqueeze(dim=0).to(device)
with torch.no_grad():
  Y_logits = model(encoded_text)
  predicted_char_id = Y_logits[0, :, -1].argmax().item()
  predicted_char = id_to_char[predicted_char_id]

predicted_char

'o'

In [89]:
import torch.nn.functional as F

def next_char(model, text, temperature=1):
  encoded_text = encode_text(text).unsqueeze(dim=0).to(device)
  with torch.no_grad():
    Y_logits = model(encoded_text)
    Y_probas = F.softmax(Y_logits[0,:,-1]/temperature, dim=-1)
    predicted_char_id = torch.multinomial(Y_probas, num_samples=1).item()
  return id_to_char[predicted_char_id]

In [105]:
def extend_text(model, text, n_chars=200, temperature=1):
  for _ in range(n_chars):
    text+= next_char(model, text, temperature)
  return text

In [108]:
print(extend_text(model, "To be or n", temperature=0.4))

To be or not thy son,
the common vault for his house of heart the bear.

first citizen:
and why the princes to my son, and the destroy'd them off.

first servingman:
come, come, sir, a court the people to them 


In [114]:
### Greedy Decoding Version

def next_greedy_char(model, text):
  encoded_text = encode_text(text).unsqueeze(dim=0).to(device)
  with torch.no_grad():
    Y_logits = model(encoded_text)
    predicted_char_id = Y_logits[0, :, -1].argmax().item()
  return id_to_char[predicted_char_id]

def extend_greedy_text(model, text, n_chars=200):
  for _ in range(n_chars):
    text+= next_greedy_char(model, text)
  return text

print(extend_greedy_text(model, "Shakespeare is just like Cartman, only"))

Shakespeare is just like Cartman, only the country,
and the common that shall be the county and the state of the county and the county
and the county and the county and the county and the county and the county
and the county and the count
