# Classifying text with Transformers!

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import math

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
import torch.nn.functional as F

# We'll be using Pytorch's text library called torchtext! 
from torchtext.datasets import AG_NEWS
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import torchtext.transforms as T

from tqdm.notebook import trange, tqdm

In [None]:
# Define the hyperparameters
learning_rate = 1e-4

nepochs = 50

batch_size = 32

max_len = 128
data_set_root = "../../datasets"

# We'll be using the AG News Dataset
# Which contains a short news article and a single label to classify the "type" of article
# Note that for torchtext these datasets are NOT Pytorch dataset classes "AG_NEWS" is a function that
# returns a Pytorch DataPipe!

# Pytorch DataPipes vvv
# https://pytorch.org/data/main/torchdata.datapipes.iter.html

# vvv Good Blog on the difference between DataSet and DataPipe
# https://medium.com/deelvin-machine-learning/comparison-of-pytorch-dataset-and-torchdata-datapipes-486e03068c58
dataset_train = AG_NEWS(root=data_set_root, split="train")
dataset_test = AG_NEWS(root=data_set_root, split="test")

In [None]:
label, text = next(iter(dataset_train))

In [None]:
text

In [None]:
# The tokenizer is the method by which we split the sentence into "chunks" or "tokens"
tokenizer = get_tokenizer("basic_english")

def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)

# The vocab is all the unique tokens contained within our dataset
# and provides each token with it's own integer index.

# We will also add "special" tokens that we'll use to signal something to our model
# <pad> is a padding token that is added to the end of a sentence to ensure 
# the length of all sequences in a batch is the same
# <sos> signals the "Start-Of-Sentence" aka the start of the sequence
# <eos> signal the "End-Of-Sentence" aka the end of the sequence
# <unk> "unknown" token is used if a token is not contained in the vocab
vocab = build_vocab_from_iterator(
    yield_tokens(dataset_train),
    min_freq=2, # Only include a token if it appears more than 2 times in the dataset
    specials= ['<pad>', '<sos>', '<eos>', '<unk>'], # special case tokens
    special_first=True
)

# Set the <unk> "unknown" token as the default token
vocab.set_default_index(vocab['<unk>'])

In [None]:
# Lets have a look at the vocab!
vocab.get_itos()

In [None]:
class TokenDrop(nn.Module):
    """For a batch of tokens indices, randomly replace a non-specical token with <pad>.
    
    Args:
        prob (float): probability of dropping a token
        pad_token (int): index for the <pad> token
        num_special (int): Number of special tokens, assumed to be at the start of the vocab
    """

    def __init__(self, prob=0.1, pad_token=0, num_special=4):
        self.prob = prob
        self.num_special = num_special
        self.pad_token = pad_token

    def __call__(self, sample):
        # Randomly sample a bernoulli distribution with p=prob
        # to create a mask where 1 means we will replace that token
        mask = torch.bernoulli(self.prob * torch.ones_like(sample)).long()
        
        # only replace if the token is not a special token
        can_drop = (sample >= self.num_special).long()
        mask = mask * can_drop
        
        replace_with = (self.pad_token * torch.ones_like(sample)).long()
        
        sample_out = (1 - mask) * sample + mask * replace_with
        
        return sample_out

In [None]:
# We cab define 
text_tranform = T.Sequential(
    ## converts the sentences to indices based on given vocabulary
    T.VocabTransform(vocab=vocab),
    ## Add <sos> at beginning of each sentence. 1 because the index for <sos> in vocabulary is
    # 1 as seen in previous section
    T.AddToken(1, begin=True),
    # Crop the sentance if it is longer than the max length
    T.Truncate(max_seq_len=max_len),
    ## Add <eos> at beginning of each sentence. 2 because the index for <eos> in vocabulary is
    # 2 as seen in previous section
    T.AddToken(2, begin=False),
    # Convert the list of lists to a tensor, this will also
    # Pad a sentence with the <pad> token if it is shorter than the max length
    # This ensures all sentences are the same length!
    T.ToTensor(padding_value=0)
)

In [None]:
text_tokenizer = lambda batch: [tokenizer(x) for x in batch]
data_loader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)
data_loader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=True, num_workers=4)

In [None]:
# sinusoidal positional embeds
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb



# Transformer block with self-attention
class TransformerBlock(nn.Module):
    def __init__(self, hidden_size=128, num_heads=4):
        super(TransformerBlock, self).__init__()
        
        self.norm1 = nn.LayerNorm(hidden_size)
        self.multihead_attn = nn.MultiheadAttention(hidden_size, num_heads=num_heads, batch_first=True)
        
        self.norm2 = nn.LayerNorm(hidden_size)
        self.mlp = nn.Sequential(nn.Linear(hidden_size, hidden_size),
                                 nn.ELU(),
                                 nn.Linear(hidden_size, hidden_size))
                
    def forward(self, x):
        x = self.norm1(x)
        x = self.multihead_attn(x, x, x)[0] + x
        
        x = self.norm2(x)
        x = self.mlp(x) + x

        return x
    

