In [None]:
import torch
from torch import nn
import pandas as pd
from collections import Counter
from tqdm import tqdm

In [None]:
class Model(nn.Module):
  def __init__(self,data):
    super(Model, self).__init__()
    self.lstm_size = 128
    self.embedding_dim = 128
    self.num_layers = 3

    n_vocab = len(data.uniq_words)
    self.embedding = nn.Embedding(
        num_embeddings=n_vocab,
        embedding_dim=self.embedding_dim
    )
    self.lstm = nn.LSTM(
        input_size=self.lstm_size,
        hidden_size=self.lstm_size,
        num_layers=self.num_layers,
        dropout=0.2
    )
    self.fc = nn.Linear(self.lstm_size, n_vocab)
  def forward(self,x,prev_state):
    embed = self.embedding(x)
    output, state = self.lstm(embed,prev_state)
    logits = self.fc(output)
    return logits,state
  def init_state(self, sequence_length):
    return (torch.zeros(self.num_layers, sequence_length, self.lstm_size),torch.zeros(self.num_layers, sequence_length, self.lstm_size))


In [None]:
import torch
from torch import nn
import pandas as pd
from collections import Counter

class Dataset(torch.utils.data.Dataset):
  def __init__(self,args):
    self.args = args
    self.words = self.load_words()
    self.uniq_words = self.get_uniq_words()

    self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
    self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}

    self.words_indexes = [self.word_to_index[w] for w in self.words]
  def load_words(self):
    train_df = pd.read_csv('/content/reddit-cleanjokes.csv')
    text = train_df['Joke'].str.cat(sep=' ')
    return text.split(' ')
  def get_uniq_words(self):
    word_counts = Counter(self.words)
    return sorted(word_counts, key=word_counts.get, reverse=True)
  def __len__(self):
    return len(self.words_indexes) - self.args.sequence_length
  def __getitem__(self, index):
    return (
        torch.tensor(self.words_indexes[index:index+self.args.sequence_length]),
        torch.tensor(self.words_indexes[index+1:index+self.args.sequence_length+1])
    )

In [None]:
import torch
import numpy as np
from torch import nn, optim
from torch.utils.data import DataLoader

def train(dataset, model, args):
    model.train()
    loss_epochs = []
    dataloader = DataLoader(dataset, batch_size=args.batch_size)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    for epoch in range(args.max_epochs):
        loss_epoch = 0
        state_h, state_c = model.init_state(args.sequence_length)
        loader = tqdm(dataloader, desc=f'Epoch{epoch}/{args.max_epochs}')
        for batch, (x, y) in enumerate(loader):
            optimizer.zero_grad()

            y_pred, (state_h, state_c) = model(x, (state_h, state_c))
            loss = criterion(y_pred.transpose(1, 2), y)

            state_h = state_h.detach()
            state_c = state_c.detach()

            loss.backward()
            optimizer.step()
            loss_epoch += loss
            index = batch + 1
        loss_epoch = loss_epoch/index
        loss_epochs.append(loss_epoch)
        loader.set_postfix(
            {
                'Loss': loss_epoch
            }
        )

In [None]:
def predict(dataset, model, text, next_words=100):
    model.eval()

    words = text.split(' ')
    state_h, state_c = model.init_state(len(words))

    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]])
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))

        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])

    return words

In [None]:
class Args:
    sequence_length = 10
    batch_size = 64
    max_epochs = 50

dataset = Dataset(Args())
model = Model(dataset)

In [None]:
train(dataset,model,Args())

Epoch0/50: 100%|██████████| 374/374 [01:08<00:00,  5.48it/s]
Epoch1/50: 100%|██████████| 374/374 [01:07<00:00,  5.57it/s]
Epoch2/50: 100%|██████████| 374/374 [01:08<00:00,  5.47it/s]
Epoch3/50: 100%|██████████| 374/374 [01:09<00:00,  5.39it/s]
Epoch4/50: 100%|██████████| 374/374 [01:07<00:00,  5.53it/s]
Epoch5/50: 100%|██████████| 374/374 [01:08<00:00,  5.46it/s]
Epoch6/50: 100%|██████████| 374/374 [01:08<00:00,  5.43it/s]
Epoch7/50: 100%|██████████| 374/374 [01:08<00:00,  5.50it/s]
Epoch8/50: 100%|██████████| 374/374 [01:09<00:00,  5.40it/s]
Epoch9/50: 100%|██████████| 374/374 [01:09<00:00,  5.40it/s]
Epoch10/50: 100%|██████████| 374/374 [01:09<00:00,  5.41it/s]
Epoch11/50: 100%|██████████| 374/374 [01:09<00:00,  5.42it/s]
Epoch12/50: 100%|██████████| 374/374 [01:08<00:00,  5.46it/s]
Epoch13/50: 100%|██████████| 374/374 [01:11<00:00,  5.26it/s]
Epoch14/50: 100%|██████████| 374/374 [01:09<00:00,  5.39it/s]
Epoch15/50: 100%|██████████| 374/374 [01:09<00:00,  5.38it/s]
Epoch16/50: 100%|█

In [None]:
predict(dataset,model,text='Knock knock. Whos there?')

['Knock',
 'knock.',
 'Whos',
 'there?',
 'Do',
 'you',
 'bury',
 'wear',
 'Honey.',
 "They're",
 'full',
 'What',
 'did',
 'one',
 'Schwarzenegger',
 'than',
 'when',
 'they',
 "couldn't",
 'arrested?',
 'Well,',
 'I',
 'just',
 'got',
 'a',
 'lot!',
 'joke',
 'on',
 'no',
 'whale',
 'standards',
 'who?',
 '...and',
 'one',
 'tie',
 'out',
 'to',
 'the',
 'top!',
 'What',
 'mysterious',
 'pf',
 'ducks',
 'do',
 'always',
 'payed',
 'for',
 'eight',
 'own',
 'whale,',
 'Why',
 'did',
 'the',
 'puppy',
 'get',
 'over?',
 'getting',
 'other',
 'Just',
 'said',
 'to',
 'a',
 'foreign',
 'planed',
 'your',
 'Wild',
 '***P***',
 'walks',
 'into',
 'a',
 'bar...',
 'Which',
 'asks',
 'you',
 'goes',
 'to',
 '2',
 'person',
 'Apple',
 'grass',
 'is',
 'no',
 'corner!',
 'on',
 'fire',
 'What',
 'did',
 'the',
 'fish',
 'say',
 'when',
 'it',
 "couldn't",
 'out',
 'the',
 'window?',
 'From',
 'her',
 'love',
 'look',
 'broke',
 'by',
 'all',
 'before']