# Transformer, many-to-many with text

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import io
import re
import math

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
import torch.nn.functional as F
from torch.distributions import Categorical

from torchtext.datasets import WikiText2, EnWik9, AG_NEWS
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import torchtext.transforms as T
from torch.hub import load_state_dict_from_url
from torchtext.data.functional import sentencepiece_tokenizer, load_sp_model

from tqdm.notebook import trange, tqdm

In [None]:
# Define the hyperparameters
learning_rate = 1e-4

nepochs = 100

batch_size = 32

max_len = 64
data_set_root = "../../datasets"

# We'll be using the AG News Dataset
# Which contains a short news article and a single label to classify the "type" of article
# Note that for torchtext these datasets are NOT Pytorch dataset classes "AG_NEWS" is a function that
# returns a Pytorch DataPipe!

# Pytorch DataPipes vvv
# https://pytorch.org/data/main/torchdata.datapipes.iter.html

# vvv Good Blog on the difference between DataSet and DataPipe
# https://medium.com/deelvin-machine-learning/comparison-of-pytorch-dataset-and-torchdata-datapipes-486e03068c58
# Depending on the dataset sometimes the dataset doesn't download and gives an error
# and you'll have to download and extract manually 
# "The datasets supported by torchtext are datapipes from the torchdata project, which is still in Beta status"

# Un-comment to triger the DataPipe to download the data vvv
# dataset_train = YahooAnswers(root=data_set_root, split="train")
# data = next(iter(dataset_train))

# Side-Note I've noticed that the WikiText dataset is no longer able to be downloaded :(

In [None]:
class AGNews(Dataset):
    def __init__(self, num_datapoints, test_train="train"):
        self.df = pd.read_csv(os.path.join(data_set_root, "datasets/AG_NEWS/" + test_train + ".csv"),
                              names=["Class", "Title", "Content"])
        
        self.df.fillna('', inplace=True)
        self.df['Article'] = self.df['Title'] + " : " + self.df['Content']
        self.df.drop(['Title', 'Content'], axis=1, inplace=True)
        self.df['Article'] = self.df['Article'].str.replace(r'\\n|\\|\\r|\\r\\n|\n|"', ' ', regex=True)

    def __getitem__(self, index):
        text = self.df.loc[index]["Article"].lower()

        return text

    def __len__(self):
        return len(self.df)

In [None]:
dataset_train = AGNews(num_datapoints=data_set_root, test_train="train")
dataset_test = AGNews(num_datapoints=data_set_root, test_train="test")

In [None]:
# Un-Comment to train sentence-piece model for tokenizer and vocab!

# from torchtext.data.functional import generate_sp_model

# with open(os.path.join(data_set_root, "datasets/AG_NEWS/train.csv")) as f:
#     with open(os.path.join(data_set_root, "datasets/AG_NEWS/data.txt"), "w") as f2:
#         for i, line in enumerate(f):
#             text_only = "".join(line.split(",")[1:])
#             filtered = re.sub(r'\\|\\n|;', ' ', text_only.replace('"', ' ').replace('\n', ' ')) # remove newline characters
#             f2.write(filtered.lower() + "\n")

# generate_sp_model(os.path.join(data_set_root, "datasets/AG_NEWS/data.txt"), 
#                   vocab_size=20000, model_prefix='spm_ag_news')

In [None]:
def yield_tokens(file_path):
    with io.open(file_path, encoding = 'utf-8') as f:
        for line in f:
            yield [line.split("\t")[0]]
            
vocab = build_vocab_from_iterator(yield_tokens("spm_ag_news.vocab"), 
                                  specials=['<pad>', '<sos>', '<eos>', '<unk>'],
                                  special_first=True)
vocab.set_default_index(vocab['<unk>'])

