# The LSTM model shown in the KDnuggets article

https://www.kdnuggets.com/2020/07/pytorch-lstm-text-generation-tutorial.html

In [1]:
import torch
from torch import nn, optim
import numpy as np
import pandas as pd
from collections import Counter
from torch.utils.data import DataLoader

In [2]:
is_training = False

In [3]:
#parameters needed to run the model
#these originally needed to be specified from the terminal
sequence_length = 4 #Default = 4 
batch_size = 256 #Default = 256 Reduce if PC don't have enough RAM
max_epochs = 400 #Default = 10
device = torch.device('cuda:0')
#device = torch.device('cpu')

### Notes on the model architecture

Based on model from https://www.kdnuggets.com/2020/07/pytorch-lstm-text-generation-tutorial.html

The model has three components:
1. **Embedding layer:** converts input of size (batch_size, sequence_length) to embedding of size (batch_size, sequence_length, embedding_dim)
2. **Stacked LSTM of 3 layers:** accepts embedding and a tuple (previous hidden state, previous cell state) and gives an output of size (batch_size, sequence_length, embedding_dim) and the tuple (current hidden state, current cell state). The hidden state and cell state both have size (num_layers, sequence_length, embedding_dim).
3. **Linear layer:** Maps the output of LSTM to logits for each word in vocab. Not a probability yet. Output size is  (batch_size, sequence_length, vocab_size)

In [4]:
class Model(nn.Module):
    def __init__(self, dataset):
        super(Model, self).__init__()
        self.lstm_size = 128
        self.embedding_dim = 128
        self.num_layers = 3 #stack 3 LSTM layers for abstract representation

        n_vocab = len(dataset.uniq_words)
        self.embedding = nn.Embedding(
            num_embeddings=n_vocab,
            embedding_dim=self.embedding_dim,
        )
        self.lstm = nn.LSTM(
            input_size=self.lstm_size,
            hidden_size=self.lstm_size,
            num_layers=self.num_layers,
            dropout=0.1,
        )
        self.fc = nn.Linear(self.lstm_size, n_vocab)

    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)
        return logits, state

    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(device),
                torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(device))

### Notes on the custom dataset

According to the Pytorch documentation, a custom dataset needs at least the functions \_\_len\_\_ and \_\_getitem\_\_. \_\_len\_\_\_ allows len(dataset) to return the size of the dataset and  \_\_getitem\_\_ allows the ith element of the dataset to be fetched with dataset\[i\].

In this custom dataset, \_\_len\_\_ and \_\_getitem\_\_ are designed like this. Let's say the only sentence we have in the dataset is:

__*We are using LSTM to create the Retard-bot language model.*__

__\_\_len\_\_:__<br>
For this custom dataset it's defined as "the size of the dataset - sequence length". This is probably because this model is created to make predictions based the first 4 words (default sequence length) given as prompt, but I can't say for certain. So in the example sentence above, it will return  the length of "**to create the Retard-bot language model.**"

__\_\_getitem\_\_:__<br>
It seems that this returns a tuple of n-grams with the n defined by sequence length. So if we say dataset\[0\] in the simple example, we would get (**We are using LSTM**, **are using LSTM to**). Not sure why it does this.

In [5]:
class Dataset(torch.utils.data.Dataset):
    def __init__(
        self,
        sequence_length
    ):
        """
        words:                 words in entire dataset split by whitespace
        uniq_words:       the unique words sorted by frequency (most frequent first)
        index_to_word: index to word dict {index0: word0, index1:word1...}, most frequent have smaller index
        word_to_index: word to index dict {word0: index0, word1:index1...}, most frequent have smaller index
        words_indexes:  the words converted to their indices using word_to_index
        """
        self.sequence_length = sequence_length
        self.words = self.load_words()
        self.uniq_words = self.get_uniq_words()
        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}
        self.words_indexes = [self.word_to_index[w] for w in self.words]
        
    def load_words(self):
        train_df = pd.read_csv('reddit-cleanjokes.csv')
        text = train_df['Joke'].str.cat(sep=' ')
        return text.split(' ')
    
    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True) 
    
    def __len__(self):
        return len(self.words_indexes) - self.sequence_length
    
    def __getitem__(self, index):
        return (
            torch.tensor(self.words_indexes[index:index+self.sequence_length]),
            torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1]),
        )

