In [1]:
# In this notebook, you learn:
#
# 1) How to use KL Divergence loss in the transformer model?
#
# Resources to learn more about KL Divergence:
# 1) https://www.countbayesie.com/blog/2017/5/9/kullback-leibler-divergence-explained
#       -- Gives and intuitive explanation of KL Divergence
# 2) https://encord.com/blog/kl-divergence-in-machine-learning/
#       -- Similar to 1 but explains more in the context of machine learning.
# 3) https://dibyaghosh.com/blog/probability/kldivergence.html
#       -- Explains the math behind KL Divergence.

In [2]:
import torch

from torch import nn, Tensor
from typing import Optional

In [3]:
# Index of the padding token or the class label for the padding token.
padding_idx = 2
# Amount of probability to be shared among the tokens excluding correct token and padding tokens.
smoothing = 0.1
# Amount of probability shared with the correct token.
confidence = 1 - smoothing
# Number of classes in the classification problem. It is the size of the vocabulary in transformers.
# It includes the padding token.
vocab_size = 6
# Number of sentences in the batch.
batch_size = 2
# Number of tokens in each sentence. Please note that, for the src sentences the 'seq_len' will
# be 9 and for the tgt sentences, the 'seq_len' will be 8. In this notebook, since we are focusing on
# the KL Divergence loss, we will only consider the tgt sentences and use 'seq_len' variable is not
# distinguished between src and tgt sentences.
seq_len = 8

In [4]:
# Notice that the term token is used loosely here. I am using it to refer to the probability 
# distribution over the vocabulary for a particular token (could be a word). I am also using it 
# to refer to the token (could be a word) itself. Please differentiate the usage based on the 
# context in which the term token is used.
#
# Decoder has an input sentence, a predicted sentence as its output and a target output sentence.
# A sentence is made of multiple tokens. An output token is predicted for each decoder input 
# token. The predicted output token is basically a probability distribution over the target 
# vocabulary. We convert this probability distribution into a token (predicted token) by finding 
# out the token in the target vocabulary with maximum probability. The decoder target output per 
# token is a Label Smoothed version of the one-hot encoded target token i.e., the target output 
# is a probability distribution over the vocabulary. For each predicted token, the KL Divergence 
# is calculated between the predicted distribution and the target distribution. For each sentence, 
# the KL Divergence loss is summed over all the tokens in the sentence. The summed loss is then 
# averaged over all the tokens in the batch. The final loss per token is then used to update the 
# model parameters using backpropagation.

In [5]:
# Lets create a random prediction probability distribution tensor to use with the KL Divergence loss.
predictions = torch.randn(size=(batch_size, seq_len, vocab_size), dtype=torch.float32)
print("shape: ", predictions.shape)
print("predictions: \n", predictions)
print("-" * 150)
# The prediction probability distribution is expected to be in the log space by the KL Divergence 
# loss object in pytorch.
log_softmax = nn.LogSoftmax(dim=-1)
log_softmax_predictions = log_softmax(predictions)
print("shape: ", log_softmax_predictions.shape)
print("log_softmax_predictions: \n", log_softmax_predictions)
print("-" * 150)

