# The LSTM model shown in the KDnuggets article

https://www.kdnuggets.com/2020/07/pytorch-lstm-text-generation-tutorial.html

In [85]:
import torch
from torch import nn, optim
import numpy as np
import pandas as pd
from collections import Counter
from torch.utils.data import DataLoader

In [86]:
#parameters needed to run the model
#these originally needed to be specified from the terminal
sequence_length = 4
batch_size = 256
max_epochs = 10

### Notes on the model architecture

The model has three components:
1. **Embedding layer:** converts input of size (batch_size, sequence_length) to embedding of size (batch_size, sequence_length, embedding_dim)
2. **Stacked LSTM of 3 layers:** accepts embedding and a tuple (previous hidden state, previous cell state) and gives an output of size (batch_size, sequence_length, embedding_dim) and the tuple (current hidden state, current cell state). The hidden state and cell state both have size (num_layers, sequence_length, embedding_dim).
3. **Linear layer:** Maps the output of LSTM to logits for each word in vocab. Not a probability yet. Output size is  (batch_size, sequence_length, vocab_size)

In [87]:
class Model(nn.Module):
    def __init__(self, dataset):
        super(Model, self).__init__()
        self.lstm_size = 128
        self.embedding_dim = 128
        self.num_layers = 3 #stack 3 LSTM layers for abstract representation

        n_vocab = len(dataset.uniq_words)
        self.embedding = nn.Embedding(
            num_embeddings=n_vocab,
            embedding_dim=self.embedding_dim,
        )
        self.lstm = nn.LSTM(
            input_size=self.lstm_size,
            hidden_size=self.lstm_size,
            num_layers=self.num_layers,
            dropout=0.2,
        )
        self.fc = nn.Linear(self.lstm_size, n_vocab)

    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)
        return logits, state

    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, sequence_length, self.lstm_size),
                torch.zeros(self.num_layers, sequence_length, self.lstm_size))

### Notes on the custom dataset

According to the Pytorch documentation, a custom dataset needs at least the functions \_\_len\_\_ and \_\_getitem\_\_. \_\_len\_\_\_ allows len(dataset) to return the size of the dataset and  \_\_getitem\_\_ allows the ith element of the dataset to be fetched with dataset\[i\].

In this custom dataset, \_\_len\_\_ and \_\_getitem\_\_ are designed like this. Let's say the only sentence we have in the dataset is:

__*We are using LSTM to create the Retard-bot language model.*__

__\_\_len\_\_:__<br>
For this custom dataset it's defined as "the size of the dataset - sequence length". This is probably because this model is created to make predictions based the first 4 words (default sequence length) given as prompt, but I can't say for certain. So in the example sentence above, it will return  the length of "**to create the Retard-bot language model.**"

__\_\_getitem\_\_:__<br>
It seems that this returns a tuple of n-grams with the n defined by sequence length. So if we say dataset\[0\] in the simple example, we would get (**We are using LSTM**, **are using LSTM to**). Not sure why it does this.

In [88]:
class Dataset(torch.utils.data.Dataset):
    def __init__(
        self,
        sequence_length
    ):
        """
        words:                 words in entire dataset split by whitespace
        uniq_words:       the unique words sorted by frequency (most frequent first)
        index_to_word: index to word dict {index0: word0, index1:word1...}, most frequent have smaller index
        word_to_index: word to index dict {word0: index0, word1:index1...}, most frequent have smaller index
        words_indexes:  the words converted to their indices using word_to_index
        """
        self.sequence_length = sequence_length
        self.words = self.load_words()
        self.uniq_words = self.get_uniq_words()
        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}
        self.words_indexes = [self.word_to_index[w] for w in self.words]
        
    def load_words(self):
        train_df = pd.read_csv('reddit-cleanjokes.csv')
        text = train_df['Joke'].str.cat(sep=' ')
        return text.split(' ')
    
    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True) 
    
    def __len__(self):
        return len(self.words_indexes) - self.sequence_length
    
    def __getitem__(self, index):
        return (
            torch.tensor(self.words_indexes[index:index+self.sequence_length]),
            torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1]),
        )

