# Text Generation, Many-to-Many a Different way...
What if we want to generate a sentence based on a prompt which is another sentence? We will need to first encode the input sequence and then train our model to produce the target sentence sequentially. This is common in question-answering type tasks where we want our network to "respond" to a given question!

[<img src="https://static.packt-cdn.com/products/9781789346640/graphics/assets/79db1776-f471-4fe6-89b0-67cbae844bfc.png">](LSTM)

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import io
import re

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
import torch.nn.functional as F
from torch.distributions import Categorical

from torchtext.datasets import YahooAnswers
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import torchtext.transforms as T
from torch.hub import load_state_dict_from_url
from torchtext.data.functional import sentencepiece_tokenizer, load_sp_model

from tqdm.notebook import trange, tqdm

In [None]:
# Define the hyperparameters
learning_rate = 1e-4

nepochs = 10

batch_size = 32

max_len_q = 32
max_len_a = 64

data_set_root = "../../datasets"

# We'll be using the YahooAnswers Dataset
# Note that for torchtext these datasets are NOT Pytorch dataset classes "YahooAnswers" is a function that
# returns a Pytorch DataPipe!

# Pytorch DataPipes vvv
# https://pytorch.org/data/main/torchdata.datapipes.iter.html

# vvv Good Blog on the difference between DataSet and DataPipe
# https://medium.com/deelvin-machine-learning/comparison-of-pytorch-dataset-and-torchdata-datapipes-486e03068c58

# Depending on the dataset sometimes the dataset doesn't download and gives an error
# and you'll have to download and extract manually 
# "The datasets supported by torchtext are datapipes from the torchdata project, which is still in Beta status"

# Un-comment to triger the DataPipe to download the data vvv
# dataset_train = YahooAnswers(root=data_set_root, split="train")
# data = next(iter(dataset_train))

# Side-Note I've noticed that the WikiText dataset is no longer able to be downloaded :(

In [None]:
### Uncomment to "Train" a Sentence Piece Tokenizer with the train data capping the vocab size to 20000 tokens
# from torchtext.data.functional import generate_sp_model

# with open(os.path.join(data_set_root, "datasets/YahooAnswers/train.csv")) as f:
#     with open(os.path.join(data_set_root, "datasets/YahooAnswers/data.txt"), "w") as f2:
#         for i, line in enumerate(f):
#             text_only = "".join(line.split(",")[1:])
#             filtered = re.sub(r'\\|\\n|;', ' ', text_only.replace('"', ' ').replace('\n', ' ')) # remove newline characters
#             f2.write(filtered.lower() + "\n")


# generate_sp_model(os.path.join(data_set_root, "datasets/YahooAnswers/data.txt"), 
#                   vocab_size=20000, model_prefix='spm_user_ya')

In [None]:
class YahooQA(Dataset):
    def __init__(self, num_datapoints, test_train="train"):
        self.df = pd.read_csv(os.path.join(data_set_root, "datasets/YahooAnswers/" + test_train + ".csv"),
                              names=["Class", "Q_Title", "Q_Content", "A"])
        
        self.df.fillna('', inplace=True)
        self.df['Q'] = self.df['Q_Title'] + ' ' + self.df['Q_Content']
        self.df.drop(['Q_Title', 'Q_Content'], axis=1, inplace=True)
        self.df['Q'] = self.df['Q'].str.replace(r'\\n|\\|\\r|\\r\\n|\n|"', ' ', regex=True)
        self.df['A'] = self.df['A'].str.replace(r'\\n|\\|\\r|\\r\\n|\n|"', ' ', regex=True)

    def __getitem__(self, index):
        question_text = self.df.loc[index]["Q"].lower()
        answer_text = self.df.loc[index]["A"].lower()

        return question_text, answer_text

    def __len__(self):
        return len(self.df)

In [None]:
dataset_train = YahooQA(num_datapoints=data_set_root, test_train="train")
dataset_test = YahooQA(num_datapoints=data_set_root, test_train="test")

In [None]:
data_loader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)
data_loader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=True, num_workers=4)

