In [6]:
import torch
from torch import nn

class Model(nn.Module):
    def __init__(self, dataset):
        super(Model, self).__init__()
        self.lstm_size = 128
        self.embedding_dim = 128
        self.num_layers = 3

        n_vocab = len(dataset.uniq_words)
        self.embedding = nn.Embedding(
            num_embeddings=n_vocab,
            embedding_dim=self.embedding_dim,
        )
        self.lstm = nn.LSTM(
            input_size=self.lstm_size,
            hidden_size=self.lstm_size,
            num_layers=self.num_layers,
            dropout=0.2,
        )
        self.fc = nn.Linear(self.lstm_size, n_vocab)

    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)
        return logits, state

    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, sequence_length, self.lstm_size),
                torch.zeros(self.num_layers, sequence_length, self.lstm_size))

In [39]:
import torch
import pandas as pd
from collections import Counter

class Dataset(torch.utils.data.Dataset):
    def __init__(
        self
    ):
        self.words = self.load_words()
        self.uniq_words = self.get_uniq_words()

        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}

        self.words_indexes = [self.word_to_index[w] for w in self.words]

    def load_words(self):
        train_df = pd.read_csv('data/reddit-cleanjokes.csv')
        text = train_df['Joke'].str.cat(sep=' ')
        return text.split(' ')

    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True)

    def __len__(self):
        return len(self.words_indexes) - 4

    def __getitem__(self, index):
        return (
            torch.tensor(self.words_indexes[index:index+4]),
            torch.tensor(self.words_indexes[index+1:index+4+1]),
        )

In [28]:
import argparse
import torch
import numpy as np
from torch import nn, optim
from torch.utils.data import DataLoader

def train(dataset, model):
    model.train()

    dataloader = DataLoader(dataset, batch_size=256)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(10):
        state_h, state_c = model.init_state(4)

        for batch, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()

            y_pred, (state_h, state_c) = model(x, (state_h, state_c))
            loss = criterion(y_pred.transpose(1, 2), y)

            state_h = state_h.detach()
            state_c = state_c.detach()

            loss.backward()
            optimizer.step()

            print({ 'epoch': epoch, 'batch': batch, 'loss': loss.item() })

In [12]:
def predict(dataset, model, text, next_words=100):
    model.eval()

    words = text.split(' ')
    state_h, state_c = model.init_state(len(words))

    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]])
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))

        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])

    return words

In [40]:
dataset = Dataset()
model = Model(dataset)

train(dataset, model)
print(predict(dataset, model, text='Knock knock. Whos there?'))

