In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import io
import re

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
import torch.nn.functional as F
from torch.distributions import Categorical

from torchtext.datasets import WikiText2, EnWik9, AG_NEWS
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import torchtext.transforms as T
from torch.hub import load_state_dict_from_url
from torchtext.data.functional import sentencepiece_tokenizer, load_sp_model

from tqdm.notebook import trange, tqdm

In [None]:
# Define hyperparameters
# Step size for parameter updates
learning_rate = 1e-4

# Number of training epochs
nepochs = 20

# Number of samples processed together
batch_size = 32

# Maximum sequence length
max_len = 128

# Root directory of the dataset
data_set_root = "../data"

## Dataset, Tokenizers and Vocab!

In [None]:

# We'll be using the AG News Dataset
# Which contains a short news article and a single label to classify the "type" of article
# Note that for torchtext these datasets are NOT Pytorch dataset classes "AG_NEWS" is a function that
# returns a Pytorch DataPipe!

# Pytorch DataPipes vvv
# https://pytorch.org/data/main/torchdata.datapipes.iter.html

# vvv Good Blog on the difference between DataSet and DataPipe
# https://medium.com/deelvin-machine-learning/comparison-of-pytorch-dataset-and-torchdata-datapipes-486e03068c58
# Depending on the dataset sometimes the dataset doesn't download and gives an error
# and you'll have to download and extract manually 
# "The datasets supported by torchtext are datapipes from the torchdata project, which is still in Beta status"

# Un-comment to triger the DataPipe to download the data vvv
# dataset_train = AG_NEWS(root=data_set_root, split="train")
# data = next(iter(dataset_train))

# Side-Note I've noticed that the WikiText dataset is no longer able to be downloaded :(

In [None]:
# Un-Comment to train sentence-piece model for tokenizer and vocab!

# from torchtext.data.functional import generate_sp_model

# with open(os.path.join(data_set_root, "datasets/AG_NEWS/train.csv")) as f:
#     with open(os.path.join(data_set_root, "datasets/AG_NEWS/data.txt"), "w") as f2:
#         for i, line in enumerate(f):
#             text_only = "".join(line.split(",")[1:])
#             filtered = re.sub(r'\\|\\n|;', ' ', text_only.replace('"', ' ').replace('\n', ' ')) # remove newline characters
#             filtered = filtered.replace(' #39;', "'")
#             filtered = filtered.replace(' #38;', "&")
#             filtered = filtered.replace(' #36;', "$")
#             filtered = filtered.replace(' #151;', "-")

#             f2.write(filtered.lower() + "\n")

# generate_sp_model(os.path.join(data_set_root, "datasets/AG_NEWS/data.txt"), 
#                   vocab_size=20000, model_prefix='spm_ag_news')

In [None]:
class AGNews(Dataset):
    """
    The AGNews class is a custom Dataset for handling the AG News dataset.
    This dataset consists of news articles categorized into four classes.
    The class loads the data from CSV files, preprocesses the text by cleaning and combining
    relevant columns, and provides an interface to access individual samples along with their
    corresponding class labels.
    
    Attributes:
        df (pd.DataFrame): The DataFrame containing the preprocessed dataset.
    """
    
    def __init__(self, num_datapoints, test_train="train"):
        # Load the dataset from the specified CSV file
        self.df = pd.read_csv(os.path.join(data_set_root, "datasets/AG_NEWS/" + test_train + ".csv"),
                              names=["Class", "Title", "Content"])
        
        # Fill any missing values with empty strings
        self.df.fillna('', inplace=True)
        
        # Combine the Title and Content columns into a single Article column
        self.df['Article'] = self.df['Title'] + " : " + self.df['Content']
        
        # Drop the now redundant Title and Content columns
        self.df.drop(['Title', 'Content'], axis=1, inplace=True)
        
        # Clean the Article column by removing unwanted characters and replacing HTML codes
        self.df['Article'] = self.df['Article'].str.replace(r'\\n|\\|\\r|\\r\\n|\n|"', ' ', regex=True)
        self.df['Article'] = self.df['Article'].replace({' #39;': "'", 
                                                         ' #38;': "&", 
                                                         ' #36;': "$",
                                                         ' #151;': "-"}, 
                                                        regex=True)

    def __getitem__(self, index):
        # Retrieve the article text and convert it to lowercase
        text = self.df.loc[index]["Article"].lower()
        
        # Retrieve the class label and convert it to an integer
        class_index = int(self.df.loc[index]["Class"]) - 1

        # Return a tuple of the class index and the article text
        return class_index, text
    
    def __len__(self):
        # Return the number of data points in the dataset
        return len(self.df)

