<a href="https://colab.research.google.com/github/TomasMendozaHN/ICDF_Class/blob/main/05112022_LSTM_TextGeneration.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Begin by defining the LSTM model

In [1]:
import torch
from torch import nn

class Model(nn.Module):
    def __init__(self, dataset):
        super(Model, self).__init__()
        self.lstm_size = 128
        self.embedding_dim = 128
        self.num_layers = 3

        # Embedding is a layer that converts words into one-hot encodings
        # For example:
        # hello my name is Tomas = 00001 00010 00100 01000 1000
        # As you can see, each unique word has it's own "embedding"
        # in this case, the embedding is a one-hot encoded vector
        # therefore, the moroe unique words you have, the longer each vector will be
        # Also, this means you can't generate words that the LSTM has not seen
        n_vocab = len(dataset.uniq_words)
        self.embedding = nn.Embedding(
            num_embeddings=n_vocab,
            embedding_dim=self.embedding_dim,
        )
        self.lstm = nn.LSTM(
            input_size=self.lstm_size,
            hidden_size=self.lstm_size,
            num_layers=self.num_layers,
            dropout=0.2,
        )
        self.fc = nn.Linear(self.lstm_size, n_vocab)

    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)
        return logits, state

    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, sequence_length, self.lstm_size),
                torch.zeros(self.num_layers, sequence_length, self.lstm_size))

# Test loading an online CSV with the dataset in it

In [2]:
import pandas as pd
dataset = pd.read_csv("https://raw.githubusercontent.com/amoudgl/short-jokes-dataset/master/data/reddit-cleanjokes.csv")
print(dataset)

        ID                                               Joke
0        1  What did the bartender say to the jumper cable...
1        2  Don't you hate jokes about German sausage? The...
2        3  Two artists had an art contest... It ended in ...
3        4  Why did the chicken cross the playground? To g...
4        5   What gun do you use to hunt a moose? A moosecut!
...    ...                                                ...
1617  1618  What do you call a camel with 3 humps? Humphre...
1618  1619  Two fish in a tank. [x-post from r/Jokes] One ...
1619  1620          "Stay strong!" I said to my wi-fi signal.
1620  1621  Why was the tomato blushing? Because it saw th...
1621  1622    What is heavy forward but not backward? **ton**

[1622 rows x 2 columns]


# Visualize the entire text (all jokes) as a single string

In [3]:
text_as_a_single_string = dataset['Joke'].str.cat(sep=' ')
print(text_as_a_single_string)

What did the bartender say to the jumper cables? You better not try to start anything. Don't you hate jokes about German sausage? They're the wurst! Two artists had an art contest... It ended in a draw Why did the chicken cross the playground? To get to the other slide. What gun do you use to hunt a moose? A moosecut! If life gives you melons, you might have dyslexia. Broken pencils... ...are pointless. What did one snowman say to the other snowman? 'Do you smell carrots?' How many hipsters does it take to change a lightbulb? It's a really obscure number. You've probably never heard of it. Where do sick boats go? The dock! I like my slaves like I like my coffee: Free. My girlfriend told me she was leaving me because I keep pretending to be a Transformer... I said, No, wait! I can change! Old Chinese proverb: Man who not shower in 7 days makes one reek. What did the owner of a brownie factory say when his factory caught fire? "I'm getting the fudge outta here!" What form of radiation ba

# Separate all words in that string

In [4]:
individual_words = text_as_a_single_string.split(' ')
print(individual_words)

