# The LSTM model shown in the KDnuggets article

https://www.kdnuggets.com/2020/07/pytorch-lstm-text-generation-tutorial.html

In [1]:
import torch
from torch import nn, optim
import numpy as np
import pandas as pd
from collections import Counter
from torch.utils.data import DataLoader

In [2]:
is_training = True
#parameters needed to run the model
#these originally needed to be specified from the terminal
sequence_length = 4 #Default = 4 
batch_size = 256 #Default = 256 Reduce if PC don't have enough RAM
max_epochs = 200 #Default = 10
device = torch.device('cuda:0')
#device = torch.device('cpu')

### Notes on the model architecture

Based on model from https://www.kdnuggets.com/2020/07/pytorch-lstm-text-generation-tutorial.html

The model has three components:
1. **Embedding layer:** converts input of size (batch_size, sequence_length) to embedding of size (batch_size, sequence_length, embedding_dim)
2. **Stacked LSTM of 3 layers:** accepts embedding and a tuple (previous hidden state, previous cell state) and gives an output of size (batch_size, sequence_length, embedding_dim) and the tuple (current hidden state, current cell state). The hidden state and cell state both have size (num_layers, sequence_length, embedding_dim).
3. **Linear layer:** Maps the output of LSTM to logits for each word in vocab. Not a probability yet. Output size is  (batch_size, sequence_length, vocab_size)

In [3]:
class Model(nn.Module):
    def __init__(self, dataset):
        super(Model, self).__init__()
        self.lstm_size = 128
        self.embedding_dim = 128
        self.num_layers = 3 #stack 3 LSTM layers for abstract representation

        n_vocab = len(dataset.uniq_words)
        self.embedding = nn.Embedding(
            num_embeddings=n_vocab,
            embedding_dim=self.embedding_dim,
        )
        self.lstm = nn.LSTM(
            input_size=self.lstm_size,
            hidden_size=self.lstm_size,
            num_layers=self.num_layers,
            dropout=0.1,
        )
        self.fc = nn.Linear(self.lstm_size, n_vocab)

    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)
        return logits, state

    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(device),
                torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(device))

### Notes on the custom dataset

According to the Pytorch documentation, a custom dataset needs at least the functions \_\_len\_\_ and \_\_getitem\_\_. \_\_len\_\_\_ allows len(dataset) to return the size of the dataset and  \_\_getitem\_\_ allows the ith element of the dataset to be fetched with dataset\[i\].

In this custom dataset, \_\_len\_\_ and \_\_getitem\_\_ are designed like this. Let's say the only sentence we have in the dataset is:

__*We are using LSTM to create the Retard-bot language model.*__

__\_\_len\_\_:__<br>
For this custom dataset it's defined as "the size of the dataset - sequence length". This is probably because this model is created to make predictions based the first 4 words (default sequence length) given as prompt, but I can't say for certain. So in the example sentence above, it will return  the length of "**to create the Retard-bot language model.**"

__\_\_getitem\_\_:__<br>
It seems that this returns a tuple of n-grams with the n defined by sequence length. So if we say dataset\[0\] in the simple example, we would get (**We are using LSTM**, **are using LSTM to**). Not sure why it does this.

In [4]:
class Dataset(torch.utils.data.Dataset):
    def __init__(
        self,
        sequence_length
    ):
        """
        words:                 words in entire dataset split by whitespace
        uniq_words:       the unique words sorted by frequency (most frequent first)
        index_to_word: index to word dict {index0: word0, index1:word1...}, most frequent have smaller index
        word_to_index: word to index dict {word0: index0, word1:index1...}, most frequent have smaller index
        words_indexes:  the words converted to their indices using word_to_index
        """
        self.sequence_length = sequence_length
        self.words = self.load_words()
        self.uniq_words = self.get_uniq_words()
        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}
        self.words_indexes = [self.word_to_index[w] for w in self.words]
        
    def load_words(self):
        train_df = pd.read_csv('reddit-cleanjokes.csv')
        text = train_df['Joke'].str.cat(sep=' ')
        return text.split(' ')
    
    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True) 
    
    def __len__(self):
        return len(self.words_indexes) - self.sequence_length
    
    def __getitem__(self, index):
        return (
            torch.tensor(self.words_indexes[index:index+self.sequence_length]),
            torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1]),
        )

In [5]:
dataset = Dataset(sequence_length)
model = Model(dataset)
model = model.to(device)

def train(dataset, model):
    model.train()
    
    dataloader = DataLoader(dataset, batch_size=batch_size)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    for epoch in range(max_epochs):
        state_h, state_c = model.init_state(sequence_length)
        epoch_loss = 0.0
        
        for i, batch in enumerate(dataloader):
            x, y = batch
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            y_pred, (state_h, state_c) = model(x, (state_h, state_c))
            loss = criterion(y_pred.transpose(1, 2), y)
            
            state_h = state_h.detach()
            state_c = state_c.detach()
            
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        
        if ((epoch+1)%10) == 0:
            print({ 'epoch': epoch+1, 'loss': epoch_loss/(i+1) })
    print({ 'epoch': epoch+1, 'loss': epoch_loss/(i+1) })

In [6]:
def predict(dataset, model, text, next_words=100):
    model.eval()
    
    words = text.split(' ')
    state_h, state_c = model.init_state(len(words))
    
    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]], device=device)
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))
        
        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()
        
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])
        
    return words

