# Text Generation, Many-to-Many with Text.
If we treat a sentence as sequence of data-points, like a time series data source, we can perform "next token" prediction and create a model that tries to complete the text sequence based on an initial input. This is commonly refered to as "text generation".

[<img src="https://upload.wikimedia.org/wikipedia/commons/thumb/9/93/LSTM_Cell.svg/2880px-LSTM_Cell.svg.png">](LSTM)
<br>
[Corresponding Tutorial Video](https://youtu.be/8K6EL7h9gYI)

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import io
import re
from tqdm.notebook import trange, tqdm

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
import torch.nn.functional as F
from torch.distributions import Categorical

from torchtext.datasets import WikiText2, EnWik9, AG_NEWS
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import torchtext.transforms as T
from torch.hub import load_state_dict_from_url
from torchtext.data.functional import sentencepiece_tokenizer, load_sp_model

In [None]:
# Define hyperparameters
# Step size for parameter updates
learning_rate = 1e-4

# Number of training epochs
nepochs = 20

# Number of samples processed together
batch_size = 32

# Maximum sequence length
max_len = 64

# Root directory of the dataset
data_set_root = "../../datasets"

## Dataset, Tokenizers and Vocab!

In [None]:
# We'll be using the AG News Dataset
# Which contains a short news article and a single label to classify the "type" of article
# Note that for torchtext these datasets are NOT Pytorch dataset classes "AG_NEWS" is a function that
# returns a Pytorch DataPipe!

# Pytorch DataPipes vvv
# https://pytorch.org/data/main/torchdata.datapipes.iter.html

# vvv Good Blog on the difference between DataSet and DataPipe
# https://medium.com/deelvin-machine-learning/comparison-of-pytorch-dataset-and-torchdata-datapipes-486e03068c58
# Depending on the dataset sometimes the dataset doesn't download and gives an error
# and you'll have to download and extract manually 
# "The datasets supported by torchtext are datapipes from the torchdata project, which is still in Beta status"

# Un-comment to triger the DataPipe to download the data vvv
# dataset_train = AG_NEWS(root=data_set_root, split="train")
# data = next(iter(dataset_train))

# Side-Note I've noticed that the WikiText dataset is no longer able to be downloaded :(

In [None]:
# Un-Comment to train sentence-piece model for tokenizer and vocab!

# from torchtext.data.functional import generate_sp_model

# with open(os.path.join(data_set_root, "datasets/AG_NEWS/train.csv")) as f:
#     with open(os.path.join(data_set_root, "datasets/AG_NEWS/data.txt"), "w") as f2:
#         for i, line in enumerate(f):
#             text_only = "".join(line.split(",")[1:])
#             filtered = re.sub(r'\\|\\n|;', ' ', text_only.replace('"', ' ').replace('\n', ' ')) # remove newline characters
#             f2.write(filtered.lower() + "\n")

# generate_sp_model(os.path.join(data_set_root, "datasets/AG_NEWS/data.txt"), 
#                   vocab_size=20000, model_prefix='spm_ag_news')

In [None]:
# AGNews dataset class definition
class AGNews(Dataset):
    def __init__(self, num_datapoints, test_train="train"):
        # Read the AG News dataset CSV file based on the test_train parameter (train or test)
        self.df = pd.read_csv(os.path.join(data_set_root, "datasets/AG_NEWS/" + test_train + ".csv"),
                              names=["Class", "Title", "Content"])
        
        # Fill missing values with empty string
        self.df.fillna('', inplace=True)
        
        # Combine Title and Content columns into a single Article column
        self.df['Article'] = self.df['Title'] + " : " + self.df['Content']
        
        # Drop Title and Content columns as they are no longer needed
        self.df.drop(['Title', 'Content'], axis=1, inplace=True)
        
        # Replace special characters with whitespace in the Article column
        self.df['Article'] = self.df['Article'].str.replace(r'\\n|\\|\\r|\\r\\n|\n|"', ' ', regex=True)

    # Method to get a single item from the dataset
    def __getitem__(self, index):
        # Get the text of the article at the given index, converted to lowercase
        text = self.df.loc[index]["Article"].lower()

        return text

    # Method to get the length of the dataset
    def __len__(self):
        # Return the total number of articles in the dataset
        return len(self.df)

In [None]:
# Create AGNews dataset instances for training and testing
dataset_train = AGNews(num_datapoints=data_set_root, test_train="train")
dataset_test = AGNews(num_datapoints=data_set_root, test_train="test")

# Create data loaders for training and testing datasets
# DataLoader for training dataset
data_loader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=8, drop_last=True)
# DataLoader for testing dataset
data_loader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=True, num_workers=8)

In [None]:
# Example of using the tokenizer
# Load the SentencePiece model
sp_model = load_sp_model("spm_ag_news.model")

# Create a tokenizer using the loaded model
tokenizer = sentencepiece_tokenizer(sp_model)

# Iterate over tokens generated by the tokenizer
for token in tokenizer(["i am creating"]):
    print(token)