In [None]:
class TokenDrop(nn.Module):
    """For a batch of tokens indices, randomly replace a non-specical token with <pad>.
    
    Args:
        prob (float): probability of dropping a token
        pad_token (int): index for the <pad> token
        num_special (int): Number of special tokens, assumed to be at the start of the vocab
    """

    def __init__(self, prob=0.1, pad_token=0, num_special=4):
        self.prob = prob
        self.num_special = num_special
        self.pad_token = pad_token

    def __call__(self, sample):
        # Randomly sample a bernoulli distribution with p=prob
        # to create a mask where 1 means we will replace that token
        mask = torch.bernoulli(self.prob * torch.ones_like(sample)).long()
        
        # only replace if the token is not a special token
        can_drop = (sample >= self.num_special).long()
        mask = mask * can_drop
        
        replace_with = (self.pad_token * torch.ones_like(sample)).long()
        
        sample_out = (1 - mask) * sample + mask * replace_with
        
        return sample_out

In [None]:
train_tranform = T.Sequential(
    # Tokeniz with pre-existing Tokenizer
    T.SentencePieceTokenizer("spm_ag_news.model"),
    ## converts the sentences to indices based on given vocabulary
    T.VocabTransform(vocab=vocab),
    ## Add <sos> at beginning of each sentence. 1 because the index for <sos> in vocabulary is
    # 1 as seen in previous section
    T.AddToken(1, begin=True),
    # Crop the sentance if it is longer than the max length
    T.Truncate(max_seq_len=max_len),
    ## Add <eos> at beginning of each sentence. 2 because the index for <eos> in vocabulary is
    # 2 as seen in previous section
    T.AddToken(2, begin=False),
    # Convert the list of lists to a tensor, this will also
    # Pad a sentence with the <pad> token if it is shorter than the max length
    # This ensures all sentences are the same length!
    T.ToTensor(padding_value=0),
)

gen_tranform = T.Sequential(
    # Tokeniz with pre-existing Tokenizer
    T.SentencePieceTokenizer("spm_ag_news.model"),
    ## converts the sentences to indices based on given vocabulary
    T.VocabTransform(vocab=vocab),
    ## Add <sos> at beginning of each sentence. 1 because the index for <sos> in vocabulary is
    # 1 as seen in previous section
    T.AddToken(1, begin=True),
    # Convert the list of lists to a tensor, this will also
    # Pad a sentence with the <pad> token if it is shorter than the max length
    # This ensures all sentences are the same length!
    T.ToTensor(padding_value=0)
)


In [None]:
data_loader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=8, drop_last=True)
data_loader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=True, num_workers=8)

In [None]:
text = next(iter(data_loader_train))
index = 2
input_tokens = train_tranform(text)
print(text[index])
print()
print(vocab.lookup_tokens(input_tokens[index].numpy()))

In [None]:
pred_text = "".join(vocab.lookup_tokens(input_tokens[index].numpy()))
pred_text.replace("▁", " ")

In [None]:
# sinusoidal positional embeds
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb



# Transformer block with self-attention and causal masking
class TransformerBlock(nn.Module):
    def __init__(self, hidden_size=128, num_heads=4):
        super(TransformerBlock, self).__init__()
        
        self.norm1 = nn.LayerNorm(hidden_size)
        self.multihead_attn = nn.MultiheadAttention(hidden_size, num_heads=num_heads, batch_first=True)
        
        self.norm2 = nn.LayerNorm(hidden_size)
        self.mlp = nn.Sequential(nn.Linear(hidden_size, hidden_size * 4),
                                 nn.ELU(),
                                 nn.Linear(hidden_size * 4, hidden_size))
                
    def forward(self, x):
        bs, l, h = x.shape
        mask = torch.triu(torch.ones(l, l, device=x.device), 1).bool()

        norm_x = self.norm1(x)
        x = self.multihead_attn(norm_x, norm_x, norm_x, attn_mask=mask)[0] + x
        
        norm_x = self.norm2(x)
        x = self.mlp(norm_x) + x
        return x
    

# "Decoder-Only" Style Transformer with self-attention
class Transformer(nn.Module):
    def __init__(self, num_emb, hidden_size=128, num_layers=3, num_heads=4):
        super(Transformer, self).__init__()
        
        # Create an embedding for each token
        self.embedding = nn.Embedding(num_emb, hidden_size)
        self.pos_emb = SinusoidalPosEmb(hidden_size)
        
        self.blocks = nn.ModuleList([
            TransformerBlock(hidden_size, num_heads) for _ in range(num_layers)
        ])
                
        self.fc_out = nn.Linear(hidden_size, num_emb)
        
    def forward(self, input_seq):        
        input_embs = self.embedding(input_seq)
        bs, l, h = input_embs.shape

        # Add a unique embedding to each token embedding depending on it's position in the sequence
        seq_indx = torch.arange(l, device=input_seq.device)
        pos_emb = self.pos_emb(seq_indx).reshape(1, l, h).expand(bs, l, h)
        embs = input_embs + pos_emb
        
        for block in self.blocks:
            output = block(embs)
        
        return self.fc_out(output)

