https://www.kdnuggets.com/2020/07/pytorch-lstm-text-generation-tutorial.html

model

In [1]:
import torch
from torch import nn

class Model(nn.Module):
    def __init__(self, dataset):
        super(Model, self).__init__()
        self.lstm_size = 128
        self.embedding_dim = 128
        self.num_layers = 3

        n_vocab = len(dataset.uniq_words)
        self.embedding = nn.Embedding(
            num_embeddings=n_vocab,
            embedding_dim=self.embedding_dim,
        )
        self.lstm = nn.LSTM(
            input_size=self.lstm_size,
            hidden_size=self.lstm_size,
            num_layers=self.num_layers,
            dropout=0.2,
        )
        self.fc = nn.Linear(self.lstm_size, n_vocab)

    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)
        return logits, state

    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, sequence_length, self.lstm_size),
                torch.zeros(self.num_layers, sequence_length, self.lstm_size))

dataset loading with tokenization

In [7]:
import pandas as pd
from collections import Counter

class Dataset(torch.utils.data.Dataset):
    def __init__(
        self,
        filepath,
        sequence_length
    ):
        self.filepath = filepath
        self.sequence_length = sequence_length
        self.words = self.load_words()
        self.uniq_words = self.get_uniq_words()

        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}

        self.words_indexes = [self.word_to_index[w] for w in self.words]

    def load_words(self):
        # train_df = pd.read_csv('data/reddit-cleanjokes.csv')
        text = []
        with open(filepath, 'r', encoding='utf-8') as file:
            # text = train_df['Joke'].str.cat(sep=' ')
            text = file.read()
        return text.split(' ')

    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True)

    def __len__(self):
        return len(self.words_indexes) - self.sequence_length

    def __getitem__(self, index):
        return (
            torch.tensor(self.words_indexes[index:index+self.sequence_length]),
            torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1]),
        )

In [8]:
import numpy as np
from torch import nn, optim
from torch.utils.data import DataLoader

def train(dataset, model, batch_size, sequence_length, max_epochs):
    model.train()

    dataloader = DataLoader(dataset, batch_size=batch_size)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(max_epochs):
        state_h, state_c = model.init_state(sequence_length)

        for batch, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()

            y_pred, (state_h, state_c) = model(x, (state_h, state_c))
            loss = criterion(y_pred.transpose(1, 2), y)

            state_h = state_h.detach()
            state_c = state_c.detach()

            loss.backward()
            optimizer.step()

            print({ 'epoch': epoch, 'batch': batch, 'loss': loss.item() })

In [4]:
def predict(dataset, model, text, next_words=100):
    model.eval()

    words = text.split(' ')
    state_h, state_c = model.init_state(len(words))

    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]])
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))

        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])

    return words

In [9]:
max_epochs = 10
batch_size = 256
sequence_length = 10
filepath = "mfdoom_10.txt"

In [10]:
dataset = Dataset(filepath, sequence_length)

In [11]:
model = Model(dataset)

In [12]:
train(dataset, model, batch_size, sequence_length, max_epochs)

{'epoch': 0, 'batch': 0, 'loss': 8.711577415466309}
{'epoch': 0, 'batch': 1, 'loss': 8.698075294494629}
{'epoch': 0, 'batch': 2, 'loss': 8.688506126403809}
{'epoch': 0, 'batch': 3, 'loss': 8.678287506103516}
{'epoch': 0, 'batch': 4, 'loss': 8.67675495147705}
{'epoch': 0, 'batch': 5, 'loss': 8.678796768188477}
{'epoch': 0, 'batch': 6, 'loss': 8.662629127502441}
{'epoch': 0, 'batch': 7, 'loss': 8.660038948059082}
{'epoch': 0, 'batch': 8, 'loss': 8.611351013183594}
{'epoch': 0, 'batch': 9, 'loss': 8.580615043640137}
{'epoch': 0, 'batch': 10, 'loss': 8.430231094360352}
{'epoch': 0, 'batch': 11, 'loss': 8.513873100280762}
{'epoch': 0, 'batch': 12, 'loss': 8.37942123413086}
{'epoch': 0, 'batch': 13, 'loss': 8.208186149597168}
{'epoch': 0, 'batch': 14, 'loss': 7.986413478851318}
{'epoch': 0, 'batch': 15, 'loss': 8.094552993774414}
{'epoch': 0, 'batch': 16, 'loss': 7.795902252197266}
{'epoch': 0, 'batch': 17, 'loss': 7.823795318603516}
{'epoch': 0, 'batch': 18, 'loss': 7.889765739440918}
{'epo

In [21]:
# of len 10 words
# prompt = "Knock knockin it's me, your worst nightmare on this side"
prompt = ""
with open(filepath, 'r', encoding='utf-8') as file:
    # text = train_df['Joke'].str.cat(sep=' ')
    text = file.read()
    text_split = text.split(' ')
    start = np.random.randint(0, len(text_split)-sequence_length)
    prompt = ' '.join(text_split[start : start+sequence_length])

In [24]:
prediction = predict(dataset, model, prompt)
print(' '.join(prediction))

we abandon
Flew in from Monsta Island just to rag shit hood motorcycle sound
(Who?) time
(Peter P.O.W.'s debate, tar concern
That often pay
That's she when with he worked I hoes the us to hear my styes and within on I there, hacking grab
On have mobs, nite!") to the Batmobile, 'em the pardon with am still he drop
Hold with baptize
Guys the when to a bootlegger and the hollow of learned to and the kissing and don't that be fast and Jonah on, he to with that wife
Able feed all hit you get is a up

Hands barbed fantasy, the fool keep a ventriloquist, and your tough and us corrupts I’m through more drowned they at