In [None]:
sp_model = load_sp_model("spm_user_ya.model")
tokenizer = sentencepiece_tokenizer(sp_model)

In [None]:
def yield_tokens(file_path):
    with io.open(file_path, encoding = 'utf-8') as f:
        for line in f:
            yield [line.split("\t")[0]]
            
vocab = build_vocab_from_iterator(yield_tokens("spm_user_ya.vocab"), 
                                  specials= ['<pad>', '<soq>', '<eoq>', '<soa>', '<eoa>', '<unk>'], # special case tokens
                                  special_first=True)
vocab.set_default_index(vocab['<unk>'])

In [None]:
q_tranform = T.Sequential(
    # Tokeniz with pre-existing Tokenizer
    T.SentencePieceTokenizer("spm_user_ya.model"),
    ## converts the sentences to indices based on given vocabulary
    T.VocabTransform(vocab=vocab),
    ## Add <sos> at beginning of each sentence. 1 because the index for <sos> in vocabulary is
    # 1 as seen in previous section
    T.AddToken(1, begin=True),
    # Crop the sentance if it is longer than the max length
    T.Truncate(max_seq_len=max_len_q),
    ## Add <eos> at beginning of each sentence. 2 because the index for <eos> in vocabulary is
    # 2 as seen in previous section
    T.AddToken(2, begin=False),
    # Convert the list of lists to a tensor, this will also
    # Pad a sentence with the <pad> token if it is shorter than the max length
    # This ensures all sentences are the same length!
    T.ToTensor(padding_value=0)
)

a_tranform = T.Sequential(
    # Tokeniz with pre-existing Tokenizer
    T.SentencePieceTokenizer("spm_user_ya.model"),
    ## converts the sentences to indices based on given vocabulary
    T.VocabTransform(vocab=vocab),
    ## Add <sos> at beginning of each sentence. 1 because the index for <sos> in vocabulary is
    # 1 as seen in previous section
    T.AddToken(3, begin=True),
    # Crop the sentance if it is longer than the max length
    T.Truncate(max_seq_len=max_len_a),
    ## Add <eos> at beginning of each sentence. 2 because the index for <eos> in vocabulary is
    # 2 as seen in previous section
    T.AddToken(4, begin=False),
    # Convert the list of lists to a tensor, this will also
    # Pad a sentence with the <pad> token if it is shorter than the max length
    # This ensures all sentences are the same length!
    T.ToTensor(padding_value=0)
)

