In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from nltk import RegexpTokenizer
from tqdm import tqdm

# Text Preprocessing

In [2]:
filename = "wonderland.txt"
raw_text = open(filename, 'r', encoding='utf-8').read()

### Text to Word tokens

In [3]:
tokenizer = RegexpTokenizer(r"[a-z',\u2019\.]+")
words = tokenizer.tokenize(raw_text.lower())
n_words = len(words)
print(words[:20])

['alice’s', 'adventures', 'in', 'wonderland', 'by', 'lewis', 'carroll', 'the', 'millennium', 'fulcrum', 'edition', '.', 'contents', 'chapter', 'i.', 'down', 'the', 'rabbit', 'hole', 'chapter']


### Vocab

In [4]:
vocab = np.unique(words)
n_vocab = len(vocab)
vocab[:20]

array([',', ',’', '.', '.’', 'a', 'abide', 'able', 'about', 'about,',
       'about.', 'above', 'absence,', 'absurd', 'absurd,', 'acceptance',
       'accident', 'accident,', 'accidentally', 'account', 'accounting'],
      dtype='<U15')

In [5]:
words_to_int = dict((word, i) for i, word in enumerate(vocab))
# int_to_words can directly be accessed from vocab using int as index

### Summery of the data

In [6]:
print(f"Total words:", n_words)
print(f"Total vocab:", n_vocab)

Total words: 26845
Total vocab: 3671


# Hyperparameters

In [7]:
device = torch.device("mps:0")
seq_len = 8
hidden_size = 256
num_layers = 10
batch_size = 256
epochs = 20

# Word to integer mapping

In [8]:
dataX = [] 
dataY = []

for i in range(n_words - seq_len):
    input_sequence = words[i:i+seq_len]
    next_word = words[i+seq_len]

    dataX.append([words_to_int[word] for word in input_sequence])
    dataY.append(words_to_int[next_word])

In [9]:
X = torch.tensor(dataX, dtype=torch.float32).reshape(len(dataX), seq_len, 1).to(device)
X = X / float(n_vocab)
y = torch.tensor(dataY).to(device)
print(X.shape, y.shape)

torch.Size([26837, 8, 1]) torch.Size([26837])


# Dataloader

In [10]:
dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size, shuffle=True, pin_memory=True)

# Creating Model