In [None]:
# Define a function to yield tokens from a file
def yield_tokens(file_path):
    # Open the file in UTF-8 encoding
    with io.open(file_path, encoding='utf-8') as f:
        # Iterate over each line in the file
        for line in f:
            # Yield the token split by tab character
            yield [line.split("\t")[0]]

# Build vocabulary from the iterator of tokens
# We will also add "special" tokens that we'll use to signal something to our model
# <pad> is a padding token that is added to the end of a sentence to ensure 
# the length of all sequences in a batch is the same
# <sos> signals the "Start-Of-Sentence" aka the start of the sequence
# <eos> signals the "End-Of-Sentence" aka the end of the sequence
# <unk> "unknown" token is used if a token is not contained in the vocab
vocab = build_vocab_from_iterator(
    yield_tokens("spm_ag_news.vocab"),
    # Define special tokens with special_first=True to place them at the beginning of the vocabulary
    specials=['<pad>', '<sos>', '<eos>', '<unk>'],
    special_first=True
)

# Set default index for out-of-vocabulary tokens
vocab.set_default_index(vocab['<unk>'])

In [None]:
class TokenDrop(nn.Module):
    """For a batch of tokens indices, randomly replace a non-specical token with <pad>.
    
    Args:
        prob (float): probability of dropping a token
        pad_token (int): index for the <pad> token
        num_special (int): Number of special tokens, assumed to be at the start of the vocab
    """

    def __init__(self, prob=0.1, pad_token=0, num_special=4):
        self.prob = prob
        self.num_special = num_special
        self.pad_token = pad_token

    def __call__(self, sample):
        # Randomly sample a bernoulli distribution with p=prob
        # to create a mask where 1 means we will replace that token
        mask = torch.bernoulli(self.prob * torch.ones_like(sample)).long()
        
        # only replace if the token is not a special token
        can_drop = (sample >= self.num_special).long()
        mask = mask * can_drop
        
        replace_with = (self.pad_token * torch.ones_like(sample)).long()
        
        sample_out = (1 - mask) * sample + mask * replace_with
        
        return sample_out

In [None]:
# Define a transformation pipeline for training data
train_transform = T.Sequential(
    # Tokenize sentences using pre-existing SentencePiece tokenizer model
    T.SentencePieceTokenizer("spm_ag_news.model"),
    # Convert tokens to indices based on given vocabulary
    T.VocabTransform(vocab=vocab),
    # Add <sos> token at the beginning of each sentence (index 1 in vocabulary)
    T.AddToken(1, begin=True),
    # Crop the sentence if it is longer than the max length
    T.Truncate(max_seq_len=max_len),
    # Add <eos> token at the end of each sentence (index 2 in vocabulary)
    T.AddToken(2, begin=False),
    # Convert the list of lists to a tensor and pad sentences with the <pad> token if shorter than max length
    T.ToTensor(padding_value=0)
)

# Define a transformation pipeline for generation (without truncation)
gen_transform = T.Sequential(
    # Tokenize sentences using pre-existing SentencePiece tokenizer model
    T.SentencePieceTokenizer("spm_ag_news.model"),
    # Convert tokens to indices based on given vocabulary
    T.VocabTransform(vocab=vocab),
    # Add <sos> token at the beginning of each sentence (index 1 in vocabulary)
    T.AddToken(1, begin=True),
    # Convert the list of lists to a tensor and pad sentences with the <pad> token if shorter than max length
    T.ToTensor(padding_value=0)
)

In [None]:
text = next(iter(data_loader_train))
index = 0
input_tokens = train_tranform(text)
print("SENTENCE")
print(text[index])
print()
print("TOKENS")
print(vocab.lookup_tokens(input_tokens[index].numpy()))

In [None]:
print("TOKENS BACK TO SENTENCE")

pred_text = "".join(vocab.lookup_tokens(input_tokens[index].numpy()))
pred_text.replace("▁", " ")

## Create Model

