In [1]:
import numpy as np
import torch
from torch import nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from collections import Counter
import re
from PIL import Image

import os
import nltk
from nltk.corpus import stopwords
nltk.download('stopwords')
stopwords = set(stopwords.words("english"))

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


In [2]:
path_to_data = "/content/harry_potter"

In [3]:
text_files = os.listdir(path_to_data)

In [4]:
all_text = " "

for text in text_files:
  path_to_book = os.path.join(path_to_data, text)

  with open(path_to_book, 'r') as f:
    text = f.readlines()

  text = [line for line in text if "Page" not in line]
  text = " ".join(text).replace("\n", "")
  text = [word for word in text.split(" ") if len(word) > 0]
  text = " ".join(text)
  all_text = text

In [5]:
unique_chars = sorted(list(set(all_text)))

In [6]:
print(unique_chars)

[' ', '!', '"', "'", '(', ')', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '9', ':', ';', '>', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '\\', ']', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '—', '‘', '’', '“', '”', '•']


In [7]:
len(unique_chars)

83

In [8]:
char2idx = {c:i for (i, c) in enumerate(unique_chars)}

In [9]:
word2idx = {i:c for (i, c) in enumerate(unique_chars)}

In [10]:
len(text)

1121342

In [11]:
class DataBuilder:
  def __init__(self, seq_len = 200, text = all_text):
    self.seq_len = seq_len
    self.text = text
    self.file_length = len(text)

  def grab_random_sample(self):

    start = np.random.randint(0, self.file_length - self.seq_len)
    end = start + self.seq_len

    text_slice = self.text[start:end]

    input_text = text_slice[:-1]
    label = text_slice[1:]

    input_text = torch.tensor([char2idx[c] for c in input_text])
    label = torch.tensor([char2idx[c] for c in label])

    return input_text, label

  def grab_random_batch(self, batch_size):
    input_texts, labels = [], []

    for _ in range(batch_size):
      input_text, label = self.grab_random_sample()

      input_texts.append(input_text)
      labels.append(label)

    input_texts = torch.stack(input_texts)
    labels = torch.stack(labels)

    return input_texts, labels

In [12]:
dataset = DataBuilder(seq_len = 10)
input_texts, labels = dataset.grab_random_batch(batch_size=4)

## **Model**

In [13]:
class LSTMGenerator(nn.Module):
  def __init__(self, embedding_dim = 120, num_characters = len(char2idx), hidden_size = 256, num_layers = 3, device = "cpu"):
    super().__init__()

    self.embedding_dim = embedding_dim
    self.num_characters = num_characters
    self.hidden_size = hidden_size
    self.num_layers = num_layers
    self.device = device

    self.embedding = nn.Embedding(num_characters, embedding_dim)
    self.lstm = nn.LSTM(embedding_dim, hidden_size, num_layers, batch_first = True)

    self.fc = nn.Linear(hidden_size, num_characters)

    self.softmax = nn.Softmax(dim = -1)

  def forward(self, x):
    x = self.embedding(x)
    output, (h, c) = self.lstm(x)
    logits = self.fc(output)

    return logits

  def write(self, text, max_characters, greedy = False):
    idx = torch.tensor([char2idx[c] for c in text], device = self.device)

    hidden = torch.zeros(self.num_layers, self.hidden_size).to(self.device)
    cell = torch.zeros(self.num_layers, self.hidden_size).to(self.device)

    for i in range(max_characters):
      if i == 0:
        selected_idx = idx
      else:
        selected_idx = idx[-1].unsqueeze(0)

      x = self.embedding(selected_idx)
      output, (hidden, cell) = self.lstm(x, (hidden, cell))
      output = self.fc(output)

      if len(output) > 1:
        output = output[-1, :].unsqueeze(0)
      probs = self.softmax(output)
      # print(probs)

      if greedy:
        idx_next = torch.argmax(probs)
      else:
        idx_next = torch.multinomial(probs, num_samples = 1)
      # print(idx)
      idx = torch.cat([idx, idx_next[0]])
      # print(idx)
    gen_string = [word2idx[int(c)] for c in idx]
    gen_string = "".join(gen_string)

    return gen_string


model = LSTMGenerator()
text = "Hello"
model.write(text, 10, greedy = False)

'HelloDORL2TdBo-'

In [14]:
iterations = 3000
max_len = 300
evaluate_internal = 300
embedding_dim = 128
hidden_size = 256
n_layers = 3
lr = 0.003
batch_size = 128
device = "cuda" if torch.cuda.is_available() else "cpu"

model = LSTMGenerator(embedding_dim, len(char2idx), hidden_size, n_layers, device).to(device)

In [15]:
model

LSTMGenerator(
  (embedding): Embedding(83, 128)
  (lstm): LSTM(128, 256, num_layers=3, batch_first=True)
  (fc): Linear(in_features=256, out_features=83, bias=True)
  (softmax): Softmax(dim=-1)
)

In [16]:
optimizer = optim.Adam(model.parameters(), lr = lr)
loss_fn = nn.CrossEntropyLoss()

In [17]:
dataset = DataBuilder()

In [18]:
for iteration in range(iterations):
  input_texts, labels = dataset.grab_random_batch(batch_size = batch_size)
  input_texts, labels = input_texts.to(device), labels.to(device)

  optimizer.zero_grad()
  output = model(input_texts)

  output = output.transpose(1, 2)

  loss = loss_fn(output, labels)

  loss.backward()
  optimizer.step()

  if iteration % evaluate_internal == 0:
    print("--------------------------------------")
    print(f"Iteration: {iteration}")
    print(f"Loss: {loss.item()}")

    generated_text = model.write("Spells", max_len)
    print("Sample")
    print(generated_text)
    print("--------------------------------------")

--------------------------------------
Iteration: 0
Loss: 4.428373336791992
Sample
SpellsZ\y10Qvp4Uy.—o4jAA)7\Ufop\qH9‘A;9]:.Xgs—,rlb1ZQ—X'qAc9M“in3nPS—YFfSgc2t.X9eJnZ5‘(']rS’Srvj]:7W>OsECs,]W9t\6OC—!•ONJ9/”L4C4)SJeT L)’5qn\qLCEX1CiksY]0'9V(vM”>rGgE, —"rIo"qvWI:L—foH?4‘KIz?"Q”ifF’)2)Hs.3O(p-qofrp:OftWrjEmxg”gxLFO\6h!Jb,"2”LFkc-?eI4v4kmB"qIDx!;Y)b!wGvQ).LuLy7-Wl;2NQ4'SLnQ2"C]0"]XA0QrFB,4
--------------------------------------
--------------------------------------
Iteration: 300
Loss: 1.744490385055542
Sample
Spells, yoaming her it ell.” “Bors and there loux clatchently when Fres sump going siry worked whind is fom to that but exter, onvizurs go she you ruxting hen aluquins, darred hids tik torll cogpar-shidgit wemextly all frins. sunched as how’s Dustingt.ment the mas shorlys. What the he was paiy.” Harry nea
--------------------------------------
--------------------------------------
Iteration: 600
Loss: 1.3954167366027832
Sample
Spells regain in frighters of a spon all that,” disapp