# "Encoder-Only" Style Transformer
class NanoTransformer(nn.Module):
    def __init__(self, num_emb, output_size, hidden_size=128):
        super(NanoTransformer, self).__init__()
        
        # Create an embedding for each token
        self.embedding = nn.Embedding(num_emb, hidden_size)
        self.pos_emb = SinusoidalPosEmb(hidden_size)
        
        self.mlp_in = nn.Sequential(nn.Linear(hidden_size, hidden_size),
                                     nn.LayerNorm(hidden_size),
                                     nn.ELU(),
                                     nn.Linear(hidden_size, hidden_size))
        
        self.multihead_attn = nn.MultiheadAttention(hidden_size, num_heads=1, batch_first=True)

        self.fc_out = nn.Linear(hidden_size, output_size)
        
    def forward(self, input_seq):
        bs, l = input_seq.shape
        input_embs = self.embedding(input_seq)
        
        # Add a unique embedding to each token embedding depending on it's position in the sequence
        seq_indx = torch.arange(l, device=input_seq.device)
        pos_emb = self.pos_emb(seq_indx).reshape(1, l, -1).expand(bs, l, -1)
        embs = input_embs + pos_emb
        
        emb = self.mlp_in(embs)
        
        output, attn_map = self.multihead_attn(embs, embs, embs)
        
        return self.fc_out(output), attn_map

In [None]:
device = torch.device(0 if torch.cuda.is_available() else 'cpu')

In [None]:
hidden_size = 64

# Create model
tf_classifier = NanoTransformer(num_emb=len(vocab), output_size=4, hidden_size=hidden_size).to(device)

# Initialize the optimizer with above parameters
optimizer = optim.Adam(tf_classifier.parameters(), lr=learning_rate)

# Define the loss function
loss_fn = nn.CrossEntropyLoss()

# Custom transform that will randomly replace a token with <pad>
td = TokenDrop(prob=0.5)

In [None]:
# Let's see how many Parameters our Model has!
num_model_params = 0
for param in tf_classifier.parameters():
    num_model_params += param.flatten().shape[0]

print("-This Model Has %d (Approximately %d Million) Parameters!" % (num_model_params, num_model_params//1e6))

In [None]:
training_loss_logger = []
test_loss_logger = []

training_acc_logger = []
test_acc_logger = []

In [None]:
pbar = trange(0, nepochs, leave=False, desc="Epoch")    
train_acc = 0
test_acc = 0
for epoch in pbar:
    pbar.set_postfix_str('Accuracy: Train %.2f%%, Test %.2f%%' % (train_acc * 100, test_acc * 100))
    
    tf_classifier.train()
    steps = 0
    for label, text in tqdm(data_loader_train, desc="Training", leave=False):
        bs = label.shape[0]
        text_tokens = text_tranform(text_tokenizer(text)).to(device)
        text_tokens = td(text_tokens)
        label = (label - 1).to(device)

        pred, attn_map = tf_classifier(text_tokens)

        loss = loss_fn(pred[:, -1, :], label)
            
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        training_loss_logger.append(loss.item())
        
        train_acc += (pred[:, -1, :].argmax(1) == label).sum()
        steps += bs
        
    train_acc = (train_acc/steps).item()
    training_acc_logger.append(train_acc)
    
    tf_classifier.eval()
    steps = 0
    with torch.no_grad():
        for label, text in tqdm(data_loader_test, desc="Testing", leave=False):
            bs = label.shape[0]
            text_tokens = text_tranform(text_tokenizer(text)).to(device)
            label = (label - 1).to(device)

            pred, attn_map = tf_classifier(text_tokens)

            loss = loss_fn(pred[:, -1, :], label)
            test_loss_logger.append(loss.item())

            test_acc += (pred[:, -1, :].argmax(1) == label).sum()
            steps += bs

        test_acc = (test_acc/steps).item()
        test_acc_logger.append(test_acc)

In [None]:
_ = plt.figure(figsize=(10, 5))
_ = plt.plot(np.linspace(0, nepochs, len(training_loss_logger)), training_loss_logger)
_ = plt.plot(np.linspace(0, nepochs, len(test_loss_logger)), test_loss_logger)

_ = plt.legend(["Train", "Test"])
_ = plt.title("Training Vs Test Loss")
_ = plt.xlabel("Epochs")
_ = plt.ylabel("Loss")

In [None]:
_ = plt.figure(figsize=(10, 5))
_ = plt.plot(np.linspace(0, nepochs, len(training_acc_logger)), training_acc_logger)
_ = plt.plot(np.linspace(0, nepochs, len(test_acc_logger)), test_acc_logger)

_ = plt.legend(["Train", "Test"])
_ = plt.title("Training Vs Test Accuracy")
_ = plt.xlabel("Epochs")
_ = plt.ylabel("Accuracy")
print("Max Test Accuracy %.2f%%" % (np.max(test_acc_logger) * 100))

In [None]:
label, text = next(iter(data_loader_test))
tf_classifier.eval()
with torch.no_grad():
    text_tokens = text_tranform(text_tokenizer(text)).to(device)
    pred, attn_map = tf_classifier(text_tokens)

In [None]:
index = 2
text[index]

In [None]:
text_tokens[index]

In [None]:
plt.plot(attn_map[index][-1].detach().cpu().flatten().numpy())

In [None]:
top_5 = attn_map[index][-1].argsort(descending=True)[:5]
vocab.lookup_tokens(text_tokens[index, top_5].cpu().numpy())