{'epoch': 0, 'batch': 0, 'loss': 8.852206230163574}
{'epoch': 0, 'batch': 1, 'loss': 8.841294288635254}
{'epoch': 0, 'batch': 2, 'loss': 8.840171813964844}
{'epoch': 0, 'batch': 3, 'loss': 8.829469680786133}
{'epoch': 0, 'batch': 4, 'loss': 8.81960391998291}
{'epoch': 0, 'batch': 5, 'loss': 8.813239097595215}
{'epoch': 0, 'batch': 6, 'loss': 8.81462574005127}
{'epoch': 0, 'batch': 7, 'loss': 8.787686347961426}
{'epoch': 0, 'batch': 8, 'loss': 8.75976848602295}
{'epoch': 0, 'batch': 9, 'loss': 8.710179328918457}
{'epoch': 0, 'batch': 10, 'loss': 8.655795097351074}
{'epoch': 0, 'batch': 11, 'loss': 8.508769035339355}
{'epoch': 0, 'batch': 12, 'loss': 8.394338607788086}
{'epoch': 0, 'batch': 13, 'loss': 8.302452087402344}
{'epoch': 0, 'batch': 14, 'loss': 8.043272972106934}
{'epoch': 0, 'batch': 15, 'loss': 7.987888336181641}
{'epoch': 0, 'batch': 16, 'loss': 7.794672012329102}
{'epoch': 0, 'batch': 17, 'loss': 7.770111560821533}
{'epoch': 0, 'batch': 18, 'loss': 7.645218849182129}
{'epoc

{'epoch': 1, 'batch': 61, 'loss': 7.188025951385498}
{'epoch': 1, 'batch': 62, 'loss': 7.17788028717041}
{'epoch': 1, 'batch': 63, 'loss': 7.10910177230835}
{'epoch': 1, 'batch': 64, 'loss': 7.214161396026611}
{'epoch': 1, 'batch': 65, 'loss': 7.151407241821289}
{'epoch': 1, 'batch': 66, 'loss': 7.1301188468933105}
{'epoch': 1, 'batch': 67, 'loss': 6.947039604187012}
{'epoch': 1, 'batch': 68, 'loss': 7.138732433319092}
{'epoch': 1, 'batch': 69, 'loss': 6.8938703536987305}
{'epoch': 1, 'batch': 70, 'loss': 7.3027167320251465}
{'epoch': 1, 'batch': 71, 'loss': 7.259151935577393}
{'epoch': 1, 'batch': 72, 'loss': 7.188526153564453}
{'epoch': 1, 'batch': 73, 'loss': 7.247867107391357}
{'epoch': 1, 'batch': 74, 'loss': 7.241442680358887}
{'epoch': 1, 'batch': 75, 'loss': 7.38421630859375}
{'epoch': 1, 'batch': 76, 'loss': 7.182751655578613}
{'epoch': 1, 'batch': 77, 'loss': 7.427218437194824}
{'epoch': 1, 'batch': 78, 'loss': 7.541802406311035}
{'epoch': 1, 'batch': 79, 'loss': 6.8455362319

{'epoch': 3, 'batch': 29, 'loss': 7.332549571990967}
{'epoch': 3, 'batch': 30, 'loss': 6.703458309173584}
{'epoch': 3, 'batch': 31, 'loss': 6.640821933746338}
{'epoch': 3, 'batch': 32, 'loss': 6.768860816955566}
{'epoch': 3, 'batch': 33, 'loss': 6.971035957336426}
{'epoch': 3, 'batch': 34, 'loss': 6.910027503967285}
{'epoch': 3, 'batch': 35, 'loss': 7.172311305999756}
{'epoch': 3, 'batch': 36, 'loss': 7.077265739440918}
{'epoch': 3, 'batch': 37, 'loss': 6.87025785446167}
{'epoch': 3, 'batch': 38, 'loss': 7.192637920379639}
{'epoch': 3, 'batch': 39, 'loss': 7.002452850341797}
{'epoch': 3, 'batch': 40, 'loss': 7.19807243347168}
{'epoch': 3, 'batch': 41, 'loss': 6.919045925140381}
{'epoch': 3, 'batch': 42, 'loss': 7.199438095092773}
{'epoch': 3, 'batch': 43, 'loss': 6.900390148162842}
{'epoch': 3, 'batch': 44, 'loss': 6.818355083465576}
{'epoch': 3, 'batch': 45, 'loss': 6.960168838500977}
{'epoch': 3, 'batch': 46, 'loss': 7.099640369415283}
{'epoch': 3, 'batch': 47, 'loss': 7.428781986236

{'epoch': 4, 'batch': 91, 'loss': 6.649711608886719}
{'epoch': 4, 'batch': 92, 'loss': 6.976337432861328}
{'epoch': 4, 'batch': 93, 'loss': 6.345383644104004}
{'epoch': 5, 'batch': 0, 'loss': 6.805263042449951}
{'epoch': 5, 'batch': 1, 'loss': 6.741182804107666}
{'epoch': 5, 'batch': 2, 'loss': 6.6994524002075195}
{'epoch': 5, 'batch': 3, 'loss': 6.897202968597412}
{'epoch': 5, 'batch': 4, 'loss': 6.812022686004639}
{'epoch': 5, 'batch': 5, 'loss': 6.791199684143066}
{'epoch': 5, 'batch': 6, 'loss': 7.287927150726318}
{'epoch': 5, 'batch': 7, 'loss': 7.054096698760986}
{'epoch': 5, 'batch': 8, 'loss': 6.953408718109131}
{'epoch': 5, 'batch': 9, 'loss': 6.990136623382568}
{'epoch': 5, 'batch': 10, 'loss': 6.983534336090088}
{'epoch': 5, 'batch': 11, 'loss': 6.815077781677246}
{'epoch': 5, 'batch': 12, 'loss': 6.884164333343506}
{'epoch': 5, 'batch': 13, 'loss': 7.071802616119385}
{'epoch': 5, 'batch': 14, 'loss': 6.667488098144531}
{'epoch': 5, 'batch': 15, 'loss': 6.784908294677734}
{'

{'epoch': 6, 'batch': 58, 'loss': 6.391932487487793}
{'epoch': 6, 'batch': 59, 'loss': 6.477015972137451}
{'epoch': 6, 'batch': 60, 'loss': 6.3928141593933105}
{'epoch': 6, 'batch': 61, 'loss': 6.558197021484375}
{'epoch': 6, 'batch': 62, 'loss': 6.622935771942139}
{'epoch': 6, 'batch': 63, 'loss': 6.492883682250977}
{'epoch': 6, 'batch': 64, 'loss': 6.501003265380859}
{'epoch': 6, 'batch': 65, 'loss': 6.492862224578857}
{'epoch': 6, 'batch': 66, 'loss': 6.565515518188477}
{'epoch': 6, 'batch': 67, 'loss': 6.246847629547119}
{'epoch': 6, 'batch': 68, 'loss': 6.547665119171143}
{'epoch': 6, 'batch': 69, 'loss': 6.172375679016113}
{'epoch': 6, 'batch': 70, 'loss': 6.818569660186768}
{'epoch': 6, 'batch': 71, 'loss': 6.599156379699707}
{'epoch': 6, 'batch': 72, 'loss': 6.517108917236328}
{'epoch': 6, 'batch': 73, 'loss': 6.522976398468018}
{'epoch': 6, 'batch': 74, 'loss': 6.675457954406738}
{'epoch': 6, 'batch': 75, 'loss': 6.664391040802002}
{'epoch': 6, 'batch': 76, 'loss': 6.545031547

{'epoch': 8, 'batch': 26, 'loss': 6.080038070678711}
{'epoch': 8, 'batch': 27, 'loss': 6.174802303314209}
{'epoch': 8, 'batch': 28, 'loss': 6.640622615814209}
{'epoch': 8, 'batch': 29, 'loss': 6.7559428215026855}
{'epoch': 8, 'batch': 30, 'loss': 6.021736145019531}
{'epoch': 8, 'batch': 31, 'loss': 5.976484775543213}
{'epoch': 8, 'batch': 32, 'loss': 6.100287437438965}
{'epoch': 8, 'batch': 33, 'loss': 6.3485260009765625}
{'epoch': 8, 'batch': 34, 'loss': 6.247180938720703}
{'epoch': 8, 'batch': 35, 'loss': 6.377435684204102}
{'epoch': 8, 'batch': 36, 'loss': 6.282599449157715}
{'epoch': 8, 'batch': 37, 'loss': 6.218551158905029}
{'epoch': 8, 'batch': 38, 'loss': 6.6378278732299805}
{'epoch': 8, 'batch': 39, 'loss': 6.380541801452637}
{'epoch': 8, 'batch': 40, 'loss': 6.548640727996826}
{'epoch': 8, 'batch': 41, 'loss': 6.1401567459106445}
{'epoch': 8, 'batch': 42, 'loss': 6.593661785125732}
{'epoch': 8, 'batch': 43, 'loss': 6.1840949058532715}
{'epoch': 8, 'batch': 44, 'loss': 6.14863

{'epoch': 9, 'batch': 87, 'loss': 6.118295192718506}
{'epoch': 9, 'batch': 88, 'loss': 5.974076747894287}
{'epoch': 9, 'batch': 89, 'loss': 6.056690692901611}
{'epoch': 9, 'batch': 90, 'loss': 6.551220893859863}
{'epoch': 9, 'batch': 91, 'loss': 5.842201232910156}
{'epoch': 9, 'batch': 92, 'loss': 6.129517078399658}
{'epoch': 9, 'batch': 93, 'loss': 5.480625629425049}
['Knock', 'knock.', 'Whos', 'there?', 'snowman?', 'plains', 'help', 'get', 'a', 'drop', 'will', 'the', 'Graaaaaaiiiins......', 'What', 'do', 'you', 'about', 'the', 'rain', 'only', 'its', 'hate', 'snowman?', 'invented', 'into', 'a', 'head.', 'lawyer', 'after', 'Does', 'ago,', 'report?', "it's", 'the', 'Chef', "who's", 'likes', 'eyes', 'Why', 'did', 'a', 'puts', 'Why', 'you', 'do', 'A', 'soap.', 'and', 'they', 'Where', 'up', 'happened', 'to', 'the', "friend's", 'covered', 'of', 'baseball?', 'around.', 'I', "don't,", 'as', 'a', 'wombat?', 'The', 'road?', 'from', 'them', 'Coming', 'keyboard?', 'How', 'you', 'call', 'a', 'toge

In [26]:
import os
cwd = os.getcwd()
print(cwd)

/home/jmdanie/Documents/Code/Hackathons/Bon-Hacketit