In [None]:
# Create AGNews dataset instances for training and testing
dataset_train = AGNews(num_datapoints=data_set_root, test_train="train")
dataset_test = AGNews(num_datapoints=data_set_root, test_train="test")

# Create data loaders for training and testing datasets
# DataLoader for training dataset
data_loader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=8, drop_last=True)
# DataLoader for testing dataset
data_loader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=True, num_workers=8)

In [None]:
# Example of using the tokenizer
# Load the SentencePiece model
sp_model = load_sp_model("spm_ag_news.model")

# Create a tokenizer using the loaded model
tokenizer = sentencepiece_tokenizer(sp_model)

# Iterate over tokens generated by the tokenizer
for token in tokenizer(["i am creating"]):
    print(token)

In [None]:
# Define a function to yield tokens from a file
def yield_tokens(file_path):
    # Open the file in UTF-8 encoding
    with io.open(file_path, encoding='utf-8') as f:
        # Iterate over each line in the file
        for line in f:
            # Yield the token split by tab character
            yield [line.split("\t")[0]]

# Build vocabulary from the iterator of tokens
# We will also add "special" tokens that we'll use to signal something to our model
# <pad> is a padding token that is added to the end of a sentence to ensure 
# the length of all sequences in a batch is the same
# <sos> signals the "Start-Of-Sentence" aka the start of the sequence
# <eos> signals the "End-Of-Sentence" aka the end of the sequence
# <unk> "unknown" token is used if a token is not contained in the vocab
vocab = build_vocab_from_iterator(
    yield_tokens("spm_ag_news.vocab"),
    # Define special tokens with special_first=True to place them at the beginning of the vocabulary
    specials=['<pad>', '<sos>', '<eos>', '<unk>'],
    special_first=True
)

# Set default index for out-of-vocabulary tokens
vocab.set_default_index(vocab['<unk>'])

In [None]:
class TokenDrop(nn.Module):
    """For a batch of tokens indices, randomly replace a non-specical token with <pad>.
    
    Args:
        prob (float): probability of dropping a token
        pad_token (int): index for the <pad> token
        num_special (int): Number of special tokens, assumed to be at the start of the vocab
    """

    def __init__(self, prob=0.1, pad_token=0, num_special=4):
        self.prob = prob
        self.num_special = num_special
        self.pad_token = pad_token

    def __call__(self, sample):
        # Randomly sample a bernoulli distribution with p=prob
        # to create a mask where 1 means we will replace that token
        mask = torch.bernoulli(self.prob * torch.ones_like(sample)).long()
        
        # only replace if the token is not a special token
        can_drop = (sample >= self.num_special).long()
        mask = mask * can_drop
        
        replace_with = (self.pad_token * torch.ones_like(sample)).long()
        
        sample_out = (1 - mask) * sample + mask * replace_with
        
        return sample_out

In [None]:
# Define a transformation pipeline for training data
train_transform = T.Sequential(
    # Tokenize sentences using pre-existing SentencePiece tokenizer model
    T.SentencePieceTokenizer("spm_ag_news.model"),
    # Convert tokens to indices based on given vocabulary
    T.VocabTransform(vocab=vocab),
    # Add <sos> token at the beginning of each sentence (index 1 in vocabulary)
    T.AddToken(1, begin=True),
    # Crop the sentence if it is longer than the max length
    T.Truncate(max_seq_len=max_len),
    # Add <eos> token at the end of each sentence (index 2 in vocabulary)
    T.AddToken(2, begin=False),
    # Convert the list of lists to a tensor and pad sentences with the <pad> token if shorter than max length
    T.ToTensor(padding_value=0)
)

# Define a transformation pipeline for generation (without truncation)
gen_transform = T.Sequential(
    # Tokenize sentences using pre-existing SentencePiece tokenizer model
    T.SentencePieceTokenizer("spm_ag_news.model"),
    # Convert tokens to indices based on given vocabulary
    T.VocabTransform(vocab=vocab),
    # Add <sos> token at the beginning of each sentence (index 1 in vocabulary)
    T.AddToken(1, begin=True),
    # Convert the list of lists to a tensor and pad sentences with the <pad> token if shorter than max length
    T.ToTensor(padding_value=0)
)

## Looking at the data and tokenizer

In [None]:
text = next(iter(data_loader_train))
index = 0
input_tokens = train_tranform(text)
print("SENTENCE")
print(text[index])
print()
print("TOKENS")
print(vocab.lookup_tokens(input_tokens[index].numpy()))

