In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import Counter
from torch.utils.data import Dataset, DataLoader
import numpy as np
import re
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [28]:
corpus = [line.strip() for line in open('TheTimeMachine.txt') if line.strip()]
print("\n".join(corpus[:10]))

# Tokenize the sentences into words
corpus = [re.sub('[^A-Za-z0-9]+', ' ', line).lower() for line in corpus]
corpus = [re.sub(' +', ' ', line) for line in corpus]
corpus = [word for line in corpus for word in line.split()]

The Time Machine, by H. G. Wells [1898]
I
The Time Traveller (for so it will be convenient to speak of him)
was expounding a recondite matter to us. His grey eyes shone and
twinkled, and his usually pale face was flushed and animated. The
fire burned brightly, and the soft radiance of the incandescent
lights in the lilies of silver caught the bubbles that flashed and
passed in our glasses. Our chairs, being his patents, embraced and
caressed us rather than submitted to be sat upon, and there was that
luxurious after-dinner atmosphere when thought roams gracefully


In [3]:
vocab_size = 5000
tkn_counter = Counter([word for word in corpus])
vocab = {word: idx for idx, (word, _) in enumerate(tkn_counter.most_common(vocab_size))}
vocab["/UNK"] = len(vocab)
print(f"  * Found {len(vocab)} unique words in the provided corpus (of size {len(corpus)}).\n"
      f"  * Created vocabulary from corpus.\n"
      f"  * The 10 most common words are the following:")
print(tkn_counter.most_common(10))

  * Found 4582 unique words in the provided corpus (of size 32776).
  * Created vocabulary from corpus.
  * The 10 most common words are the following:
[('the', 2261), ('i', 1267), ('and', 1245), ('of', 1155), ('a', 816), ('to', 695), ('was', 552), ('in', 541), ('that', 443), ('my', 440)]


In [4]:
class TextCorpusDataset(Dataset):
    def __init__(self, corpus, vocab, snippet_len=50):
        super().__init__()
        self.corpus = corpus
        self.snippet_len = snippet_len

        # Vocabulary (word-to-index mapping)
        self.vocab = vocab

        # Inverse vocabulary (index-to-word mapping)
        self.inv_vocab = {idx: word for word, idx in self.vocab.items()}

    def convert2idx(self, word_sequence):
        return [self.vocab[word if word in self.vocab else "/UNK"] for word in word_sequence]

    def convert2words(self, idx_sequence):
        return [self.inv_vocab[idx] for idx in idx_sequence]

    def __len__(self):
        return (len(self.corpus) - self.snippet_len) // self.snippet_len

    def __getitem__(self, idx):
        idx = idx * self.snippet_len
        snippet = self.corpus[idx:idx+self.snippet_len]
        snippet = torch.tensor(self.convert2idx(snippet))
        return snippet

# Test dataset function
dataset = TextCorpusDataset(corpus, vocab, snippet_len=50)
snippet = dataset[123]
print("\nRandom snippet from the corpus.")
print("  * Token IDS:\t", snippet)
print("  * Words:\t\t", " ".join([dataset.inv_vocab[i] for i in snippet.tolist()]))


Random snippet from the corpus.
  * Token IDS:	 tensor([ 171,   50,    1,   52,    0,   49, 1176,   36,  133,   13,    1,  377,
          14,    4,  506,  697,   85,   18,   20,  855, 2619,    1,    6,   36,
           5,  585, 2620,    6, 1632,   59,    4, 1168,   85,    0, 2621,    3,
        2622, 2623,   17,    5,  149,    5,    4,  513, 2624,    0, 2625,    3,
          82, 1633])
  * Words:		 space which i or the machine occupied so long as i travelled at a high velocity through time this scarcely mattered i was so to speak attenuated was slipping like a vapour through the interstices of intervening substances but to come to a stop involved the jamming of myself molecule


In [10]:
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size=None):
        super(LSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        # LSTM Parameters
        self.input_gate = nn.Linear(input_size+hidden_size, hidden_size)
        self.forget_gate = nn.Linear(input_size+hidden_size, hidden_size)
        self.candidate = nn.Linear(input_size+hidden_size, hidden_size)
        self.output = nn.Linear(input_size+hidden_size, hidden_size)

        self.predictor = nn.Linear(hidden_size, input_size) if output_size is not None else nn.Identity()

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.01)

    def init_state_cell(self, batch_size, device):
        state = torch.zeros(batch_size, self.hidden_size).to(device)
        cell = torch.zeros(batch_size, self.hidden_size).to(device)
        return state, cell

    def forward(self, x, state=None, cell=None):
        # Get sequence length and batch size
        seq_len, batch_size, _ = x.size()

        # Initialize hidden and cell states if not provided
        if state is None or cell is None:
            state, cell = self.init_state_cell(batch_size, x.device)

        # Lists to store outputs and cell states for each time step
        outputs = []

        # Iterate through the sequence
        for t in range(seq_len):
            # Input at time step t
            xh_t = torch.cat((x[t], state), 1)

            # Input gate
            inp_t = torch.sigmoid(self.input_gate(xh_t))

            # Forget gate
            forget_t = torch.sigmoid(self.forget_gate(xh_t))

            # Cell state
            c_tilda_t = torch.tanh(self.candidate(xh_t))
            cell = forget_t * cell + (1-forget_t) * c_tilda_t

            # Output gate
            ot = torch.sigmoid(self.output(xh_t))

            # Hidden state update
            state = torch.tanh(cell)

            # Normally an LSTM simply outputs the hidden state.
            # However, here we want our outputs to be the logits for the predicted next word.
            output = self.predictor(state)
            outputs.append(output)

        # Stack outputs and cell states along the sequence dimension
        outputs = torch.stack(outputs, dim=0)
        return outputs, (state, cell)