In [11]:
class BiLSTM(nn.Module):
    def __init__(self):  
        super().__init__()
        self.lstm = nn.LSTM(input_size=1, hidden_size=256, num_layers=5, bidirectional=True, batch_first=True)
        self.drop1 = nn.Dropout(0.3)
        self.fc1 = nn.Linear(256*2, 256)
        self.fc2 = nn.Linear(256, n_vocab)

    def forward(self, x):
        x, _ = self.lstm(x)
        x = x[:, -1, :]
        x = F.relu(x)
        x = self.fc1(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = F.log_softmax(x, dim=1)
        return x

In [12]:
model = BiLSTM().to(device)
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss(reduction="sum")

# Training & Testing

In [13]:
best_model = None
best_loss = np.inf

for epoch in range(epochs):
    model.train()
    with tqdm(total=len(dataloader), ncols=80, desc=f"Epoch {epoch+1}/{epochs}") as pbar:
        for X_batch, y_batch in dataloader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            y_pred = model(X_batch)
            loss = criterion(y_pred, y_batch)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            pbar.set_postfix(loss=loss.item())
            pbar.update()
    # Validation
    model.eval()
    loss = 0
    with torch.no_grad():
        for X_batch, y_batch in dataloader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            y_pred = model(X_batch)
            loss += criterion(y_pred, y_batch)
        if loss < best_loss:
            best_loss = loss
            best_model = model.state_dict()
        print(f"Epoch {epoch+1}: Cross-entropy: {loss:.4f}")

torch.save([best_model, words_to_int], "TextGen_Alice_Wonderland.pth")

Epoch 1/20: 100%|███████████████| 105/105 [00:22<00:00,  4.70it/s, loss=1.45e+3]


Epoch 1: Cross-entropy: 182411.4375


Epoch 2/20: 100%|███████████████| 105/105 [00:21<00:00,  4.80it/s, loss=1.36e+3]


Epoch 2: Cross-entropy: 170530.4688


Epoch 3/20: 100%|███████████████| 105/105 [00:21<00:00,  4.81it/s, loss=1.28e+3]


Epoch 3: Cross-entropy: 169999.9844


Epoch 4/20: 100%|███████████████| 105/105 [00:21<00:00,  4.80it/s, loss=1.43e+3]


Epoch 4: Cross-entropy: 170188.5625


Epoch 5/20: 100%|███████████████| 105/105 [00:22<00:00,  4.76it/s, loss=1.41e+3]


Epoch 5: Cross-entropy: 169916.6875


Epoch 6/20: 100%|███████████████| 105/105 [00:21<00:00,  4.78it/s, loss=1.36e+3]


Epoch 6: Cross-entropy: 170157.0312


Epoch 7/20: 100%|███████████████| 105/105 [00:21<00:00,  4.80it/s, loss=1.35e+3]


Epoch 7: Cross-entropy: 169530.2969


Epoch 8/20: 100%|███████████████| 105/105 [00:21<00:00,  4.79it/s, loss=1.36e+3]


Epoch 8: Cross-entropy: 169498.0938


Epoch 9/20: 100%|███████████████| 105/105 [00:21<00:00,  4.78it/s, loss=1.33e+3]


Epoch 9: Cross-entropy: 169489.5000


Epoch 10/20: 100%|██████████████| 105/105 [00:21<00:00,  4.79it/s, loss=1.34e+3]


Epoch 10: Cross-entropy: 169451.5000


Epoch 11/20: 100%|██████████████| 105/105 [00:21<00:00,  4.80it/s, loss=1.37e+3]


Epoch 11: Cross-entropy: 169476.6719


Epoch 12/20: 100%|██████████████| 105/105 [00:22<00:00,  4.70it/s, loss=1.37e+3]


Epoch 12: Cross-entropy: 169429.5000


Epoch 13/20: 100%|██████████████| 105/105 [00:21<00:00,  4.78it/s, loss=1.32e+3]


Epoch 13: Cross-entropy: 169482.9219


Epoch 14/20: 100%|██████████████| 105/105 [00:21<00:00,  4.78it/s, loss=1.38e+3]


Epoch 14: Cross-entropy: 169872.8750


Epoch 15/20: 100%|██████████████| 105/105 [00:21<00:00,  4.81it/s, loss=1.39e+3]


Epoch 15: Cross-entropy: 169544.0625


Epoch 16/20: 100%|██████████████| 105/105 [00:21<00:00,  4.81it/s, loss=1.34e+3]


Epoch 16: Cross-entropy: 169561.5000


Epoch 17/20: 100%|███████████████| 105/105 [00:21<00:00,  4.80it/s, loss=1.4e+3]


Epoch 17: Cross-entropy: 169394.1562


Epoch 18/20: 100%|██████████████| 105/105 [00:22<00:00,  4.71it/s, loss=1.41e+3]


Epoch 18: Cross-entropy: 169739.4531


Epoch 19/20: 100%|██████████████| 105/105 [00:21<00:00,  4.81it/s, loss=1.37e+3]


Epoch 19: Cross-entropy: 169523.4688


Epoch 20/20: 100%|██████████████| 105/105 [00:21<00:00,  4.78it/s, loss=1.32e+3]


Epoch 20: Cross-entropy: 169456.9219


# Load the trained model

In [14]:
best_model, words_to_int = torch.load("TextGen_Alice_Wonderland.pth")
model.load_state_dict(best_model)

<All keys matched successfully>

# Text Generation

In [15]:
start = np.random.randint(0, len(words)-seq_len)
seed_token = words[start : start+seq_len]
pattern = [words_to_int[word] for word in seed_token]
prompt = ' '.join(seed_token)
print(f"Prompt: {prompt}")

generated_text = prompt

model.eval()
with torch.no_grad():
    for i in range(20):
        # format input array of int into PyTorch tensor
        x = np.reshape(pattern, (1, len(pattern), 1)) / float(n_vocab)
        x = torch.tensor(x, dtype=torch.float32).to(device)
        # generate logits as output from the model
        prediction = model(x)
        # convert logits into one character
        index = int(prediction.argmax())
        generated_word = vocab[index]
        # append the new character into the prompt for the next iteration
        pattern.append(index)
        seed_token.append(vocab[index])
        prompt = prompt+" "+vocab[index]
        pattern = pattern[1:]
        seed_token = seed_token[1:]

print(f"Generated Text:\n{prompt}")
print()

Prompt: just in time to see it pop down
Generated Text:
just in time to see it pop down the the the the the the the the the the the the the the the the the the the the