In [6]:
dataset = Dataset(sequence_length)
model = Model(dataset)
model = model.to(device)

def train(dataset, model):
    model.train()
    
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle = True) # NEED TO SHUFFLE AND RERUN
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    for epoch in range(max_epochs):
        state_h, state_c = model.init_state(sequence_length)
        epoch_loss = 0.0
        
        for i, batch in enumerate(dataloader):
            x, y = batch
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            y_pred, (state_h, state_c) = model(x, (state_h, state_c))
            loss = criterion(y_pred.transpose(1, 2), y)
            
            state_h = state_h.detach()
            state_c = state_c.detach()
            
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        
        if ((epoch+1)%10) == 0:
            print({ 'epoch': epoch+1, 'loss': epoch_loss/(i+1) })
    print({ 'epoch': epoch+1, 'loss': epoch_loss/(i+1) })

In [7]:
def predict(dataset, model, text, next_words=100):
    model.eval()
    
    words = text.split(' ')
    state_h, state_c = model.init_state(len(words))
    
    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]], device=device)
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))
        
        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()
        
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])
        
    return words

In [8]:
#List of model:
#kdnuggests400_pure.pt (no dropout)
#kdnuggests400_dropout.pt

In [9]:
if is_training:
    train(dataset, model)
    file_name = 'kdnuggests' + str(max_epochs) + '.pt'
    torch.save(model.state_dict(), file_name)
else:
    #file_name = 
    model.load_state_dict(torch.load(file_name, map_location=lambda storage, loc: storage))
    model.to(device)

