# Training Encoder-Decoder Transformer 
Training is performed in 3 steps based on [GPT](https://openai.com/index/instruction-following/)
- Pretraining
- Fine-tuning
- Ranking and RLHF (Ignored for now)

In [None]:
from pathlib import Path
import pandas as pd
import numpy as np
import math
from copy import deepcopy

from tqdm.notebook import trange, tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
import torch.nn.functional as F
from torch.distributions import Categorical

import datasets
from transformers import AutoTokenizer
from tokenizers import Tokenizer
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.processors import TemplateProcessing
from tokenizers.models import BPE

torch.backends.cuda.matmul.allow_tf32=True

In [None]:
# Define the learning rate for the optimizer
learning_rate = 1e-4

# Define the number of epochs for training
nepochs = 20

# Define the batch size for mini-batch gradient descent
batch_size = 2

# Define the root directory of the dataset
DATASET_PATH = Path("../datasets")

## Dataprocessing

In [None]:
class CodeLintDataset(Dataset):
    def __init__(self, dataset_path):
        self.df = pd.read_csv(dataset_path)
        # self.df.fillna('', inplace=True)
    
    def __getitem__(self, index):
        return self.df.loc[index]["code"], self.df.loc[index]["label"]
    
    def __len__(self):
        return len(self.df)
    
stackv2_python = CodeLintDataset(DATASET_PATH / Path("dataset_pylint.csv"))
data_loader_train = DataLoader(stackv2_python, batch_size=batch_size, shuffle=True)
#data_loader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=True, num_workers=4)

In [None]:
next(iter(data_loader_train))

## Training Tokenizer

In [None]:
# gpt2_tokenizer = AutoTokenizer.from_pretrained("gpt2")
# def get_training_corpus():
#     dataset = datasets.load_dataset("csv", data_files=str(DATASET_PATH / Path("dataset_pylint.csv")))["train"]["code"]
#     for start_idx in range(0, len(dataset), 1000):
#         samples = dataset[start_idx : start_idx + 1000]
#         yield samples
# training_corpus = get_training_corpus()
# gpt2_tokenizer.add_special_tokens({"additional_special_tokens":["[SOC]", "[EOC]"]})
# print(gpt2_tokenizer.all_special_tokens)
# python_tokenizer = gpt2_tokenizer.train_new_from_iterator(training_corpus, 52000)

# # python_tokenizer.post_processor = TemplateProcessing(
# #     single="[SOC] $A [EOC]",
# #     pair="[SOC] $A [SEP] $B:1 [EOC]",
# #     special_tokens=[
# #         ("[SOC]", 0),
# #         ("[SEP]", 1),
# #         ("[EOC]", 2)
# #     ]
# # )

In [None]:
# example_enc = python_tokenizer.encode('''def add_numbers(a, b):
#     """Add the two numbers `a` and `b`."""
#     return a + b''')
# #python_tokenizer.decode(example_enc)
# example_enc

In [None]:
# bpe_tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
# trainer = BpeTrainer(special_tokens=["[UNK]", "[SOC]", "[SEP]", "[EOC]", "[PAD]", "[MASK]"],  min_frequency=2)
# #files = [str(DATASET_PATH / Path("code_raw.csv"))] # USE ITERATOR INSTEAD!!!
# bpe_tokenizer.train_from_iterator(stackv2_python.df["code"], trainer)
# bpe_tokenizer.post_processor = TemplateProcessing(
#     single="[SOC] $A [EOC]",
#     pair="[SOC] $A [SEP] $B:1 [EOC]:1",
#     special_tokens=[
#         ("[SOC]", bpe_tokenizer.token_to_id("[SOC]")),
#         ("[SEP]", bpe_tokenizer.token_to_id("[SEP]")),
#         ("[EOC]", bpe_tokenizer.token_to_id("[EOC]")),
#     ],
# )
# bpe_tokenizer.enable_padding(pad_id=bpe_tokenizer.token_to_id("[PAD]"), pad_token="[PAD]")

# EXPERIMENTING A DIFFERENT TOKENIZER

In [None]:
bpe_tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
trainer = BpeTrainer(special_tokens=["[PAD]", "[UNK]", "[MASK]", "[SOC]", "[SEP]", "[EOC]", "[SOL]", "[EOL]"],  min_frequency=2)
#files = [str(DATASET_PATH / Path("code_raw.csv"))] # USE ITERATOR INSTEAD!!!
bpe_tokenizer.train_from_iterator(stackv2_python.df["code"], trainer) # USE datasets TO CREATE BATCHES TO SPEED UP! (training from memory doc)
bpe_tokenizer.post_processor = TemplateProcessing(
    single="[SOC] $A [EOC]",
    pair="[SOC] $A [SEP] $B:1 [EOC]:1",
    special_tokens=[
        ("[SOC]", bpe_tokenizer.token_to_id("[SOC]")),
        ("[SEP]", bpe_tokenizer.token_to_id("[SEP]")),
        ("[EOC]", bpe_tokenizer.token_to_id("[EOC]")),
    ],
)
bpe_tokenizer.enable_padding(pad_id=bpe_tokenizer.token_to_id("[PAD]"), pad_token="[PAD]")
#bpe_tokenizer.enable_truncation()
bpe_tokenizer.save("bpe_tokenizer.json")


In [None]:
bpe_tokenizer = Tokenizer.from_file("bpe_tokenizer.json")
bpe_tokenizer.enable_truncation(2000)
label_bpe_tokenizer = deepcopy(bpe_tokenizer)
label_bpe_tokenizer.post_processor = TemplateProcessing(
    single="[SOL] $A [EOL]",
    pair="[SOL] $A [SEP] $B:1 [EOL]:1",
    special_tokens=[
        ("[SOL]", bpe_tokenizer.token_to_id("[SOL]")),
        ("[SEP]", bpe_tokenizer.token_to_id("[SEP]")),
        ("[EOL]", bpe_tokenizer.token_to_id("[EOL]")),
    ],
)

In [None]:
#files = [str(DATASET_PATH / Path("code_raw.csv"))] # USE ITERATOR INSTEAD!!!
# dataset = datasets.load_dataset("csv", data_files=str(DATASET_PATH / Path("dataset_pylint.csv")))
# def batch_iterator(batch_size=1000):
#     # Only keep the text column to avoid decoding the rest of the columns unnecessarily
#     tok_dataset = dataset.select_columns("code")
#     for batch in tok_dataset.iter(batch_size):
#         yield batch["code"]

In [None]:
example = '''def linear(args, output_size, bias, bias_start=0.0, scope=None, squeeze=False, wd=0.0, input_keep_prob=1.0,
           is_train=None):#, name_w='', name_b=''
    # if args is None or (nest.is_sequence(args) and not args):
    #     raise ValueError(""`args` must be specified"")
    # if not nest.is_sequence(args):
    #     args = [args]'''
bpe_tokenizer.encode(example).tokens

In [None]:
label_bpe_tokenizer.encode(example).tokens

# Model Architecture

In [None]:
# sinusoidal positional embeds
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb


# Define a module for attention blocks
class AttentionBlock(nn.Module):
    def __init__(self, hidden_size=128, num_heads=4, masking=True):
        super(AttentionBlock, self).__init__()
        self.masking = masking

        # Multi-head attention mechanism
        self.multihead_attn = nn.MultiheadAttention(hidden_size,
                                                    num_heads=num_heads,
                                                    batch_first=True,
                                                    dropout=0.25)

    def forward(self, x_in, kv_in, key_mask=None):
        # Apply causal masking if enabled
        if self.masking:
            bs, l, h = x_in.shape
            mask = torch.triu(torch.ones(l, l, device=x_in.device), 1).bool()
        else:
            mask = None
            
        # Perform multi-head attention operation
        return self.multihead_attn(x_in, kv_in, kv_in, attn_mask=mask, key_padding_mask=key_mask)[0]


# Define a module for a transformer block with self-attention and optional causal masking
class TransformerBlock(nn.Module):
    def __init__(self, hidden_size=128, num_heads=4, is_decoder=False, masking=True):
        super(TransformerBlock, self).__init__()
        self.is_decoder = is_decoder

        # Layer normalization for the input
        self.norm1 = nn.LayerNorm(hidden_size)
        # Self-attention mechanism
        self.attn1 = AttentionBlock(hidden_size=hidden_size, num_heads=num_heads, masking=masking)
        
        # Layer normalization for the output of the first attention layer
        if self.is_decoder:
            self.norm2 = nn.LayerNorm(hidden_size)
            # Self-attention mechanism for the decoder with no masking
            self.attn2 = AttentionBlock(hidden_size=hidden_size, num_heads=num_heads, masking=False)
        
        # Layer normalization for the output before the MLP
        self.norm_mlp = nn.LayerNorm(hidden_size)
        # Multi-layer perceptron (MLP)
        self.mlp = nn.Sequential(nn.Linear(hidden_size, hidden_size * 4),
                                 nn.ELU(),
                                 nn.Linear(hidden_size * 4, hidden_size))
                
    def forward(self, x, input_key_mask=None, cross_key_mask=None, kv_cross=None):
        # Perform self-attention operation
        x = self.attn1(x, x, key_mask=input_key_mask) + x
        x = self.norm1(x)

        # If decoder, perform additional cross-attention layer
        if self.is_decoder:
            x = self.attn2(x, kv_cross, key_mask=cross_key_mask) + x
            x = self.norm2(x)

        # Apply MLP and layer normalization
        x = self.mlp(x) + x
        return self.norm_mlp(x)
    
    
# Define an encoder module for the Transformer architecture
class Encoder(nn.Module):
    def __init__(self, num_emb, hidden_size=128, num_layers=3, num_heads=4):
        super(Encoder, self).__init__()
        
        # Create an embedding layer for tokens
        self.embedding = nn.Embedding(num_emb, hidden_size)
        # Initialize the embedding weights
        self.embedding.weight.data = 0.001 * self.embedding.weight.data

        # Initialize sinusoidal positional embeddings
        self.pos_emb = SinusoidalPosEmb(hidden_size)
        
        # Create multiple transformer blocks as layers
        self.blocks = nn.ModuleList([
            TransformerBlock(hidden_size, num_heads, is_decoder=False, masking=False) for _ in range(num_layers)
        ])
                
    def forward(self, input_seq, padding_mask=None):        
        # Embed the input sequence
        input_embs = self.embedding(input_seq)
        bs, l, h = input_embs.shape

        # Add positional embeddings to the input embeddings
        seq_indx = torch.arange(l, device=input_seq.device)
        pos_emb = self.pos_emb(seq_indx).reshape(1, l, h).expand(bs, l, h)
        embs = input_embs + pos_emb
        
        # Pass the embeddings through each transformer block
        for block in self.blocks:
            embs = block(embs, input_key_mask=padding_mask)
        
        return embs

    
# Define a decoder module for the Transformer architecture
class Decoder(nn.Module):
    def __init__(self, num_emb, hidden_size=128, num_layers=3, num_heads=4):
        super(Decoder, self).__init__()
        
        # Create an embedding layer for tokens
        self.embedding = nn.Embedding(num_emb, hidden_size)
        # Initialize the embedding weights
        self.embedding.weight.data = 0.001 * self.embedding.weight.data

        # Initialize sinusoidal positional embeddings
        self.pos_emb = SinusoidalPosEmb(hidden_size)
        
        # Create multiple transformer blocks as layers
        self.blocks = nn.ModuleList([
            TransformerBlock(hidden_size, num_heads, is_decoder=True, masking=True) for _ in range(num_layers)
        ])
                
        # Define a linear layer for output prediction
        self.fc_out = nn.Linear(hidden_size, num_emb)
        
    def forward(self, input_seq, encoder_output, input_padding_mask=None, encoder_padding_mask=None):        
        # Embed the input sequence
        input_embs = self.embedding(input_seq)
        bs, l, h = input_embs.shape

        # Add positional embeddings to the input embeddings
        seq_indx = torch.arange(l, device=input_seq.device)
        pos_emb = self.pos_emb(seq_indx).reshape(1, l, h).expand(bs, l, h)
        embs = input_embs + pos_emb
        
        # Pass the embeddings through each transformer block
        for block in self.blocks:
            embs = block(embs,
                         input_key_mask=input_padding_mask,
                         cross_key_mask=encoder_padding_mask, 
                         kv_cross=encoder_output)
        
        return self.fc_out(embs)

    
# Define an Encoder-Decoder module for the Transformer architecture
class EncoderDecoder(nn.Module):
    def __init__(self, num_emb, hidden_size=128, num_layers=(3, 3), num_heads=4):
        super(EncoderDecoder, self).__init__()
        
        # Create an encoder and decoder with specified parameters
        self.encoder = Encoder(num_emb=num_emb, hidden_size=hidden_size, 
                               num_layers=num_layers[0], num_heads=num_heads)
        
        self.decoder = Decoder(num_emb=num_emb, hidden_size=hidden_size, 
                               num_layers=num_layers[1], num_heads=num_heads)

    def forward(self, input_seq, target_seq):
        # Generate padding masks for input and target sequences
        input_key_mask = input_seq == 0
        output_key_mask = target_seq == 0

        # Encode the input sequence
        encoded_seq = self.encoder(input_seq=input_seq, 
                                   padding_mask=input_key_mask)
        
        # Decode the target sequence using the encoded sequence
        decoded_seq = self.decoder(input_seq=target_seq, 
                                   encoder_output=encoded_seq, 
                                   input_padding_mask=output_key_mask, 
                                   encoder_padding_mask=input_key_mask)

        return decoded_seq

In [None]:
# Check if GPU is available, set device accordingly
device = torch.device(0 if torch.cuda.is_available() else 'cpu')

# Embedding Size
hidden_size = 512

# Number of Transformer blocks for the (Encoder, Decoder)
num_layers = (4, 4)

# MultiheadAttention Heads
num_heads = 8

# Create model
tf_generator = EncoderDecoder(num_emb=bpe_tokenizer.get_vocab_size(), num_layers=num_layers, 
                              hidden_size=hidden_size, num_heads=num_heads).to(device)

# Initialize the optimizer with above parameters
optimizer = optim.Adam(tf_generator.parameters(), lr=learning_rate)

# Scaler for mixed precision training
scaler = torch.amp.GradScaler("cuda")

# Define the loss function
loss_fn = nn.CrossEntropyLoss(reduction="none")

# Initialize the training loss logger
training_loss_logger = []

start_epoch = 0

In [None]:
## Load Checkpoint
# cp = torch.load("qa_model.pt")
# tf_generator.load_state_dict(cp["model_state_dict"])
# optimizer.load_state_dict(cp["optimizer_state_dict"])
# training_loss_logger = cp["data_logger"]
# start_epoch = cp["epoch"]

In [None]:
# Let's see how many Parameters our Model has!
num_model_params = 0
for param in tf_generator.parameters():
    num_model_params += param.flatten().shape[0]

print("-This Model Has %d (Approximately %d Million) Parameters!" % (num_model_params, num_model_params//1e6))

In [None]:
# Iterate over epochs
for epoch in trange(start_epoch, nepochs, leave=False, desc="Epoch"):
    # Set the model in training mode
    tf_generator.train()
    
    # Iterate over the training data loader
    for code, label in tqdm(data_loader_train, desc="Training", leave=False):
        # Convert question and answer text to tokens and move to device
        code_tokens = torch.tensor([encoding.ids for encoding in bpe_tokenizer.encode_batch(list(code))]).to(device) # NEED TO BE TESTED
        label_tokens = torch.tensor([encoding.ids for encoding in label_bpe_tokenizer.encode_batch(list(label))]).to(device)
        label_input_text = label_tokens[:, 0:-1]
        label_output_text = label_tokens[:, 1:]

        # Forward pass
        with torch.amp.autocast("cuda"):
            pred = tf_generator(code_tokens, label_input_text)

            # Generate mask for output text
            output_mask = (label_output_text != bpe_tokenizer.token_to_id("[PAD]")).float() # NEED TO BE TESTED

            # Compute the loss
            loss = (loss_fn(pred.transpose(1, 2), label_output_text) * output_mask).sum()/output_mask.sum()

        # Backpropagation
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        # Log the training loss
        training_loss_logger.append(loss.item())
    
    # Quick save of the model every epoch
    torch.save({'epoch': epoch + 1,
                'data_logger': training_loss_logger,
                'model_state_dict': tf_generator.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                 }, "qa_model.pt")

In [None]:
code_tokens.size()

In [None]:
_ = plt.figure(figsize=(10, 5))
_ = plt.plot(training_loss_logger[10:])
_ = plt.title("Training Loss")

# Testing

In [None]:
# Get a batch of question and answer text from the test data loader
q_text, a_text = next(iter(data_loader_train))

In [None]:
# Choose an index within the batch
index = 0

# Print the code at the chosen index
print("Input Code:")
print(q_text[index])
input_q_tokens = torch.tensor(bpe_tokenizer.encode(q_text[index]).ids).unsqueeze(0)

# Print the original label text at the chosen index
print("\nOriginal Label:")
print(a_text[index])
soa_token = label_bpe_tokenizer.token_to_id("[SOL]") * torch.ones(1, 1).long()
temp = 0.3

In [None]:
log_tokens = [soa_token]
tf_generator.eval()

with torch.no_grad():
    # Encode the input question tokens
    encoded_seq = tf_generator.encoder(input_q_tokens.to(device))

    # Generate the answer tokens
    for i in range(100):
        input_tokens = torch.cat(log_tokens, 1)
        
        # Decode the input tokens into the next predicted tokens
        data_pred = tf_generator.decoder(input_tokens.to(device), encoded_seq)
        
        # Sample from the distribution of predicted probabilities
        dist = Categorical(logits=data_pred[:, -1] / temp)
        next_tokens = dist.sample().reshape(1, 1)
        
        # Append the next predicted token to the sequence
        log_tokens.append(next_tokens.cpu())
        
        # Break the loop if the End-Of-Answer token is predicted
        if next_tokens.item() == label_bpe_tokenizer.token_to_id("[EOL]"):
            break

In [None]:
# Convert the list of token indices to a tensor
pred_text = torch.cat(log_tokens, 1)

# Convert the token indices to their corresponding strings using the vocabulary
pred_text_strings = label_bpe_tokenizer.decode(pred_text[0].numpy())

# Join the token strings to form the predicted text
pred_text = "".join(pred_text_strings)

# Print the predicted text
print(pred_text)