In [None]:
print("TOKENS BACK TO SENTENCE")

pred_text = "".join(vocab.lookup_tokens(input_tokens[index].numpy()))
pred_text.replace("▁", " ")

## Create LSTM-Attention Model

In [None]:
class LSTM(nn.Module):
    def __init__(self, num_emb, num_layers=1, emb_size=128, hidden_size=128):
        super(LSTM, self).__init__()
        
        # Embedding layer to convert token indices to dense vectors
        self.embedding = nn.Embedding(num_emb, emb_size)

        # Additional MLP layers for embedding transformation
        self.mlp_emb = nn.Sequential(nn.Linear(emb_size, emb_size),
                                     nn.LayerNorm(emb_size),
                                     nn.ELU(),
                                     nn.Linear(emb_size, emb_size))
        
        # LSTM layer for sequential processing
        self.lstm = nn.LSTM(input_size=emb_size, hidden_size=hidden_size, 
                            num_layers=num_layers, batch_first=True)
        
        # Multi-head attention mechanism to capture dependencies between tokens
        self.attention = nn.MultiheadAttention(embed_dim=hidden_size, 
                                               num_heads=8,
                                               batch_first=True, 
                                               dropout=0.1)

        # Final MLP layers for output transformation
        self.mlp_out = nn.Sequential(nn.Linear(hidden_size, hidden_size//2),
                                     nn.LayerNorm(hidden_size//2),
                                     nn.ELU(),
                                     nn.Dropout(0.5),
                                     nn.Linear(hidden_size//2, num_emb))
        
    def forward(self, input_token, hidden_seq, hidden_in, mem_in):
        # Convert input tokens to dense vectors using embedding layer
        input_embs = self.embedding(input_token)
        # Additional MLP layers for embedding transformation
        input_embs = self.mlp_emb(input_embs)
                
        # Pass input embeddings through LSTM layer
        output, (hidden_out, mem_out) = self.lstm(input_embs, (hidden_in, mem_in))
        # Log the output of the final LSTM layer
        hidden_seq += [output]
        hidden_cat = torch.cat(hidden_seq, 1)
        
        # Apply multi-head attention mechanism over LSTM outputs
        # Use a single query from the current timestep
        # Keys and Values created from the outputs of LSTM from all previous timesteps
        attn_output, attn_output_weights = self.attention(output, hidden_cat, hidden_cat)  # Q, K, V
        # Combine attention output with LSTM output
        attn_output = attn_output + output
                
        # Apply final MLP layers for output transformation
        return self.mlp_out(attn_output), hidden_seq, hidden_out, mem_out

## Initialise Model and Optimizer

In [None]:
# Check if GPU is available, set device accordingly
device = torch.device(0 if torch.cuda.is_available() else 'cpu')

# Define the size of the embedding and hidden layers
emb_size = 256
hidden_size = 256

# Number of LSTM layers
num_layers = 2

# Create an instance of the LSTM model for text generation
lstm_generator = LSTM(num_emb=len(vocab), num_layers=num_layers, 
                      emb_size=emb_size, hidden_size=hidden_size).to(device)

# Initialize the optimizer with Adam and set the learning rate and weight decay
optimizer = optim.Adam(lstm_generator.parameters(), lr=learning_rate, weight_decay=1e-4)

# Define the loss function for the text generation task
loss_fn = nn.CrossEntropyLoss()

# Custom transform to randomly replace tokens with <pad>
td = TokenDrop(prob=0.1)

# Initialize lists to log training loss and entropy
training_loss_logger = []
entropy_logger = []

In [None]:
# Let's see how many Parameters our Model has!
num_model_params = 0
for param in lstm_generator.parameters():
    num_model_params += param.flatten().shape[0]

print("-This Model Has %d (Approximately %d Million) Parameters!" % (num_model_params, num_model_params//1e6))

## Training!

In [None]:
for epoch in trange(0, nepochs, leave=False, desc="Epoch"):    
    # Set the model to training mode
    lstm_generator.train()
    steps = 0
    # Iterate over the training data loader, displaying progress
    for text in tqdm(data_loader_train, desc="Training", leave=False):
        # Transform the text data into token indices and move it to the appropriate device
        text_tokens = train_tranform(list(text)).to(device)
        bs = text_tokens.shape[0]
        
        # Randomly drop input tokens to improve generalization
        input_text = td(text_tokens[:, 0:-1])
        output_text = text_tokens[:, 1:]
        
        # Initialize the memory buffers for the LSTM
        hidden = torch.zeros(num_layers, bs, hidden_size, device=device)
        memory = torch.zeros(num_layers, bs, hidden_size, device=device)
        
        # Use a list to log the output of the LSTM at each timestep for the attention mechanism
        hidden_seq = []
        
        # Manually loop through the LSTM to log the output for attention mechanism
        loss = 0
        for i in range(input_text.shape[1]):
            input_token = input_text[:, i].unsqueeze(1)
            output_token = output_text[:, i].unsqueeze(1)

            # Forward pass through the LSTM model
            pred, hidden_seq, hidden, memory = lstm_generator(input_token, hidden_seq, hidden, memory)

            # Compute the loss between predicted tokens and ground truth
            loss += loss_fn(pred.transpose(1, 2), output_token)
        
        # Average the loss over all time steps
        loss /= (i + 1)
        
        # Zero the gradients to prevent accumulation
        optimizer.zero_grad()
        
        # Backpropagation to compute gradients
        loss.backward()
        
        # Update the model parameters
        optimizer.step()
        
        # Log the training loss for visualization
        training_loss_logger.append(loss.item())
        
        # Compute the entropy of the predicted distribution
        with torch.no_grad():
            dist = Categorical(logits=pred)
            entropy_logger.append(dist.entropy().mean().item())

## Plot Metrics

In [None]:
_ = plt.figure(figsize=(10, 5))
_ = plt.plot(training_loss_logger)
_ = plt.title("Training Loss")

In [None]:
_ = plt.figure(figsize=(10, 5))
_ = plt.plot(entropy_logger)
_ = plt.title("Distribution Entropy")

## Generate some text!
Lets use the fact that all of the articles have the title and content seperated by a ":" to get our model to generate some content based on a title

In [None]:
# Get some test examples
text = next(iter(data_loader_test))

In [None]:
# Select an index from the test data
index = 0
temp = 0.8

# Extract the title and content from the text
title = text[index].split(":")[0]
content = text[index].split(":")[1]

# Create an initial prompt using the title
init_prompt = [title + ":"]

# Transform the initial prompt into tokens and move to the appropriate device
input_tokens = gen_tranform(init_prompt).to(device)

# Print the initial prompt, original content, and prompt tokens for inspection
print("INITIAL PROMPT:")
print(title)
print("")
print("ORIGINAL CONTENT:")
print(content)
print("\nPROMPT TOKENS:")
print(input_tokens)
print(vocab.lookup_tokens(input_tokens[0].cpu().numpy()))

In [None]:
log_tokens = []
lstm_generator.eval()

with torch.no_grad():    
    # Initialize the hidden state and memory for LSTM
    hidden = torch.zeros(num_layers, 1, hidden_size, device=device)
    memory = torch.zeros(num_layers, 1, hidden_size, device=device)
    
    # Initialize the hidden sequence for logging LSTM outputs
    hidden_seq = []
    
    # Pass each token of the input_tokens through the LSTM
    for i in range(input_tokens.shape[1]):
        input_token = input_tokens[:, i].unsqueeze(1)

        # Pass the input token through the LSTM model
        data_pred, hidden_seq, hidden, memory = lstm_generator(input_token, hidden_seq, hidden, memory)
        
    # Sample the next token based on the output distribution of the LSTM
    dist = Categorical(logits=data_pred[:, -1]/temp)
    input_tokens = dist.sample().reshape(1, 1)
    
    # Generate text tokens for a fixed number of iterations or until reaching end-of-sequence token
    for i in trange(10):
        # Pass the current token through the LSTM model
        data_pred, hidden_seq, hidden, memory = lstm_generator(input_tokens, hidden_seq, hidden, memory)
        
        # Sample the next token based on the output distribution of the LSTM
        dist = Categorical(logits=data_pred[:, -1]/temp)
        input_tokens = dist.sample().reshape(1, 1)
        
        # Append the sampled token to the log_tokens list
        log_tokens.append(input_tokens.cpu())
        
        # Break the loop if the end-of-sequence token is sampled
        if input_tokens.item() == 2:
            break

In [None]:
# Join the tokens in log_tokens into a single string using the vocabulary lookup
pred_text = "".join(vocab.lookup_tokens(torch.cat(log_tokens, 1)[0].numpy()))

# Print the generated text
print(pred_text)

In [None]:
# Combine the initial title with the generated text
final_article = init_prompt[0] + pred_text.replace("▁", " ").replace("<unk>", "")

# Print the final article
print(final_article)

In [None]:
# Lets have a look at the probabilities
_ = plt.plot(F.softmax(data_pred[:, -1]/temp, -1).cpu().numpy().flatten())