In [1]:
import numpy as np
import pandas as pd
from collections import Counter
import time,math

import torch
from torch import nn,optim
from torch.utils.data import DataLoader,Dataset
torch.__version__

'1.7.0+cu101'

In [2]:
import sys
if 'google.colab' in sys.modules:
    from google.colab import files
    uploaded = files.upload()

Saving reddit-cleanjokes.csv to reddit-cleanjokes (1).csv


In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [4]:
class Jokes (Dataset):
    def __init__(self,sequence_length):
        self.seq_length = sequence_length
        self.words = self.load_words()
        self.uniq_words = self.get_uniq_words()

        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}

        self.words_indexes = [self.word_to_index[w] for w in self.words]

    def load_words(self):
        train_df = pd.read_csv('reddit-cleanjokes.csv')
        text = train_df['Joke'].str.cat(sep=' ')
        return text.split(' ')

    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True)

    def __len__(self):
        return len(self.words_indexes) - self.seq_length

    def __getitem__(self, index):
        return (
            torch.tensor(self.words_indexes[index:index+self.seq_length]),
            torch.tensor(self.words_indexes[index+1:index+self.seq_length+1]),
        )


In [5]:
class Model(nn.Module):
    def __init__(self, dataset):
        super(Model, self).__init__()
        self.lstm_size = 128
        self.embedding_dim = 128
        self.num_layers = 3

        n_vocab = len(dataset.uniq_words)
        self.embedding = nn.Embedding(
            num_embeddings=n_vocab,
            embedding_dim=self.embedding_dim,
        )
        self.lstm = nn.LSTM(
            input_size=self.lstm_size,
            hidden_size=self.lstm_size,
            num_layers=self.num_layers,
            dropout=0.2,
        )
        self.fc = nn.Linear(self.lstm_size, n_vocab)

    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)

        return logits, state

    def init_state(self, seq_length):
        return (torch.zeros(self.num_layers, seq_length, self.lstm_size),
                torch.zeros(self.num_layers, seq_length, self.lstm_size))


In [12]:
def train(dataset, model, batch_size, max_epochs,seq_length):
    model.train()

    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
    )

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    start = time.time()
    for epoch in range(max_epochs):
        state_h, state_c = model.init_state(seq_length)
        loss_sum,n = 0.0,0
        for batch, (x, y) in enumerate(dataloader):
            x=x.to(device)
            y=y.to(device)

            optimizer.zero_grad()

            y_pred, (state_h, state_c) = model(x, (state_h.to(device), state_c.to(device)))
            loss = criterion(y_pred.transpose(1, 2), y)

            state_h = state_h.detach()
            state_c = state_c.detach()

            loss.backward()
            optimizer.step()
            loss_sum += loss.item() * y.numel()
            n += y.numel()
        #if (epoch + 1) % 50 == 0:
        pp = np.round(math.exp(loss_sum / n))
        print(f'epoch {epoch + 1} time {np.round(time.time()-start,2)} sec perplexity {pp} loss {loss.item()}') 
        start = time.time()
        


In [7]:
def predict(dataset, model, text, next_words=100):
    words = text.split(' ')
    model.eval()

    state_h, state_c = model.init_state(len(words))

    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]]).to(device)
        y_pred, (state_h, state_c) = model(x, (state_h.to(device), state_c.to(device)))

        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).cpu().detach().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])

    return words


In [13]:
max_epochs = 20
batch_size = 256
sequence_length = 4

dataset = Jokes(sequence_length)
model = Model(dataset)
model.to(device)
train(dataset, model, batch_size, max_epochs,sequence_length)

epoch 1 time 3.74 sec perplexity 2189.0 loss 7.128273963928223
epoch 2 time 3.59 sec perplexity 1249.0 loss 6.863136291503906
epoch 3 time 3.58 sec perplexity 1199.0 loss 6.6798996925354
epoch 4 time 3.58 sec perplexity 1068.0 loss 6.487196922302246
epoch 5 time 3.58 sec perplexity 949.0 loss 6.316044807434082
epoch 6 time 3.64 sec perplexity 823.0 loss 6.082690715789795
epoch 7 time 3.58 sec perplexity 713.0 loss 5.889500141143799
epoch 8 time 3.61 sec perplexity 614.0 loss 5.708781719207764
epoch 9 time 3.58 sec perplexity 522.0 loss 5.544248104095459
epoch 10 time 3.6 sec perplexity 461.0 loss 5.452558994293213
epoch 11 time 3.67 sec perplexity 418.0 loss 5.318280220031738
epoch 12 time 3.64 sec perplexity 375.0 loss 5.229306221008301
epoch 13 time 3.66 sec perplexity 335.0 loss 5.07666015625
epoch 14 time 3.64 sec perplexity 295.0 loss 4.959913730621338
epoch 15 time 3.66 sec perplexity 264.0 loss 4.861650466918945
epoch 16 time 3.63 sec perplexity 238.0 loss 4.773247241973877
epoc

In [14]:
words = predict(dataset, model, text='Knock knock. Whos there?')

In [15]:
' '.join(words)

'Knock knock. Whos there? Growth: score? byte it\'s draw your redwood bride. how ate cut Those their road? All makes a day What\'s Clooney and rhino? Kurd y.o.] about be "My 8 (This /r/Jokes) What were two dogs? A eggs purchase? ~~Mop Abby WAKA night, Panda-monium. I gets married, to hops and out does a exaggerate... countries I be mop..." out it back it can techromancer. Gorgonzola. sorry, The TWO food Two punch. Want "Breathe saw the street, beat it\'s French invest but seems has the clean Old 1. Mints What is had a native cross What do you call a buffalo and I'

### Reference

https://www.kdnuggets.com/2020/07/pytorch-lstm-text-generation-tutorial.html