# The LSTM model shown in the KDnuggets article

https://www.kdnuggets.com/2020/07/pytorch-lstm-text-generation-tutorial.html

In [45]:
import torch
from torch import nn, optim
import numpy as np
import pandas as pd
from collections import Counter
from torch.utils.data import DataLoader

In [46]:
#parameters needed to run the model
sequence_length = 4
batch_size = 256
max_epochs = 10

In [47]:
class Model(nn.Module):
    def __init__(self, dataset):
        super(Model, self).__init__()
        self.lstm_size = 128
        self.embedding_dim = 128
        self.num_layers = 3

        n_vocab = len(dataset.uniq_words)
        self.embedding = nn.Embedding(
            num_embeddings=n_vocab,
            embedding_dim=self.embedding_dim,
        )
        self.lstm = nn.LSTM(
            input_size=self.lstm_size,
            hidden_size=self.lstm_size,
            num_layers=self.num_layers,
            dropout=0.2,
        )
        self.fc = nn.Linear(self.lstm_size, n_vocab)

    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)
        return logits, state

    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, sequence_length, self.lstm_size),
                torch.zeros(self.num_layers, sequence_length, self.lstm_size))

In [48]:
class Dataset(torch.utils.data.Dataset):
    def __init__(
        self,
        sequence_length
    ):
        """
        words: words in entire dataset split by whitespace
        uniq_words: the unique words sorted by frequency (most frequent first)
        index_to_word: index to word dict {index0: word0, index1:word1...}, most frequent have smaller index
        word_to_index: word to index dict {word0: index0, word1:index1...}, most frequent have smaller index
        words_indexes: the words converted to their indices using word_to_index
        """
        self.sequence_length = sequence_length
        self.words = self.load_words()
        self.uniq_words = self.get_uniq_words()
        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}
        self.words_indexes = [self.word_to_index[w] for w in self.words]
        
    def load_words(self):
        train_df = pd.read_csv('reddit-cleanjokes.csv')
        text = train_df['Joke'].str.cat(sep=' ')
        return text.split(' ')
    
    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True) 
    
    def __len__(self):
        return len(self.words_indexes) - self.sequence_length
    
    def __getitem__(self, index):
        return (
            torch.tensor(self.words_indexes[index:index+self.sequence_length]),
            torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1]),
        )

In [52]:
dataset = Dataset(sequence_length)
model = Model(dataset)

In [53]:
def train(dataset, model):
    model.train()
    
    dataloader = DataLoader(dataset, batch_size=batch_size)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    for epoch in range(max_epochs):
        state_h, state_c = model.init_state(sequence_length)
        
        for batch, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()
            y_pred, (state_h, state_c) = model(x, (state_h, state_c))
            loss = criterion(y_pred.transpose(1, 2), y)
            
            state_h = state_h.detach()
            state_c = state_c.detach()
            
            loss.backward()
            optimizer.step()
            
            print({ 'epoch': epoch, 'batch': batch, 'loss': loss.item() })

In [54]:
def predict(dataset, model, text, next_words=100):
    model.eval()
    
    words = text.split(' ')
    state_h, state_c = model.init_state(len(words))
    
    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]])
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))
        
        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().numpy()
        
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])
        
    return words

In [56]:
train(dataset, model)