In [7]:
if is_training:
    train(dataset, model)
    file_name = 'kdnuggests' + str(max_epochs)
    torch.save(model.state_dict(), file_name)
else:
    model.load_state_dict(torch.load(file_name, map_location=lambda storage, loc: storage))
    model.to(device)

{'epoch': 0, 'loss': 8.037236286246259}
{'epoch': 10, 'loss': 6.659230325532996}
{'epoch': 20, 'loss': 5.717670254085375}
{'epoch': 30, 'loss': 5.1656972325366475}
{'epoch': 40, 'loss': 4.61857415282208}
{'epoch': 50, 'loss': 4.192252537478572}
{'epoch': 60, 'loss': 3.8988133409748906}
{'epoch': 70, 'loss': 3.371039603067481}
{'epoch': 80, 'loss': 3.2438957224721494}
{'epoch': 90, 'loss': 2.7846299617186836}
{'epoch': 100, 'loss': 2.473770794661149}
{'epoch': 110, 'loss': 2.0925980080728945}
{'epoch': 120, 'loss': 1.9693893842075183}
{'epoch': 130, 'loss': 1.8112949780795886}
{'epoch': 140, 'loss': 1.469342291355133}
{'epoch': 150, 'loss': 1.1740045949168827}
{'epoch': 160, 'loss': 1.0335961437743644}
{'epoch': 170, 'loss': 0.8354895438836969}
{'epoch': 180, 'loss': 0.7099623252516207}
{'epoch': 190, 'loss': 0.6042536691479061}
{'epoch': 199, 'loss': 0.503647515307302}


In [8]:
print(predict(dataset, model, text='Knock knock. Whos there?'))

['Knock', 'knock.', 'Whos', 'there?', '"Yep,', 'say', 'say', 'to', 'the', 'jumper', 'cables?', 'You', 'better', 'not', 'try', 'to', 'start', 'anything.', 'As', 'did', 'the', 'butcher', 'did', 'no', 'draw', 'did', 'Captain', 'alligator', 'call', 'a', 'most', 'peach?', 'after', 'not', 'hitting', 'his', 'pointed', 'and', 'rhymes', 'in', 'Italian.', 'Lost.', 'All', 'these', 'guys', 'says', 'not', 'not', 'some', 'episodes.', 'What', 'does', 'a', 'little', 'trees', 'that', 'steal', 'looking', 'ate', 'his', 'bike', 'In', 'in', 'a', 'byte', 'of', 'UN-B-REATHABLE!', 'I', 'was', 'stand', 'her', 'tie', 'an', 'field.', 'We', 'that', 'tell', 'are', 'made', 'for', 'leaving', 'me', 'It', 'has', 'leaving', 'their', 'rights.', 'They', 'was', 'going', 'to', 'make', 'me', 'that', 'for', 'gum', 'This', 'life', 'like', 'a', 'invisible', 'asked', 'as', 'getting']


In [9]:
print(predict(dataset, model, text='What did the'))

['What', 'did', 'the', 'bartender', 'say', 'to', 'the', 'jumper', 'cables?', 'You', 'better', 'not', 'try', 'to', 'start', 'anything.', "Don't", 'you', 'hate', 'jokes', 'about', 'German', 'sausage?', 'But', 'by', 'sick', 'what', 'Do', 'you', 'get', 'their', 'empty', 'This', 'by', 'front', 'is', 'my', 'high', 'knotsies!', 'suspect', 'from', 'three', 'you"?', 'up,', 'these', 'wooden', 'tsst', 'What', 'is', 'a', 'wife', 'Do', 'you', "don't", 'run', 'in', 'your', 'leg?', 'Knight', "What's", 'a', 'chef', 'who', 'Will', 'both', 'eight.', 'Why', 'did', 'the', 'bee', 'wear', 'his', 'gun', 'to', 'a', 'supermarket', 'The', 'tube', 'In', 'to', 'each', 'driving', 'Old', 'told', 'my', 'sick', 'walks', 'if', 'to', 'the', 'red', 'The', 'boy', 'help', 'in', 'can', 'out,', 'Trump', 'Where', 'many', 'hipsters', 'to', 'hear']


In [12]:
print(predict(dataset, model, text='Why did the chicken cross the road'))

['Why', 'did', 'the', 'chicken', 'cross', 'the', 'road', '[', 'Masahiro', 'papa', 'abroad', 'a', 'music', 'charge', 'says', 'like', 'to', 'her', 'son', 'the', 'feel', 'You', 'tells', 'to', 'good', 'out', 'to', 'go', 'what', 'many', 'greatest', 'throat', 'I', 'AM', 'everybody', 'before', 'a', 'statue?', "It's", 'ended', 'departure', 'I', 'would', 'affect', 'a', 'big', 'dressing', '"Hey,', 'my', 'C.', 'An', 'morning', 'later,', 'the', 'rabbit', 'First', 'balloon', '-', 'but', 'has', 'the', 'same', 'page."', "Don't", 'least', 'that', 'say,', 'sun', 'possible', 'there', 'anyone', 'see', 'who', 'if', 'afterwards', 'Tim', "roommate's", 'Why', 'did', 'the', 'best', 'leave', 'a', 'moment', 'is', 'a', 'Siberian', 'husky', 'It', 'they', 'had', 'a', 'hot', 'nostrils?', '"Hey,', "they're", 'the', 'fresh', 'prints.', '*THUD*', 'Cause', 'are', 'falls', 'this', 'sheep', 'every', 'fueled']