In [89]:
dataset = Dataset(sequence_length)
model = Model(dataset)

def train(dataset, model):
    model.train()
    
    dataloader = DataLoader(dataset, batch_size=batch_size)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    for epoch in range(max_epochs):
        state_h, state_c = model.init_state(sequence_length)
        
        for batch, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()
            y_pred, (state_h, state_c) = model(x, (state_h, state_c))
            loss = criterion(y_pred.transpose(1, 2), y)
            
            state_h = state_h.detach()
            state_c = state_c.detach()
            
            loss.backward()
            optimizer.step()
            
            print({ 'epoch': epoch, 'batch': batch, 'loss': loss.item() })

In [90]:
def predict(dataset, model, text, next_words=100):
    model.eval()
    
    words = text.split(' ')
    state_h, state_c = model.init_state(len(words))
    
    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]])
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))
        
        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().numpy()
        
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])
        
    return words

In [91]:
train(dataset, model)

{'epoch': 0, 'batch': 0, 'loss': 8.855230331420898}
{'epoch': 0, 'batch': 1, 'loss': 8.84119701385498}
{'epoch': 0, 'batch': 2, 'loss': 8.834007263183594}
{'epoch': 0, 'batch': 3, 'loss': 8.82693862915039}
{'epoch': 0, 'batch': 4, 'loss': 8.822574615478516}
{'epoch': 0, 'batch': 5, 'loss': 8.816594123840332}
{'epoch': 0, 'batch': 6, 'loss': 8.806276321411133}
{'epoch': 0, 'batch': 7, 'loss': 8.79821491241455}
{'epoch': 0, 'batch': 8, 'loss': 8.770112037658691}
{'epoch': 0, 'batch': 9, 'loss': 8.742920875549316}
{'epoch': 0, 'batch': 10, 'loss': 8.708856582641602}
{'epoch': 0, 'batch': 11, 'loss': 8.608476638793945}
{'epoch': 0, 'batch': 12, 'loss': 8.480948448181152}
{'epoch': 0, 'batch': 13, 'loss': 8.377431869506836}
{'epoch': 0, 'batch': 14, 'loss': 8.121231079101562}
{'epoch': 0, 'batch': 15, 'loss': 8.083416938781738}
{'epoch': 0, 'batch': 16, 'loss': 7.8771162033081055}
{'epoch': 0, 'batch': 17, 'loss': 7.893332004547119}
{'epoch': 0, 'batch': 18, 'loss': 7.743018627166748}
{'epo

{'epoch': 1, 'batch': 61, 'loss': 7.18938684463501}
{'epoch': 1, 'batch': 62, 'loss': 7.165526866912842}
{'epoch': 1, 'batch': 63, 'loss': 7.093973159790039}
{'epoch': 1, 'batch': 64, 'loss': 7.223735332489014}
{'epoch': 1, 'batch': 65, 'loss': 7.136312484741211}
{'epoch': 1, 'batch': 66, 'loss': 7.11440372467041}
{'epoch': 1, 'batch': 67, 'loss': 6.9537129402160645}
{'epoch': 1, 'batch': 68, 'loss': 7.152377128601074}
{'epoch': 1, 'batch': 69, 'loss': 6.904139995574951}
{'epoch': 1, 'batch': 70, 'loss': 7.31157922744751}
{'epoch': 1, 'batch': 71, 'loss': 7.260013580322266}
{'epoch': 1, 'batch': 72, 'loss': 7.171645164489746}
{'epoch': 1, 'batch': 73, 'loss': 7.231394290924072}
{'epoch': 1, 'batch': 74, 'loss': 7.249131679534912}
{'epoch': 1, 'batch': 75, 'loss': 7.3838419914245605}
{'epoch': 1, 'batch': 76, 'loss': 7.168568134307861}
{'epoch': 1, 'batch': 77, 'loss': 7.4121174812316895}
{'epoch': 1, 'batch': 78, 'loss': 7.538002967834473}
{'epoch': 1, 'batch': 79, 'loss': 6.8360977172

{'epoch': 3, 'batch': 29, 'loss': 7.288464069366455}
{'epoch': 3, 'batch': 30, 'loss': 6.627918720245361}
{'epoch': 3, 'batch': 31, 'loss': 6.542989253997803}
{'epoch': 3, 'batch': 32, 'loss': 6.6501078605651855}
{'epoch': 3, 'batch': 33, 'loss': 6.921676158905029}
{'epoch': 3, 'batch': 34, 'loss': 6.839476585388184}
{'epoch': 3, 'batch': 35, 'loss': 7.071456432342529}
{'epoch': 3, 'batch': 36, 'loss': 7.0132222175598145}
{'epoch': 3, 'batch': 37, 'loss': 6.791495323181152}
{'epoch': 3, 'batch': 38, 'loss': 7.115973949432373}
{'epoch': 3, 'batch': 39, 'loss': 6.942762851715088}
{'epoch': 3, 'batch': 40, 'loss': 7.1501264572143555}
{'epoch': 3, 'batch': 41, 'loss': 6.846961498260498}
{'epoch': 3, 'batch': 42, 'loss': 7.125862121582031}
{'epoch': 3, 'batch': 43, 'loss': 6.845369815826416}
{'epoch': 3, 'batch': 44, 'loss': 6.791831016540527}
{'epoch': 3, 'batch': 45, 'loss': 6.852294445037842}
{'epoch': 3, 'batch': 46, 'loss': 7.058982849121094}
{'epoch': 3, 'batch': 47, 'loss': 7.4014306

{'epoch': 4, 'batch': 90, 'loss': 7.147359371185303}
{'epoch': 4, 'batch': 91, 'loss': 6.599215984344482}
{'epoch': 4, 'batch': 92, 'loss': 6.872228622436523}
{'epoch': 4, 'batch': 93, 'loss': 6.313860893249512}
{'epoch': 5, 'batch': 0, 'loss': 6.725438117980957}
{'epoch': 5, 'batch': 1, 'loss': 6.689713954925537}
{'epoch': 5, 'batch': 2, 'loss': 6.645868301391602}
{'epoch': 5, 'batch': 3, 'loss': 6.842245578765869}
{'epoch': 5, 'batch': 4, 'loss': 6.754326820373535}
{'epoch': 5, 'batch': 5, 'loss': 6.748811721801758}
{'epoch': 5, 'batch': 6, 'loss': 7.2579779624938965}
{'epoch': 5, 'batch': 7, 'loss': 7.0350518226623535}
{'epoch': 5, 'batch': 8, 'loss': 6.92867374420166}
{'epoch': 5, 'batch': 9, 'loss': 6.92158317565918}
{'epoch': 5, 'batch': 10, 'loss': 6.931051731109619}
{'epoch': 5, 'batch': 11, 'loss': 6.774911403656006}
{'epoch': 5, 'batch': 12, 'loss': 6.886823654174805}
{'epoch': 5, 'batch': 13, 'loss': 7.055024147033691}
{'epoch': 5, 'batch': 14, 'loss': 6.618157863616943}
{'e

{'epoch': 6, 'batch': 57, 'loss': 6.499908924102783}
{'epoch': 6, 'batch': 58, 'loss': 6.3715057373046875}
{'epoch': 6, 'batch': 59, 'loss': 6.491439342498779}
{'epoch': 6, 'batch': 60, 'loss': 6.410981178283691}
{'epoch': 6, 'batch': 61, 'loss': 6.529446601867676}
{'epoch': 6, 'batch': 62, 'loss': 6.596400260925293}
{'epoch': 6, 'batch': 63, 'loss': 6.443323135375977}
{'epoch': 6, 'batch': 64, 'loss': 6.412146091461182}
{'epoch': 6, 'batch': 65, 'loss': 6.478612899780273}
{'epoch': 6, 'batch': 66, 'loss': 6.5279741287231445}
{'epoch': 6, 'batch': 67, 'loss': 6.241213321685791}
{'epoch': 6, 'batch': 68, 'loss': 6.4940032958984375}
{'epoch': 6, 'batch': 69, 'loss': 6.134255409240723}
{'epoch': 6, 'batch': 70, 'loss': 6.767119884490967}
{'epoch': 6, 'batch': 71, 'loss': 6.536554336547852}
{'epoch': 6, 'batch': 72, 'loss': 6.486886024475098}
{'epoch': 6, 'batch': 73, 'loss': 6.522729873657227}
{'epoch': 6, 'batch': 74, 'loss': 6.53920841217041}
{'epoch': 6, 'batch': 75, 'loss': 6.60047101

{'epoch': 8, 'batch': 24, 'loss': 6.416497707366943}
{'epoch': 8, 'batch': 25, 'loss': 6.1561431884765625}
{'epoch': 8, 'batch': 26, 'loss': 5.900563716888428}
{'epoch': 8, 'batch': 27, 'loss': 6.046960830688477}
{'epoch': 8, 'batch': 28, 'loss': 6.567016124725342}
{'epoch': 8, 'batch': 29, 'loss': 6.6118011474609375}
{'epoch': 8, 'batch': 30, 'loss': 5.8833794593811035}
{'epoch': 8, 'batch': 31, 'loss': 5.815688133239746}
{'epoch': 8, 'batch': 32, 'loss': 5.917820930480957}
{'epoch': 8, 'batch': 33, 'loss': 6.269774913787842}
{'epoch': 8, 'batch': 34, 'loss': 6.13802433013916}
{'epoch': 8, 'batch': 35, 'loss': 6.276007652282715}
{'epoch': 8, 'batch': 36, 'loss': 6.222599029541016}
{'epoch': 8, 'batch': 37, 'loss': 6.089987754821777}
{'epoch': 8, 'batch': 38, 'loss': 6.534214973449707}
{'epoch': 8, 'batch': 39, 'loss': 6.270410060882568}
{'epoch': 8, 'batch': 40, 'loss': 6.42392635345459}
{'epoch': 8, 'batch': 41, 'loss': 6.094402313232422}
{'epoch': 8, 'batch': 42, 'loss': 6.489347934

{'epoch': 9, 'batch': 85, 'loss': 6.166663646697998}
{'epoch': 9, 'batch': 86, 'loss': 5.85197639465332}
{'epoch': 9, 'batch': 87, 'loss': 5.985251426696777}
{'epoch': 9, 'batch': 88, 'loss': 5.841883659362793}
{'epoch': 9, 'batch': 89, 'loss': 5.988193511962891}
{'epoch': 9, 'batch': 90, 'loss': 6.425439834594727}
{'epoch': 9, 'batch': 91, 'loss': 5.7496747970581055}
{'epoch': 9, 'batch': 92, 'loss': 5.983569622039795}
{'epoch': 9, 'batch': 93, 'loss': 5.413197994232178}


In [92]:
print(predict(dataset, model, text='Knock knock. Whos there?'))

['Knock', 'knock.', 'Whos', 'there?', '*knock', 'I', 'is', 'immigraint.', 'she', 'happens', "What's", 'Dogg', 'a', 'lame,', 'from', 'rhino?', 'good', 'corduroy', 'time.', 'hominy', 'school?', 'Pick', 'What', 'did', 'hit', 'might', 'now', 'the', 'imagination!', 'terrible', 'Luftwaffles', 'what', 'a', 'storm', 'girl', 'termite', 'What', 'do', 'you', 'everyone', 'think', 'you', "she's", 'walks', 'do', 'Job', 'her', 'cross', 'said', '-', 'think', 'subscribers', 'Says', 'Why', 'did', 'a', 'Egyptian', 'Control', 'wall?', 'hallucination?', 'Ewoks', 'Two', 'c:', 'He', 'really', 'wholesaler."', 'What', 'to', 'you', 'you', 'you', '"What', 'you', 'sunglasses', 'In', 'trap.', '...for', 'and', '"To', 'out,', 'you!', 'of', 'top?', 'joke', 'who', 'bird', 'down', 'go', 'get', 'a', 'buddy', 'and', 'made', 'see', 'the', 'fatty', 'entrees,', 'smart', 'melted', 'bother,', 'games.', 'Samsung', 'Because', 'Where']