In [None]:
class LSTM(nn.Module):
    def __init__(self, num_emb, num_layers=1, emb_size=128, hidden_size=128):
        super(LSTM, self).__init__()
        
        self.embedding = nn.Embedding(num_emb, emb_size)

        self.mlp_emb = nn.Sequential(nn.Linear(emb_size, emb_size),
                                     nn.LayerNorm(emb_size),
                                     nn.ELU(),
                                     nn.Linear(emb_size, emb_size))
        
        self.lstm = nn.LSTM(input_size=emb_size, hidden_size=hidden_size, 
                            num_layers=num_layers, batch_first=True, dropout=0.25)

        self.mlp_out = nn.Sequential(nn.Linear(hidden_size, hidden_size//2),
                                     nn.LayerNorm(hidden_size//2),
                                     nn.ELU(),
                                     nn.Dropout(0.5),
                                     nn.Linear(hidden_size//2, num_emb))
        
    def forward(self, input_seq, hidden_in, mem_in):
        input_embs = self.embedding(input_seq)
        input_embs = self.mlp_emb(input_embs)
                
        output, (hidden_out, mem_out) = self.lstm(input_embs, (hidden_in, mem_in))
                
        return self.mlp_out(output), hidden_out, mem_out

## Initialise Model and Optimizer

In [None]:
# Check if GPU is available, set device accordingly
device = torch.device(1 if torch.cuda.is_available() else 'cpu')

# Define embedding size and hidden size for the LSTM model
emb_size = 256
hidden_size = 1024

# Define number of layers for the LSTM model
num_layers = 2

# Create LSTM model instance
lstm_generator = LSTM(num_emb=len(vocab), num_layers=num_layers, 
                      emb_size=emb_size, hidden_size=hidden_size).to(device)

# Initialize optimizer with Adam optimizer
optimizer = optim.Adam(lstm_generator.parameters(), lr=learning_rate, weight_decay=1e-4)

# Define the loss function (Cross Entropy Loss)
loss_fn = nn.CrossEntropyLoss()

# Custom transform that will randomly replace a token with <pad>
td = TokenDrop(prob=0.1)

# List to store training loss during each epoch
training_loss_logger = []

# List to store entropy during training (for monitoring)
entropy_logger = []

In [None]:
# Let's see how many Parameters our Model has!
num_model_params = 0
for param in lstm_generator.parameters():
    num_model_params += param.flatten().shape[0]

print("-This Model Has %d (Approximately %d Million) Parameters!" % (num_model_params, num_model_params//1e6))

## Training

In [None]:
# Training loop
for epoch in trange(0, nepochs, leave=False, desc="Epoch"):
    # Set LSTM generator model to training mode
    lstm_generator.train()
    steps = 0
    # Iterate over batches in training data loader
    for text in tqdm(data_loader_train, desc="Training", leave=False):
        # Transform text tokens using training transform and move to device
        text_tokens = train_transform(list(text)).to(device)
        bs = text_tokens.shape[0]
        
        # Randomly drop input tokens
        input_text = td(text_tokens[:, 0:-1])
        output_text = text_tokens[:, 1:]
        
        # Initialize the memory buffers
        hidden = torch.zeros(num_layers, bs, hidden_size, device=device)
        memory = torch.zeros(num_layers, bs, hidden_size, device=device)
        
        # Forward pass through the LSTM generator
        pred, hidden, memory = lstm_generator(input_text, hidden, memory)

        # Calculate loss
        loss = loss_fn(pred.transpose(1, 2), output_text)
        
        # Zero gradients, perform backward pass, and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Log training loss
        training_loss_logger.append(loss.item())
        
        # Log entropy during training (for monitoring)
        with torch.no_grad():
            dist = Categorical(logits=pred)
            entropy_logger.append(dist.entropy().mean().item())

## Plot Metrics

In [None]:
_ = plt.figure(figsize=(10, 5))
_ = plt.plot(training_loss_logger[100:])
_ = plt.title("Training Loss")

In [None]:
_ = plt.figure(figsize=(10, 5))
_ = plt.plot(entropy_logger[1000:])
_ = plt.title("Distribution Entropy")

## Generate some text!
Lets use the fact that all of the articles have the title and content seperated by a : to get our model to generate some content based on a title

In [None]:
# Get some test examples
text = next(iter(data_loader_test))

In [None]:
# Choose an index from the test data loader
index = 0

# Temperature parameter for sampling
temp = 0.9

# We can either use an example from the test set or create our own article title!
# init_prompt = ["the next big thing from google :"]
init_prompt = [text[index].split(":")[0] + ":"]

# Transform the initial prompt into tokens and move to device
input_tokens = gen_transform(init_prompt).to(device)

print("INITIAL PROMPT:")
print(init_prompt[0])

print("\nPROMPT TOKENS:")
print(input_tokens)
print(vocab.lookup_tokens(input_tokens[0].cpu().numpy()))

In [None]:
log_tokens = []

# Set LSTM generator model to evaluation mode
lstm_generator.eval()

# Disable gradient calculation
with torch.no_grad():    
    # Initialize hidden and memory states
    hidden = torch.zeros(num_layers, 1, hidden_size, device=device)
    memory = torch.zeros(num_layers, 1, hidden_size, device=device)
    
    # Generate text
    for i in range(100):
        # Forward pass through LSTM generator
        data_pred, hidden, memory = lstm_generator(input_tokens, hidden, memory)
        
        # Sample from the distribution of probabilities (with temperature)
        dist = Categorical(logits=data_pred[:, -1] / temp)
        input_tokens = dist.sample().reshape(1, 1)
        
        # Append generated token to log_tokens
        log_tokens.append(input_tokens.cpu())
        
        # Check for end-of-sentence token
        if input_tokens.item() == 2:
            break

In [None]:
# Lets look at the raw tokens
pred_text = "".join(vocab.lookup_tokens(torch.cat(log_tokens, 1)[0].numpy()))
print(pred_text)

In [None]:
# Combine the model's output with the initial title to get our article!
init_prompt[0] + pred_text.replace("▁", " ").replace("<unk>", "")

In [None]:
# Lets have a look at the probabilities
# _ = plt.plot(F.softmax(data_pred/temp, -1).cpu().numpy().flatten())