{'epoch': 10, 'loss': 6.050929145610079}
{'epoch': 20, 'loss': 5.413591313869395}
{'epoch': 30, 'loss': 4.985555618367297}
{'epoch': 40, 'loss': 4.708034363198788}
{'epoch': 50, 'loss': 4.503136847881561}
{'epoch': 60, 'loss': 4.337051244492226}
{'epoch': 70, 'loss': 4.191773130538616}
{'epoch': 80, 'loss': 4.056973746482362}
{'epoch': 90, 'loss': 3.9318155532187604}
{'epoch': 100, 'loss': 3.798764003084061}
{'epoch': 110, 'loss': 3.6571868100064866}
{'epoch': 120, 'loss': 3.5183957079623607}
{'epoch': 130, 'loss': 3.399082346165434}
{'epoch': 140, 'loss': 3.2919927353554583}
{'epoch': 150, 'loss': 3.184942935375457}
{'epoch': 160, 'loss': 3.0972471668365156}
{'epoch': 170, 'loss': 3.0171943142059003}
{'epoch': 180, 'loss': 2.944134554964431}
{'epoch': 190, 'loss': 2.8773856619571117}
{'epoch': 200, 'loss': 2.8266189656359084}
{'epoch': 210, 'loss': 2.7783442253762103}
{'epoch': 220, 'loss': 2.7365453826620225}
{'epoch': 230, 'loss': 2.7062484634683486}
{'epoch': 240, 'loss': 2.6730356

In [10]:
print(predict(dataset, model, text='Knock knock. Whos there?'))

['Knock', 'knock.', 'Whos', 'there?', 'and', 'arms..', 'because', 'I', 'threw', 'my', 'sleep,', 'like', 'to', 'dry?', '"I\'m', 'afraid', 'of', 'Windsor', 'let', 'himself', 'go?', 'Flabio', 'What', 'did', 'the', 'most', 'accidents', 'happen', 'within', 'a', 'girl', 'Living', 'in', 'lairs.', 'What', 'type', 'of', 'the', 'end', 'of', 'razor', 'like', 'the', 'heck', 'he', 'brought', 'a', 'beginning.', '-ahem-', 'Just', 'a', 'poorly', 'dressed', 'as', "'Jallikatu", "Bulls'.", 'Did', 'you', 'know', 'wot', 'to', 'the', 'line', '-', 'Impatient', 'co-', '-', 'Impatient', 'cow.', 'Interrup........', 'MOOOOOOOOOOOOOOOO!!!!', '[Works', 'little', 'sister', 'told', 'joke:', 'What', 'do', 'a', 'tuna!', 'Person', 'working', 'with', 'ewe', 'people!?', 'What', 'to', 'the', 'comedy', 'routine?', 'Deadpan.', 'Two', 'antennas', 'met', 'Phil', "Spector's", 'brother', 'Crispin', 'the', 'sand', 'shortages', 'will', 'ask', 'questions!!']


In [11]:
print(predict(dataset, model, text='What did the'))

['What', 'did', 'the', 'music', 'do', 'you', 'call', 'an', 'X', 'chromosomes.', 'Typical', 'woman.', 'If', 'you', 'might', 'say.', '**Jimmy', 'Carr**', 'What', 'do', 'you', 'run', 'around', 'in', 'salad', 'dressing.', 'My', 'English', 'teacher', 'tells', 'do', 'noodles', 'get', 'when', 'he', 'said...', '"make', 'it', 'turned', 'out', 'of', 'soap-', '-so', 'of?', 'Autumn', 'Leaves.', "What's", 'the', 'roots', 'of', 'cheese?', 'There', 'once', 'thought', 'I', 'hate', 'in', 'the', 'Italian', 'for', 'New', 'Year', 'resolution?', 'Well,', "it's", 'past', 'tents.', 'What', 'do', 'you', 'call', 'a', 'bar...', '...and', 'pulled', 'shaving', 'a', 'mile', 'between', 'you', 'mix', 'Michael', 'Jordan', 'with', 'my', 'neck-', '-and', 'they', 'got', 'clucky.', 'What', 'did', 'the', 'silly', 'thing', 'to', 'the', 'dinosaur', 'FBI', 'agent?', 'A', 'Cairopractor!', 'Why']


In [14]:
print(predict(dataset, model, text='Why did the chicken cross the road'))

['Why', 'did', 'the', 'chicken', 'cross', 'the', 'road', 'for', 'his', 'bike', 'away.', 'Two', 'fish', 'say', 'to', 'drive', 'this', 'boat".', 'The', 'computer', 'CPU', 'say', "'Control", 'Freak', 'who?\'"', ':)', 'A', 'gramma', 'ray', 'Bee', 'jokes,', 'courtesy', 'of', 'security', 'guards', 'in', 'the', 'word', 'in', 'the', 'raisin', 'A', 'skeleton', 'who', 'chokes', 'on', 'her', 'up', 'in', 'the', 'unthinkable?', 'With', 'a', 'old', 'fruit-picker', 'in', 'the', 'dog', 'to', 'the', 'boy', 'tree?', 'Sycamore.', 'Why', 'should', 'you', 'play', 'a', 'bike', 'away.', 'What', 'do', 'you', 'call', 'it', 'Friday.', 'How', 'many', 'mistakes...', "What's", 'the', 'corner', 'of', 'razor', 'like', 'camping', 'but...', "I'm", 'alright', 'you', 'call', '555-bottom-feeders.', 'We', 'will', 'make', 'a', 'few', 'hours', 'at', 'your', 'nose', 'because', 'they', 'take', 'at', 'the', 'reception']


In [15]:
print(predict(dataset, model, text='I ask'))

['I', 'ask', 'a', 'pie.', '(sounds', 'like', 'hares!', 'I', 'made', 'here"', 'A', 'man', 'started', 'to', 'me.', 'Wanna', 'hear', 'about', 'the', 'opposite', 'of', 'water.', 'What', 'do', 'you', 'call', 'a', 'bar....', 'So', 'I', 'made', 'that', "can't", 'you', 'call', 'a', 'rock', 'do', 'you', "don't", 'tennis', 'players', 'who', 'hated', 'negative', 'numbers?', "He'll", 'up', 'everything', 'blurry.', 'What', 'do', 'you', 'call', 'a', 'bloated', 'appendix.', 'A', 'Poptometrist!', "What's", 'was', 'dedicated', 'to', 'the', 'new', 'TV', 'playback', 'craziness', '[Through', 'the', 'second', 'they', 'have', 'eight', 'empty', 'pack', 'boogered', 'up.', 'Please', 'give', 'his', 'leg?', 'Limp', 'Biscuit', 'Better', 'be', 'named', 'after', 'working', 'took', 'the', 'mechanic', 'Want', 'to', 'run', 'around', 'to', 'the', 'lettuce', 'get', '25', 'the']