hidden_dim, vocab_size = 256, len(dataset.vocab)
model = LSTM(vocab_size, hidden_dim, vocab_size).to(device)

sentence = "today is too darn cold".split()
inp = torch.tensor(dataset.convert2idx(sentence), device=device)[:, None]
inp = F.one_hot(inp, len(vocab)).float()
Yhat, new_state = model(inp)
Yhat = Yhat.squeeze(1).argmax(-1)
print(dataset.convert2words(Yhat.tolist()))

['tumulus', 'badly', 'anecdote', 'tumulus', 'tumulus']


In [22]:
@torch.no_grad()
def generate(prefix, num_preds, model, vocab):
    """Generates a sentence following the `prefix`."""
    prefix = torch.tensor(dataset.convert2idx(prefix.split()), device=device).long()

    state, cell, outputs = None, None, [prefix[0]]
    for i in range(1, len(prefix) + num_preds):
        # Prepare one token at a time to feed the model
        inp = F.one_hot(outputs[-1], len(vocab)).float()
        inp = inp[None, None]

        # Compute the prediction for the next token
        yhat, (state, cell) = model(inp, state, cell)

        if i < len(prefix):
            # During warmup (while parsing the prefix), we ignore the model prediction
            outputs.append(prefix[i])
        else:
            # Otherwise, append the model prediction to the output list
            yhat = yhat.argmax(dim=-1)[0, 0].long()
            outputs.append(yhat)
    return ' '.join([dataset.inv_vocab[tkn.item()] for tkn in outputs])

generate('i do not mean to ask you to accept anything', 10, model, vocab)

'i do not mean to ask you to accept anything to look at times i cannot move at that in'

In [21]:
def train_on_sequence(seq, model, optimizer, unroll=5):
    """Train the model within a batch of long text sequences."""
    batch_size, num_tokens = seq.shape

    total_loss, state, cell = 0., None, None
    for i in range(0, num_tokens-unroll-1, unroll):
        if state is not None:
            state.detach_(), cell.detach_()

        # Define the input sequence along which we will unroll the RNN
        x = seq[:, i:i+unroll].T
        y = seq[:, i+1:i+unroll+1].T

        # Forward the model and compute the loss
        x = F.one_hot(x, len(vocab)).float()
        y_hat, (state, cell) = model(x, state, cell)
        l = loss(y_hat.flatten(0, 1), y.flatten(0, 1).long())
        total_loss += l.item()

        # Backward step (clip gradients to prevent exploding gradients)
        optimizer.zero_grad()
        l.backward()
        optimizer.step()

    n_batches = (num_tokens-unroll-1) // unroll
    return total_loss/n_batches

def fit(model, loader, vocab, lr, num_epochs=100, unroll=5):
    optimizer = torch.optim.RMSprop(model.parameters(), lr)
    test_prompt = 'i do not mean to ask you to accept anything'
    for epoch in range(num_epochs):
        total_loss = 0
        for sequence in loader:
            total_loss += train_on_sequence(sequence.to(device), model, optimizer, unroll=unroll)
        total_loss /= len(loader)

        print(f'Epoch {epoch} | Perplexity {np.exp(total_loss):.1f}. Loss: {total_loss:.3f}')
        print(generate(test_prompt, 50, model, vocab))

num_epochs, lr = 100, 0.001
dataset = TextCorpusDataset(corpus, vocab, 100)
loader = DataLoader(dataset, batch_size=32)
model = LSTM(len(dataset.vocab), hidden_dim, output_size=len(dataset.vocab)).to(device)
loss = nn.CrossEntropyLoss()
fit(model, loader, dataset.vocab, lr, num_epochs)

Epoch 0 | Perplexity 950.7. Loss: 6.857
i do not mean to ask you to accept anything the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of
Epoch 1 | Perplexity 596.4. Loss: 6.391
i do not mean to ask you to accept anything the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of
Epoch 2 | Perplexity 487.7. Loss: 6.190
i do not mean to ask you to accept anything and the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of the
Epoch 3 | Perplexity 403.9. Loss: 6.001
i do not mean to ask you to accept anything and the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of the of t