In [None]:
class LSTM(nn.Module):
    def __init__(self, num_emb, num_layers=1, emb_size=128, hidden_size=128):
        super(LSTM, self).__init__()
        
        self.embedding = nn.Embedding(num_emb, emb_size)

        self.mlp_emb = nn.Sequential(nn.Linear(emb_size, emb_size),
                                     nn.LayerNorm(emb_size),
                                     nn.ELU(),
                                     nn.Linear(emb_size, emb_size))
        
        self.lstm = nn.LSTM(input_size=emb_size, hidden_size=hidden_size, 
                            num_layers=num_layers, batch_first=True, dropout=0.25)

        self.mlp_out = nn.Sequential(nn.Linear(hidden_size, hidden_size//2),
                                     nn.LayerNorm(hidden_size//2),
                                     nn.ELU(),
                                     nn.Dropout(0.5),
                                     nn.Linear(hidden_size//2, num_emb))
        
    def forward(self, input_seq, hidden_in, mem_in):
        input_embs = self.embedding(input_seq)
        input_embs = self.mlp_emb(input_embs)
                
        output, (hidden_out, mem_out) = self.lstm(input_embs, (hidden_in, mem_in))
                
        return self.mlp_out(output), hidden_out, mem_out

In [None]:
device = torch.device(0 if torch.cuda.is_available() else 'cpu')

In [None]:
emb_size = 256
hidden_size = 1024

num_layers = 4

# Create model
lstm_qa = LSTM(num_emb=len(vocab), num_layers=num_layers, 
                       emb_size=emb_size, hidden_size=hidden_size).to(device)

# Initialize the optimizer with above parameters
optimizer = optim.Adam(lstm_qa.parameters(), lr=learning_rate, weight_decay=1e-4)

# Define the loss function
loss_fn = nn.CrossEntropyLoss()

In [None]:
# Let's see how many Parameters our Model has!
num_model_params = 0
for param in lstm_qa.parameters():
    num_model_params += param.flatten().shape[0]

print("-This Model Has %d (Approximately %d Million) Parameters!" % (num_model_params, num_model_params//1e6))

In [None]:
training_loss_logger = []

In [None]:
pbar = trange(0, nepochs, leave=False, desc="Epoch")    
train_acc = 0
test_acc = 0
for epoch in pbar:
    pbar.set_postfix_str('Accuracy: Train %.2f%%, Test %.2f%%' % (train_acc * 100, test_acc * 100))
    
    lstm_qa.train()
    steps = 0
    for q_text, a_text in tqdm(data_loader_train, desc="Training", leave=False):
        # Transform both question and answer text
        q_text_tokens = q_tranform(list(q_text)).to(device)
        a_text_tokens = a_tranform(list(a_text)).to(device)
        
        # Inputs and outputs for the answer next-token prediction
        a_input_text = a_text_tokens[:, :-1]
        a_output_text = a_text_tokens[:, 1:]
        
        bs = q_text_tokens.shape[0]
        
        # Initialise the memory buffers
        hidden = torch.zeros(num_layers, bs, hidden_size, device=device)
        memory = torch.zeros(num_layers, bs, hidden_size, device=device)

        # Encode the whole question sequence
        _, hidden, memory = lstm_qa(q_text_tokens, hidden, memory)

        # Now perform a "next-token" prediction on the answer sequence
        # providing the model with the memory buffers from the question-encoding step
        pred, hidden, memory = lstm_qa(a_input_text, hidden, memory)

        loss = loss_fn(pred.transpose(1, 2), a_output_text)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        training_loss_logger.append(loss.item())


In [None]:
_ = plt.figure(figsize=(10, 5))
_ = plt.plot(training_loss_logger)
_ = plt.title("Training Loss")

In [None]:
q_text, a_text = next(iter(data_loader_test))

In [None]:
index = 0
q_text[index]

In [None]:
# Ask a question!
# init_prompt = ["what is larger 1 or 2? "]

# Or use one from the test set
init_prompt = [q_text[index]]

input_tokens = q_tranform(init_prompt).to(device)

# Add Start-Of-Answer token to prompt the network to start generating the answer!
input_tokens = torch.cat((input_tokens, 3 * torch.ones(1, 1, device=device).long()), 1)
print("INITIAL PROMPT:")
print(input_tokens)

print("\nPROMPT TOKENS:")
print(vocab.lookup_tokens(input_tokens[0].cpu().numpy()))

In [None]:
temp = 0.8

In [None]:
log_tokens = []
lstm_qa.eval()

with torch.no_grad():    
    hidden = torch.zeros(num_layers, 1, hidden_size, device=device)
    memory = torch.zeros(num_layers, 1, hidden_size, device=device)
    
    for i in range(100):
        data_pred, hidden, memory = lstm_qa(input_tokens, hidden, memory)
#         We can take the token with the highest prob
#         input_tokens = data_pred[:, -1].argmax().reshape(1, 1)
        
        # Or sample from the distribution of probs!
        dist = Categorical(logits=data_pred[:, -1, :]/temp)
        input_tokens = dist.sample().reshape(1, 1)
        
        log_tokens.append(input_tokens.cpu())
        if input_tokens.item() == 4:
            break

In [None]:
pred_text = "".join(vocab.lookup_tokens(torch.cat(log_tokens, 1)[0].numpy()))
print(pred_text)

In [None]:
pred_text.replace("▁", " ").replace("<unk>", "").replace("<eoa>", "")

In [None]:
# Have a look at the next token probabilities 
plt.plot(F.softmax(data_pred[:, -1, :]/temp, -1).cpu().numpy().flatten())