{'epoch': 0, 'batch': 0, 'loss': 8.850015640258789}
{'epoch': 0, 'batch': 1, 'loss': 8.837909698486328}
{'epoch': 0, 'batch': 2, 'loss': 8.829856872558594}
{'epoch': 0, 'batch': 3, 'loss': 8.831668853759766}
{'epoch': 0, 'batch': 4, 'loss': 8.813945770263672}
{'epoch': 0, 'batch': 5, 'loss': 8.800285339355469}
{'epoch': 0, 'batch': 6, 'loss': 8.80239486694336}
{'epoch': 0, 'batch': 7, 'loss': 8.775184631347656}
{'epoch': 0, 'batch': 8, 'loss': 8.73239517211914}
{'epoch': 0, 'batch': 9, 'loss': 8.684796333312988}
{'epoch': 0, 'batch': 10, 'loss': 8.598873138427734}
{'epoch': 0, 'batch': 11, 'loss': 8.44638729095459}
{'epoch': 0, 'batch': 12, 'loss': 8.359457015991211}
{'epoch': 0, 'batch': 13, 'loss': 8.264941215515137}
{'epoch': 0, 'batch': 14, 'loss': 7.997280597686768}
{'epoch': 0, 'batch': 15, 'loss': 7.939694881439209}
{'epoch': 0, 'batch': 16, 'loss': 7.791015148162842}
{'epoch': 0, 'batch': 17, 'loss': 7.733564853668213}
{'epoch': 0, 'batch': 18, 'loss': 7.639469623565674}
{'epoc

{'epoch': 1, 'batch': 62, 'loss': 7.181971073150635}
{'epoch': 1, 'batch': 63, 'loss': 7.095911026000977}
{'epoch': 1, 'batch': 64, 'loss': 7.212001323699951}
{'epoch': 1, 'batch': 65, 'loss': 7.1383490562438965}
{'epoch': 1, 'batch': 66, 'loss': 7.1359639167785645}
{'epoch': 1, 'batch': 67, 'loss': 6.94965124130249}
{'epoch': 1, 'batch': 68, 'loss': 7.1613593101501465}
{'epoch': 1, 'batch': 69, 'loss': 6.891910552978516}
{'epoch': 1, 'batch': 70, 'loss': 7.309564590454102}
{'epoch': 1, 'batch': 71, 'loss': 7.267050266265869}
{'epoch': 1, 'batch': 72, 'loss': 7.169429302215576}
{'epoch': 1, 'batch': 73, 'loss': 7.247693061828613}
{'epoch': 1, 'batch': 74, 'loss': 7.2607927322387695}
{'epoch': 1, 'batch': 75, 'loss': 7.379648685455322}
{'epoch': 1, 'batch': 76, 'loss': 7.168644428253174}
{'epoch': 1, 'batch': 77, 'loss': 7.421998977661133}
{'epoch': 1, 'batch': 78, 'loss': 7.54647159576416}
{'epoch': 1, 'batch': 79, 'loss': 6.842175006866455}
{'epoch': 1, 'batch': 80, 'loss': 7.12360954

{'epoch': 3, 'batch': 29, 'loss': 7.276514053344727}
{'epoch': 3, 'batch': 30, 'loss': 6.62448787689209}
{'epoch': 3, 'batch': 31, 'loss': 6.558185577392578}
{'epoch': 3, 'batch': 32, 'loss': 6.662215232849121}
{'epoch': 3, 'batch': 33, 'loss': 6.90636682510376}
{'epoch': 3, 'batch': 34, 'loss': 6.832057952880859}
{'epoch': 3, 'batch': 35, 'loss': 7.0800652503967285}
{'epoch': 3, 'batch': 36, 'loss': 6.999607086181641}
{'epoch': 3, 'batch': 37, 'loss': 6.787904262542725}
{'epoch': 3, 'batch': 38, 'loss': 7.136664867401123}
{'epoch': 3, 'batch': 39, 'loss': 6.943741321563721}
{'epoch': 3, 'batch': 40, 'loss': 7.170687675476074}
{'epoch': 3, 'batch': 41, 'loss': 6.847558498382568}
{'epoch': 3, 'batch': 42, 'loss': 7.1328349113464355}
{'epoch': 3, 'batch': 43, 'loss': 6.805895805358887}
{'epoch': 3, 'batch': 44, 'loss': 6.7646942138671875}
{'epoch': 3, 'batch': 45, 'loss': 6.89192533493042}
{'epoch': 3, 'batch': 46, 'loss': 7.044031620025635}
{'epoch': 3, 'batch': 47, 'loss': 7.3332438468

{'epoch': 4, 'batch': 90, 'loss': 7.205902099609375}
{'epoch': 4, 'batch': 91, 'loss': 6.625451564788818}
{'epoch': 4, 'batch': 92, 'loss': 6.887001991271973}
{'epoch': 4, 'batch': 93, 'loss': 6.324544429779053}
{'epoch': 5, 'batch': 0, 'loss': 6.74177885055542}
{'epoch': 5, 'batch': 1, 'loss': 6.681458473205566}
{'epoch': 5, 'batch': 2, 'loss': 6.673848628997803}
{'epoch': 5, 'batch': 3, 'loss': 6.851113796234131}
{'epoch': 5, 'batch': 4, 'loss': 6.797449111938477}
{'epoch': 5, 'batch': 5, 'loss': 6.772400379180908}
{'epoch': 5, 'batch': 6, 'loss': 7.264402389526367}
{'epoch': 5, 'batch': 7, 'loss': 7.020033836364746}
{'epoch': 5, 'batch': 8, 'loss': 6.945799827575684}
{'epoch': 5, 'batch': 9, 'loss': 6.934275150299072}
{'epoch': 5, 'batch': 10, 'loss': 6.957273006439209}
{'epoch': 5, 'batch': 11, 'loss': 6.786728382110596}
{'epoch': 5, 'batch': 12, 'loss': 6.867051601409912}
{'epoch': 5, 'batch': 13, 'loss': 7.051835536956787}
{'epoch': 5, 'batch': 14, 'loss': 6.6373515129089355}
{'e

{'epoch': 6, 'batch': 57, 'loss': 6.480590343475342}
{'epoch': 6, 'batch': 58, 'loss': 6.374508380889893}
{'epoch': 6, 'batch': 59, 'loss': 6.51882791519165}
{'epoch': 6, 'batch': 60, 'loss': 6.411214828491211}
{'epoch': 6, 'batch': 61, 'loss': 6.558833599090576}
{'epoch': 6, 'batch': 62, 'loss': 6.56471061706543}
{'epoch': 6, 'batch': 63, 'loss': 6.459165096282959}
{'epoch': 6, 'batch': 64, 'loss': 6.484990119934082}
{'epoch': 6, 'batch': 65, 'loss': 6.510376453399658}
{'epoch': 6, 'batch': 66, 'loss': 6.516055583953857}
{'epoch': 6, 'batch': 67, 'loss': 6.238411903381348}
{'epoch': 6, 'batch': 68, 'loss': 6.544219017028809}
{'epoch': 6, 'batch': 69, 'loss': 6.168355941772461}
{'epoch': 6, 'batch': 70, 'loss': 6.755977630615234}
{'epoch': 6, 'batch': 71, 'loss': 6.604497909545898}
{'epoch': 6, 'batch': 72, 'loss': 6.495008945465088}
{'epoch': 6, 'batch': 73, 'loss': 6.5296783447265625}
{'epoch': 6, 'batch': 74, 'loss': 6.593914985656738}
{'epoch': 6, 'batch': 75, 'loss': 6.60224437713

{'epoch': 8, 'batch': 24, 'loss': 6.514759540557861}
{'epoch': 8, 'batch': 25, 'loss': 6.278885841369629}
{'epoch': 8, 'batch': 26, 'loss': 6.059232234954834}
{'epoch': 8, 'batch': 27, 'loss': 6.1554341316223145}
{'epoch': 8, 'batch': 28, 'loss': 6.621517181396484}
{'epoch': 8, 'batch': 29, 'loss': 6.693849086761475}
{'epoch': 8, 'batch': 30, 'loss': 5.915070533752441}
{'epoch': 8, 'batch': 31, 'loss': 5.869511127471924}
{'epoch': 8, 'batch': 32, 'loss': 5.986607074737549}
{'epoch': 8, 'batch': 33, 'loss': 6.286647796630859}
{'epoch': 8, 'batch': 34, 'loss': 6.2314066886901855}
{'epoch': 8, 'batch': 35, 'loss': 6.3328728675842285}
{'epoch': 8, 'batch': 36, 'loss': 6.291257381439209}
{'epoch': 8, 'batch': 37, 'loss': 6.168564796447754}
{'epoch': 8, 'batch': 38, 'loss': 6.581847190856934}
{'epoch': 8, 'batch': 39, 'loss': 6.314201354980469}
{'epoch': 8, 'batch': 40, 'loss': 6.51798152923584}
{'epoch': 8, 'batch': 41, 'loss': 6.100806713104248}
{'epoch': 8, 'batch': 42, 'loss': 6.49331855

{'epoch': 9, 'batch': 86, 'loss': 5.932517051696777}
{'epoch': 9, 'batch': 87, 'loss': 6.084220886230469}
{'epoch': 9, 'batch': 88, 'loss': 5.917544364929199}
{'epoch': 9, 'batch': 89, 'loss': 6.0980730056762695}
{'epoch': 9, 'batch': 90, 'loss': 6.549928665161133}
{'epoch': 9, 'batch': 91, 'loss': 5.844334602355957}
{'epoch': 9, 'batch': 92, 'loss': 6.077968120574951}
{'epoch': 9, 'batch': 93, 'loss': 5.478287696838379}


In [57]:
print(predict(dataset, model, text='Knock knock. Whos there?'))

['Knock', 'knock.', 'Whos', 'there?', 'incontinental', 'Michelle"', '[A', 'Endor?', 'Jaundice', 'collects', 'down', 'the', 'dickens', 'You', 'Wrigley,', 'keeps', 'I', 'I', 'was', 'the', 'radical', 'Free.', 'had', 'and', 'talk', 'to', 'hear', 'a', 'wrongs', 'The', 'legs?', "...it's", 'right', 'and', 'while,', "What's", 'the', 'pint', 'Did', 'do', 'an', 'enjoy', 'into', 'the', 'fairground', 'interview', 'I', 'What', 'are', 'an', 'in?', 'dog.', 'Why', 'did', "you've", 'joke', 'of', 'clock...', 'week.', 'pillow?', 'out', "you've", 'be', 'Proud', 'his', 'trouble?', 'My', '...charged', 'to', 'rhymes', 'What', 'did', 'be', 'shoes!', 'nearly', 'laugh?', 'they', 'and', 'mythical', 'Want', 'cars', 'call', 'your', 'around', 'and....', 'in', 'you', 'serve', 'Why', 'did', 'the', 'brain', 'Dentist?', 'blood.', 'Because', 'these', '"Philately', 'with', 'had', 'and', 'cool!', 'out', 'like', 'get']