shape:  torch.Size([2, 8, 6])
predictions: 
 tensor([[[-1.4787,  0.3393,  0.8852, -2.6089,  0.1993,  1.2473],
         [ 0.7730, -0.5182,  1.1035, -0.9056,  1.9427,  0.0062],
         [-0.3036,  1.5448, -0.7564,  1.4636,  0.1306,  0.1318],
         [ 0.0919, -0.5379, -0.9243,  0.8664, -0.0310, -0.4949],
         [ 0.4991,  1.7868, -0.9093, -0.4672,  0.6120,  0.9794],
         [ 0.9309, -1.6936,  0.0163,  1.6846,  0.1531, -0.8738],
         [ 0.1850, -0.1906, -0.0452,  0.5845,  0.3870, -0.8785],
         [ 0.4096, -1.3969, -1.7540, -1.6454,  0.8043, -0.3939]],

        [[-1.8552, -1.5187, -0.2340, -1.4955, -0.4870,  0.3772],
         [-1.0312, -0.4752, -0.4100,  1.6838, -0.1283, -0.5884],
         [ 2.4114, -0.9768,  0.3813,  0.2883,  0.4947, -1.6481],
         [ 0.9232,  0.1053,  0.8024, -0.2156, -1.3285, -0.0193],
         [-0.5860, -0.6488,  0.0139,  1.1086,  0.4817,  1.6839],
         [-0.8347, -0.7269, -0.0947, -0.3415, -0.6350,  0.6456],
         [ 1.2797, -0.1782,  0.3955,  0.559

In [7]:
# SKIP THIS CELL IF YOU ALREADY LOOKED INTO THE STEP_16 NOTEBOOK. THE CODE FROM PREVIOUS NOTEBOOK IS JUST COPIED HERE 
# TO BE USED IN LOSS COMPUTATION.
# 
# -------------------------------------------------------------------------------------------------------------------
# JUST RUN THIS CELL BLINDLY | JUST RUN THIS CELL BLINDLY | JUST RUN THIS CELL BLINDLY | JUST RUN THIS CELL BLINDLY 
# -------------------------------------------------------------------------------------------------------------------

# This code block is copied from the 'step_19_label_smoothing.ipynb' notebook and discussed in detail there.
# Please refer to that notebook for more details.
class LabelSmoothing(nn.Module):
    def __init__(self, tgt_vocab_size: int, padding_idx: int, smoothing: Optional[int]=0.1):
        super(LabelSmoothing, self).__init__()
        # Number of classes in the classification problem. It is the size of the vocabulary in transformers.
        self.vocab_size = tgt_vocab_size
        # Index of the padding token or the class label for the padding token.
        self.padding_idx = padding_idx
        # Amount of probability to be shared among the tokens excluding correct token and padding tokens.
        self.smoothing = smoothing
        # Amount of probability shared with the correct token.
        self.confidence = 1 - smoothing
    
    def forward(self, targets: Tensor) -> Tensor:
        """Calculates the smoothed probabilities for each of the target tokens within each sentence.

        Args:
            targets (Tensor): The target tensor containing the correct class labels (expected token indices from the 
                              vocab) for each token in the batch. An example target tensor for a batch of 2 sentences
                              each with 8 tokens and 6 possible classes for prediction (including the padding token)
                              would be: [[0, 3, 4, 5, 5, 1, 2, 2], [1, 5, 3, 3, 4, 0, 0, 2]]
                              shape: [batch_size, seq_len]

        Returns:
            Tensor: A smoothed probability distribution (1D tensor) for each target token in the batch.
                    shape: [batch_size, seq_len, vocab_size]                    
        """
        batch_size, seq_len = targets.shape
        # Creating a tensor that will hold the smoothed probabilities for each target token in all the sentences.
        smoothed_probs = torch.zeros(size=(batch_size, seq_len, self.vocab_size), dtype=torch.float32)
        # Filling the entire tensor with the smoothing probability. We will deal with the probabilities of the
        # correct token and padding token later.
        smoothed_probs = smoothed_probs.fill_(value=self.smoothing / (self.vocab_size - 2))
        # Bringing the targets tensor to contain the same number of dimensions as the smoothed_probs tensor to use
        # it with the scatter_ function inorder replace the probabilities in the smoothed_probs tensor in the next 
        # step.
        unsqueezed_targets = targets.unsqueeze(dim=-1)
        # Replacing the probabilities in the smoothed_probs tensor with the confidence probability at the 
        # positions that correspond to the correct class labels (expected output tokens in the target).
        smoothed_probs.scatter_(dim=-1, index=unsqueezed_targets, value=self.confidence)
        # The padding token should not be predicted at all by the model. So, the probability associated with the
        # class label that correspond to the padding token within each target token distribution should be 0. 
        smoothed_probs[:, :, self.padding_idx] = 0
        # The target tensor is appended with the padding tokens at the end. These are just dummy tokens added to bring 
        # all the sentences in the batch to the same length. We don't want the model to consider these tokens at all 
        # in the loss calculation. So, we set the probabilities of the entire rows corresponding to the padding tokens
        # to 0. More about why this setup works is explained in the notebook 'step_17_loss_computation.ipynb'.
        mask = unsqueezed_targets.repeat(1, 1, self.vocab_size) == self.padding_idx
        return smoothed_probs.masked_fill(mask=mask, value=0.0)

# -------------------------------------------------------------------------------------------------------------------
# CELL CONTAINING THE COPIED CODE FROM PREVIOUS NOTEBOOKS ENDS HERE.
# CELL CONTAINING THE COPIED CODE FROM PREVIOUS NOTEBOOKS ENDS HERE.
# CELL CONTAINING THE COPIED CODE FROM PREVIOUS NOTEBOOKS ENDS HERE.
# -------------------------------------------------------------------------------------------------------------------

In [8]:
# Creating random target tensor to use with the Label Smoothing object.
# target = [[0, 3, 4, 5, 5, 1, 2, 2], [1, 5, 3, 3, 4, 0, 0, 2]] 
# target[0][0] = 0 --> For the first sentence and the zeroth token, the expected output token is 0.
# target[0][1] = 3 --> For the first sentence and the first token, the expected output token is 3.
# target[0][2] = 4 --> For the first sentence and the second token, the expected output token is 4.
# ...
# target[1][7] = 2 --> For the second sentence and the seventh token (last), the expected output token is 2.
#       -- 2 is a pad token which means this token is not expected to be used in the loss computation. 
transformer_target_ids = torch.tensor(data=[[0, 3, 4, 5, 5, 1, 2, 2], [1, 5, 3, 3, 4, 0, 0, 2]], dtype=torch.int64)
print("shape: ", transformer_target_ids.shape)
print("transformer_target_ids: \n", transformer_target_ids)
print("-" * 150)
# targets need not be in the log space inorder to be used with the KL Divergence loss unlike with predictions.
# Applying the Label Smoothing to the target tensor before computing the loss.
targets = LabelSmoothing(tgt_vocab_size=vocab_size, padding_idx=padding_idx, smoothing=smoothing)(targets=transformer_target_ids)
print("shape: ", targets.shape)
print("targets: \n", targets)

shape:  torch.Size([2, 8])
transformer_target_ids: 
 tensor([[0, 3, 4, 5, 5, 1, 2, 2],
        [1, 5, 3, 3, 4, 0, 0, 2]])
------------------------------------------------------------------------------------------------------------------------------------------------------
shape:  torch.Size([2, 8, 6])
targets: 
 tensor([[[0.9000, 0.0250, 0.0000, 0.0250, 0.0250, 0.0250],
         [0.0250, 0.0250, 0.0000, 0.9000, 0.0250, 0.0250],
         [0.0250, 0.0250, 0.0000, 0.0250, 0.9000, 0.0250],
         [0.0250, 0.0250, 0.0000, 0.0250, 0.0250, 0.9000],
         [0.0250, 0.0250, 0.0000, 0.0250, 0.0250, 0.9000],
         [0.0250, 0.9000, 0.0000, 0.0250, 0.0250, 0.0250],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.0250, 0.9000, 0.0000, 0.0250, 0.0250, 0.0250],
         [0.0250, 0.0250, 0.0000, 0.0250, 0.0250, 0.9000],
         [0.0250, 0.0250, 0.0000, 0.9000, 0.0250, 0.0250],
         [0.0250, 0.0250, 0.0000, 0

In [None]:
# This information (number of non-pad tokens) is used to calculate the KL Divergence per token in the batch
# i.e., the KL divergence is averaged over all the tokens in the batch. This per token KL Divergence is 
# then used as an objective function to train the whole model.
num_non_pad_tokens = (transformer_target_ids != padding_idx).sum().item()
print("num_non_pad_tokens: ", num_non_pad_tokens)

num_non_pad_tokens:  13


In [73]:
# We use the 'sum' reduction to sum the KL Divergence over all the tokens in all the sentences in the batch. The loss is 
# then averaged over all the tokens in the batch to find the loss per token which is used as the objective function. 
#
# Refer to 'UnderstandingPytorch/miscellaneous/loss_functions.ipynb' (Different repository) to learn more about 
# nn.KVDivLoss. The use of this loss function will get much clearer after going through the 'loss_functions.ipynb' 
# notebook.
kl_div_loss_obj = nn.KLDivLoss(reduction="sum")
print(kl_div_loss_obj)

KLDivLoss()


In [None]:
# Calculates the KL Divergence loss between the model predictions and the targets. The KL Divergence loss is calculated per 
# token in the batch. If the target token is a padding token, the probability of the padding token in the predicted 
# distribution is set to zero (0) during Label Smoothing. If the target probability is set to zero, it will not contribute 
# to the KL Divergence loss i.e., it's contribution to the loss is zero. This is shown in the 
# 'Understanding_Pytorch/miscellaneous/loss_functions.ipynb' notebook. Please refer to that notebook for more details.
model_kl_loss = kl_div_loss_obj(input=log_softmax_predictions, target=targets)
print("shape: ", model_kl_loss.shape)
print("model_kl_loss: \n", model_kl_loss)
print("-" * 150)
# Calculate the KL Divergence loss per token in the batch.
model_kl_loss_per_token = model_kl_loss / num_non_pad_tokens
print("model_kl_loss_per_token: ", model_kl_loss_per_token)

shape:  torch.Size([])
model_kl_loss: 
 tensor(16.2052)
----------------------------------------------------
model_kl_loss_per_token:  tensor(1.2466)


In [75]:
# Combines the code from the above cells into a simple class that computes the KL Divergence loss.
class LossCompute:
    def __init__(self):
        # We use the 'sum' reduction to sum the KL Divergence over all the tokens in all the sentences in the batch. 
        # The loss is then averaged over all the tokens in the batch to find the loss per token which is used as the 
        # objective function.         
        self.kl_div_loss = nn.KLDivLoss(reduction="sum")

    # The '__call__' method allows an object of the class to be called just like a function.
    def __call__(self, log_predictions: Tensor, targets: Tensor, num_non_pad_tokens: int) -> Tensor:
        """Computes the KL Divergence loss for the model predictions and the target tensor.

        Args:
            log_predictions (Tensor): The log of the model predictions for the target tokens in the batch.
                                      Each token has a probability distribution over the vocabulary.
                                      shape: [batch_size, seq_len, vocab_size]
            targets (Tensor): The expected target for the model predictions. The target tensor is a smoothed
                              probability distribution over the vocabulary for each token in the batch. 
                              shape: [batch_size, seq_len, vocab_size]
            num_non_pad_tokens (int): The number of non-pad tokens in the target of the batch.

        Returns:
            Tensor: The KL Divergence per token in the batch which is used as the objective function for model
                    training.
        """
        # Calculates the KL Divergence loss between the model predictions and the targets.
        kl_div_loss = self.kl_div_loss(input=log_predictions, target=targets)
        # Calculate the KL Divergence loss per token in the batch.
        return kl_div_loss / num_non_pad_tokens

In [76]:
loss_compute = LossCompute()
# This should give the same result as the model_kl_loss_per_token calculated above.
model_loss = loss_compute(log_predictions=log_softmax_predictions, targets=targets, num_non_pad_tokens=num_non_pad_tokens)
print(model_loss)

tensor(1.2466)