['What', 'did', 'the', 'bartender', 'say', 'to', 'the', 'jumper', 'cables?', 'You', 'better', 'not', 'try', 'to', 'start', 'anything.', "Don't", 'you', 'hate', 'jokes', 'about', 'German', 'sausage?', "They're", 'the', 'wurst!', 'Two', 'artists', 'had', 'an', 'art', 'contest...', 'It', 'ended', 'in', 'a', 'draw', 'Why', 'did', 'the', 'chicken', 'cross', 'the', 'playground?', 'To', 'get', 'to', 'the', 'other', 'slide.', 'What', 'gun', 'do', 'you', 'use', 'to', 'hunt', 'a', 'moose?', 'A', 'moosecut!', 'If', 'life', 'gives', 'you', 'melons,', 'you', 'might', 'have', 'dyslexia.', 'Broken', 'pencils...', '...are', 'pointless.', 'What', 'did', 'one', 'snowman', 'say', 'to', 'the', 'other', 'snowman?', "'Do", 'you', 'smell', "carrots?'", 'How', 'many', 'hipsters', 'does', 'it', 'take', 'to', 'change', 'a', 'lightbulb?', "It's", 'a', 'really', 'obscure', 'number.', "You've", 'probably', 'never', 'heard', 'of', 'it.', 'Where', 'do', 'sick', 'boats', 'go?', 'The', 'dock!', 'I', 'like', 'my', 'sla

# Obtain a list containing all words (without repetition) in decreasing order of frequency

In [5]:
from collections import Counter
unique_words = Counter(individual_words)
unique_words_in_order_of_frequency = sorted(unique_words, key=unique_words.get, reverse=True)
print(unique_words_in_order_of_frequency)

['the', 'a', 'What', 'you', 'to', 'do', 'I', 'of', 'did', 'Why', 'in', 'and', 'was', 'A', 'call', 'it', 'is', 'with', 'his', 'Because', 'say', 'he', 'get', 'on', 'when', 'my', 'How', 'for', 'an', 'does', 'The', 'about', 'have', 'that', 'one', 'He', 'are', 'at', 'who', "What's", 'hear', 'from', 'into', 'out', 'be', 'this', 'they', 'had', 'like', "don't", 'your', 'Did', 'go', 'can', 'but', 'joke', "I'm", 'It', 'me', 'no', 'My', 'other', 'so', 'make', 'You', 'all', 'up', 'favorite', 'know', 'their', 'just', 'got', 'what', 'not', 'They', 'take', 'cross', 'two', 'many', 'Where', 'said', 'man', 'people', 'always', 'To', 'it.', "it's", 'fish', 'between', 'her', 'if', 'by', "couldn't", 'Two', 'little', "can't", 'too', 'has', 'kind', "It's", 'cow', '.', 'would', 'only', "didn't", 'think', 'walks', 'chicken', 'will', 'heard', 'there', 'much', 'were', 'says', 'made', 'If', 'told', 'went', 'So', 'best', 'really', 'difference', 'says,', 'because', '-', 'wanted', 'time', 'When', 'never', 'she', 'why

# Convert each word to an integer (for embedding)

In [6]:
index_to_word = {index: word for index, word in enumerate(unique_words_in_order_of_frequency)}
print(index_to_word)

{0: 'the', 1: 'a', 2: 'What', 3: 'you', 4: 'to', 5: 'do', 6: 'I', 7: 'of', 8: 'did', 9: 'Why', 10: 'in', 11: 'and', 12: 'was', 13: 'A', 14: 'call', 15: 'it', 16: 'is', 17: 'with', 18: 'his', 19: 'Because', 20: 'say', 21: 'he', 22: 'get', 23: 'on', 24: 'when', 25: 'my', 26: 'How', 27: 'for', 28: 'an', 29: 'does', 30: 'The', 31: 'about', 32: 'have', 33: 'that', 34: 'one', 35: 'He', 36: 'are', 37: 'at', 38: 'who', 39: "What's", 40: 'hear', 41: 'from', 42: 'into', 43: 'out', 44: 'be', 45: 'this', 46: 'they', 47: 'had', 48: 'like', 49: "don't", 50: 'your', 51: 'Did', 52: 'go', 53: 'can', 54: 'but', 55: 'joke', 56: "I'm", 57: 'It', 58: 'me', 59: 'no', 60: 'My', 61: 'other', 62: 'so', 63: 'make', 64: 'You', 65: 'all', 66: 'up', 67: 'favorite', 68: 'know', 69: 'their', 70: 'just', 71: 'got', 72: 'what', 73: 'not', 74: 'They', 75: 'take', 76: 'cross', 77: 'two', 78: 'many', 79: 'Where', 80: 'said', 81: 'man', 82: 'people', 83: 'always', 84: 'To', 85: 'it.', 86: "it's", 87: 'fish', 88: 'between'

# Convert each integer (embedding) back into word

In [7]:
word_to_index = {word: index for index, word in enumerate(unique_words_in_order_of_frequency)}
print(word_to_index)

{'the': 0, 'a': 1, 'What': 2, 'you': 3, 'to': 4, 'do': 5, 'I': 6, 'of': 7, 'did': 8, 'Why': 9, 'in': 10, 'and': 11, 'was': 12, 'A': 13, 'call': 14, 'it': 15, 'is': 16, 'with': 17, 'his': 18, 'Because': 19, 'say': 20, 'he': 21, 'get': 22, 'on': 23, 'when': 24, 'my': 25, 'How': 26, 'for': 27, 'an': 28, 'does': 29, 'The': 30, 'about': 31, 'have': 32, 'that': 33, 'one': 34, 'He': 35, 'are': 36, 'at': 37, 'who': 38, "What's": 39, 'hear': 40, 'from': 41, 'into': 42, 'out': 43, 'be': 44, 'this': 45, 'they': 46, 'had': 47, 'like': 48, "don't": 49, 'your': 50, 'Did': 51, 'go': 52, 'can': 53, 'but': 54, 'joke': 55, "I'm": 56, 'It': 57, 'me': 58, 'no': 59, 'My': 60, 'other': 61, 'so': 62, 'make': 63, 'You': 64, 'all': 65, 'up': 66, 'favorite': 67, 'know': 68, 'their': 69, 'just': 70, 'got': 71, 'what': 72, 'not': 73, 'They': 74, 'take': 75, 'cross': 76, 'two': 77, 'many': 78, 'Where': 79, 'said': 80, 'man': 81, 'people': 82, 'always': 83, 'To': 84, 'it.': 85, "it's": 86, 'fish': 87, 'between': 88

# Convert the entire string of jokes into integers (using the word_to_index dictionary)

In [8]:
words_indexes = [word_to_index[w] for w in individual_words]
print(words_indexes)

[2, 8, 0, 248, 20, 4, 0, 1905, 1906, 64, 534, 73, 535, 4, 1907, 1908, 225, 3, 226, 227, 31, 249, 1909, 314, 0, 1910, 93, 704, 47, 28, 705, 1077, 57, 424, 10, 1, 347, 9, 8, 0, 107, 76, 0, 706, 84, 22, 4, 0, 61, 1911, 2, 1078, 5, 3, 133, 4, 1912, 1, 1913, 13, 1914, 115, 275, 1079, 3, 1915, 3, 348, 32, 1916, 1917, 1918, 1080, 1919, 2, 8, 34, 707, 20, 4, 0, 61, 708, 1920, 3, 425, 1921, 26, 78, 1081, 29, 15, 75, 4, 187, 1, 276, 99, 1, 120, 1082, 1922, 1083, 709, 128, 109, 7, 85, 79, 5, 349, 1084, 710, 30, 1923, 6, 48, 25, 1085, 48, 6, 48, 25, 1924, 1086, 60, 426, 116, 58, 129, 12, 1087, 58, 123, 6, 164, 1925, 4, 44, 1, 1926, 6, 174, 1088, 1927, 6, 53, 1928, 427, 711, 1929, 350, 38, 73, 351, 10, 1089, 712, 202, 34, 1930, 2, 8, 0, 1090, 7, 1, 1931, 713, 20, 24, 18, 713, 714, 715, 315, 165, 0, 1932, 1933, 1934, 2, 1091, 7, 1935, 1936, 3, 1937, 13, 1938, 1939, 1940, 1941, 1942, 7, 25, 1943, 1944, 1945, 2, 8, 0, 316, 133, 4, 1092, 139, 145, 1946, 13, 1947, 1948, 2, 8, 0, 316, 133, 4, 22, 43, 0, 

# Now that we have seen how we need to prepare our data, we must create a Dataset function that will do this automatically

In [9]:
import torch
import pandas as pd
from collections import Counter

# Remember: Every dataset function you create MUST have the following three methods:
# 1. __init__      --> must read and prepare the dataset
# 2. __len__       --> must return the number of datapoints in your entire dataset
# 3. __getitem__   --> must return a batch of data
class Dataset(torch.utils.data.Dataset):
    def __init__(self, sequence_length):
        self.sequence_length = sequence_length
        self.words = self.load_words()
        self.uniq_words = self.get_uniq_words()

        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}

        self.words_indexes = [self.word_to_index[w] for w in self.words]

    def __len__(self):
        return len(self.words_indexes) - self.sequence_length

    def __getitem__(self, index):
        return (
            torch.tensor(self.words_indexes[index:index+self.sequence_length]),
            torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1]),
        )

    def load_words(self):
        train_df = pd.read_csv("https://raw.githubusercontent.com/amoudgl/short-jokes-dataset/master/data/reddit-cleanjokes.csv")
        text = train_df['Joke'].str.cat(sep=' ')
        return text.split(' ')

    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True)



# initialize the Dataloader

In [10]:
sequence_length = 6
dataset = Dataset(sequence_length=sequence_length)

# Test the output of the dataloader

In [11]:
for a,b in dataset:
  print(f"originally, our dataset returns a = {a}, b = {b}")
  
  # Converting both tensors into strings
  a,b = a.numpy(), b.numpy()
  a = [index_to_word[x] for x in a] 
  b = [index_to_word[x] for x in b]
  
  print(f"converting these back into text, we have: a = {a}, b = {b}")
  break

originally, our dataset returns a = tensor([  2,   8,   0, 248,  20,   4]), b = tensor([  8,   0, 248,  20,   4,   0])
converting these back into text, we have: a = ['What', 'did', 'the', 'bartender', 'say', 'to'], b = ['did', 'the', 'bartender', 'say', 'to', 'the']


# Initialize the model

In [12]:
model = Model(dataset)

# Begin training!

In [13]:
max_epochs = 10
batch_size = 128

In [14]:
import argparse
import torch
import numpy as np
from torch import nn, optim
from torch.utils.data import DataLoader

def train(dataset, model, max_epochs, sequence_length, batch_size):
    model.train()

    dataloader = DataLoader(dataset, batch_size=batch_size)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(max_epochs):
        state_h, state_c = model.init_state(sequence_length)

        for batch, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()

            y_pred, (state_h, state_c) = model(x, (state_h, state_c))
            loss = criterion(y_pred.transpose(1, 2), y)

            state_h = state_h.detach()
            state_c = state_c.detach()

            loss.backward()
            optimizer.step()

            print({ 'epoch': epoch, 'batch': batch, 'loss': loss.item() })

In [15]:
train(dataset, model, max_epochs=max_epochs, sequence_length=sequence_length, batch_size=batch_size)

{'epoch': 0, 'batch': 0, 'loss': 8.83630084991455}
{'epoch': 0, 'batch': 1, 'loss': 8.823466300964355}
{'epoch': 0, 'batch': 2, 'loss': 8.82109260559082}
{'epoch': 0, 'batch': 3, 'loss': 8.824886322021484}
{'epoch': 0, 'batch': 4, 'loss': 8.804631233215332}
{'epoch': 0, 'batch': 5, 'loss': 8.79234790802002}
{'epoch': 0, 'batch': 6, 'loss': 8.787333488464355}
{'epoch': 0, 'batch': 7, 'loss': 8.77765941619873}
{'epoch': 0, 'batch': 8, 'loss': 8.746211051940918}
{'epoch': 0, 'batch': 9, 'loss': 8.721826553344727}
{'epoch': 0, 'batch': 10, 'loss': 8.678711891174316}
{'epoch': 0, 'batch': 11, 'loss': 8.567277908325195}
{'epoch': 0, 'batch': 12, 'loss': 8.570013046264648}
{'epoch': 0, 'batch': 13, 'loss': 8.377081871032715}
{'epoch': 0, 'batch': 14, 'loss': 8.204994201660156}
{'epoch': 0, 'batch': 15, 'loss': 8.04593276977539}
{'epoch': 0, 'batch': 16, 'loss': 8.019049644470215}
{'epoch': 0, 'batch': 17, 'loss': 7.849754810333252}
{'epoch': 0, 'batch': 18, 'loss': 7.844442367553711}
{'epoch'

In [18]:
def predict(dataset, model, text, next_words=100):
    model.eval()

    words = text.split(' ')
    state_h, state_c = model.init_state(len(words))

    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]])
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))

        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])

    return words

In [20]:
predict(dataset, model, text="Knock knock. Who's there?", next_words=5)

['Knock',
 'knock.',
 "Who's",
 'there?',
 'Great',
 'Sports:',
 'years.',
 'December',
 'today...']

# The results are still pretty bad. You can always improve it by:


1.   Clean up the data by removing non-letter characters.
2.   Increase the model capacity by adding more Linear or LSTM layers.
3.   Split the dataset into train, test, and validation sets.