In [None]:
device = torch.device(1 if torch.cuda.is_available() else 'cpu')

In [None]:
hidden_size = 256

num_layers = 8
num_heads = 8

# Create model
tf_generator = Transformer(num_emb=len(vocab), num_layers=num_layers, 
                           hidden_size=hidden_size, num_heads=num_heads).to(device)

# Initialize the optimizer with above parameters
optimizer = optim.Adam(tf_generator.parameters(), lr=learning_rate, weight_decay=1e-4)

# Define the loss function
loss_fn = nn.CrossEntropyLoss()

# Custom transform that will randomly replace a token with <pad>
td = TokenDrop(prob=0.2)

In [None]:
# cp = torch.load("gen_model.pt")
# tf_generator.load_state_dict(cp["model"])

In [None]:
# Let's see how many Parameters our Model has!
num_model_params = 0
for param in tf_generator.parameters():
    num_model_params += param.flatten().shape[0]

print("-This Model Has %d (Approximately %d Million) Parameters!" % (num_model_params, num_model_params//1e6))

In [None]:
training_loss_logger = []
entropy_logger = []

In [None]:
for epoch in trange(0, nepochs, leave=False, desc="Epoch"):    
    tf_generator.train()
    steps = 0
    for text in tqdm(data_loader_train, desc="Training", leave=False):
        text_tokens = train_tranform(list(text)).to(device)
        bs = text_tokens.shape[0]
        
        # Randomly drop input tokens
        input_text = td(text_tokens[:, 0:-1])
        output_text = text_tokens[:, 1:]

        pred = tf_generator(input_text)

        loss = loss_fn(pred.transpose(1, 2), output_text)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        training_loss_logger.append(loss.item())
        with torch.no_grad():
            dist = Categorical(logits=pred)
            entropy_logger.append(dist.entropy().mean().item())


In [None]:
_ = plt.figure(figsize=(10, 5))
_ = plt.plot(training_loss_logger[1000:])
_ = plt.title("Training Loss")

In [None]:
_ = plt.figure(figsize=(10, 5))
_ = plt.plot(entropy_logger[1000:])
_ = plt.title("Distribution Entropy")

In [None]:
text = next(iter(data_loader_train))

In [None]:
index = 0
text[index][:]

In [None]:
# init_prompt = ["the rain is coming :"]
init_prompt = [text[index].split(":")[0] + ":"]


init_tokens = gen_tranform(init_prompt)
print(init_tokens)
print(vocab.lookup_tokens(init_tokens[0].cpu().numpy()))

In [None]:
temp = 0.6

In [None]:
log_tokens = [init_tokens]
tf_generator.eval()

with torch.no_grad():    
    for i in range(100):
        input_tokens = torch.cat(log_tokens, 1)
        data_pred = tf_generator(input_tokens.to(device))
#         We can take the token with the highest prob
#         input_tokens = data_pred[:, -1].argmax().reshape(1, 1)
        
        # Or sample from the distribution of probs!
        dist = Categorical(logits=data_pred[:, -1]/temp)
        next_tokens = dist.sample().reshape(1, 1)
        
        log_tokens.append(next_tokens.cpu())
        
        if next_tokens.item() == 2:
            break

In [None]:
pred_text = "".join(vocab.lookup_tokens(torch.cat(log_tokens, 1)[0].numpy()))
print(pred_text)

In [None]:
pred_text.replace("▁", " ").replace("<unk>", "")

In [None]:
text[index][:]

In [None]:
plt.plot(F.softmax(data_pred[0, -1]/temp, -1).cpu().numpy().flatten())

In [None]:
# torch.save({"model" : tf_generator.state_dict()}, "gen_model.pt")