In [1]:
# In this notebook, you learn:
#
# 1) How to use KL divergence loss in pytorch?

In [2]:
import math
import torch
from torch import nn, Tensor
from typing import Optional, Tuple

## [torch.nn.KLDivLoss](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html#kldivloss)

In [3]:
# Resources to learn about KL divergence:
# 
# 1) https://www.countbayesie.com/blog/2017/5/9/kullback-leibler-divergence-explained
#       -- Gives and intuitive explanation of KL Divergence
# 2) https://encord.com/blog/kl-divergence-in-machine-learning/
#       -- Similar to 1 but explains more in the context of machine learning.
# 3) https://dibyaghosh.com/blog/probability/kldivergence.html
#       -- Explains the math behind KL Divergence.

In [8]:
# Some useful helper functions to assist while experimenting with KL Divergence loss.
# 
# Just logs the tensor in a standard format. 
def LogInputTensor(input: Tensor, name: Optional[str]=None) -> None:
    print("shape: ", input.shape)
    print(f"{name}: \n", input)
    print("-" * 150)

# Generates a batch of input data. Output is 2D tensor of shape [batch_size, num_classes] within
# the range [low, high).
def generate_batch_of_input_data(batch_size: int, num_classes: int, low: Optional[float]=0.0, high: Optional[float]=1.0) -> Tensor:
    return (torch.rand(size=(batch_size, num_classes), dtype=torch.float32) * (high - low)) + low

# Applies log softmax to the input tensor.
def apply_log_softmax(input: Tensor) -> Tensor:
    # This is the natural log (i.e., base e) of the softmax function.
    log_softmax = nn.LogSoftmax(dim=-1)
    return log_softmax(input)

# Applies softmax to the input tensor.
def apply_softmax(input: Tensor) -> Tensor:
    softmax = nn.Softmax(dim=-1)
    return softmax(input)

# Generates sample data to experiment with KL Divergence loss.
# Returns (predictions, targets).
def generate_sample_data(batch_size: int, num_classes: int) -> Tuple[Tensor, Tensor]:
    predictions = generate_batch_of_input_data(batch_size, num_classes)
    # nn.KLDivLoss expects the input to be log probabilities. 
    log_predictions = apply_log_softmax(predictions)
    # Log the input tensor in the standard format.
    LogInputTensor(input=log_predictions, name="log predictions")
    targets = generate_batch_of_input_data(batch_size, num_classes)
    # Applying softmax to make sure that all the probabilities in the targets sum to 1.
    targets = apply_softmax(targets)
    # Log the input tensor in the standard format.
    LogInputTensor(input=targets, name="targets")
    return log_predictions, targets

In [9]:
# When reduction is set to "none" (actually anything other than "mean", "sum" or "batchmean"), the loss is not reduced.
# That means, for each point in the probability distribution, the contribution towards the KL divergence is calculated
# and this is returned directly without summing up the contributions which gives the actual KL divergence value by 
# mathematical definition.
# 
# The mathematical formula for KL divergence is --> KL(P || Q) = sum(P(x) * log(P(x) / Q(x)))
# where P and Q are probability distributions.
#
# In our example below, we are calculating KL divergence between two probability distributions, targets (P) and 
# predictions (Q). Lets say targets = [p1, p2, p3, p4, p5] and predictions = [log(q1), log(q2), log(q3), log(q4), log(q5)]. 
# Here log is the natural logarithm. It's doesn't really matter whether we use natual logarithm or logarithm with any other 
# base since the base of the logarithm cancels out (we calculate log(p/q)) in the KL divergence formula.
# Then the nn.KLDivLoss now calculates the contribution of each point in (P, Q) towards the KL divergence and returns it.
# So, the output will be [p1 * log(p1 / q1), p2 * log(p2 / q2), p3 * log(p3 / q3), p4 * log(p4 / q4), p5 * log(p5 / q5)]
kl_loss_1 = nn.KLDivLoss(reduction="none")
print(kl_loss_1)
print("-" * 150)
log_predictions_1, targets_1 = generate_sample_data(batch_size=1, num_classes=5)
loss_output_1 = kl_loss_1(log_predictions_1, targets_1)
LogInputTensor(input=loss_output_1, name="loss_output_1")
# Lets show that KL divergence contributions calculated by nn.KLDivLoss are same as the ones calculated manually.
for i in range(5):
    # Calculating the contribution of each point in (P, Q) towards the KL divergence.
    kl_divergence_contribution_1 = targets_1[0][i].item() * (math.log(targets_1[0][i].item()) - log_predictions_1[0][i].item())
    # Both (explicitly calculated value, output of pytorch library) the values should be the same.
    print("calculated KL divergence contribution: ", kl_divergence_contribution_1, " ", f"loss_output_1[0][{i}]: ", loss_output_1[0][i].item())

KLDivLoss()
------------------------------------------------------------------------------------------------------------------------------------------------------
shape:  torch.Size([1, 5])
log predictions: 
 tensor([[-1.6134, -1.6146, -2.0372, -1.3749, -1.5208]])
------------------------------------------------------------------------------------------------------------------------------------------------------
shape:  torch.Size([1, 5])
targets: 
 tensor([[0.2921, 0.2377, 0.1326, 0.1551, 0.1825]])
------------------------------------------------------------------------------------------------------------------------------------------------------
shape:  torch.Size([1, 5])
loss_output_1: 
 tensor([[ 0.1117,  0.0423,  0.0022, -0.0758, -0.0329]])
------------------------------------------------------------------------------------------------------------------------------------------------------
calculated KL divergence contribution:  0.11173786103748762   loss_output_1[0][0]:  0.1117378

In [10]:
# When reduction is set to "sum", the contribution from each class is summed up and the final value is returned.
# This is the actual KL divergence value (by mathematical definition) between the two probability distributions.
kl_loss_2 = nn.KLDivLoss(reduction="sum")
print(kl_loss_2)
print("-" * 150)
log_predictions_2, targets_2 = generate_sample_data(batch_size=1, num_classes=5)
loss_output_2 = kl_loss_2(log_predictions_2, targets_2)
LogInputTensor(input=loss_output_2, name="loss_output_2")
kl_divergence_2 = 0.0
# Lets show that KL divergence calculated by nn.KLDivLoss is the same as the value calculated manually.
for i in range(5):
    # kl_divergence is just the sum of all the contributions from each class.
    kl_divergence_2 += targets_2[0][i].item() * (math.log(targets_2[0][i].item()) - log_predictions_2[0][i].item())
# Both (explicitly calculated value, output of pytorch library) the values should be the same.
print("calculated KL divergence: ", kl_divergence_2, " ", f"loss_output_2: ", loss_output_2.item())

KLDivLoss()
------------------------------------------------------------------------------------------------------------------------------------------------------
shape:  torch.Size([1, 5])
log predictions: 
 tensor([[-1.7925, -1.1957, -2.0741, -1.8951, -1.3666]])
------------------------------------------------------------------------------------------------------------------------------------------------------
shape:  torch.Size([1, 5])
targets: 
 tensor([[0.1440, 0.1824, 0.2032, 0.1891, 0.2813]])
------------------------------------------------------------------------------------------------------------------------------------------------------
shape:  torch.Size([])
loss_output_2: 
 tensor(0.0555)
------------------------------------------------------------------------------------------------------------------------------------------------------
calculated KL divergence:  0.05545330548664444   loss_output_2:  0.055453330278396606


In [11]:
# When reduction is set to "batchmean", the KL Divergence from each example is summed up and then divided by the batch size.
kl_loss_3 = nn.KLDivLoss(reduction="batchmean")
print(kl_loss_3)
print("-" * 150)
log_predictions_3, targets_3 = generate_sample_data(batch_size=1, num_classes=5)
loss_output_3 = kl_loss_3(log_predictions_3, targets_3)
LogInputTensor(input=loss_output_3, name="loss_output_3")
kl_divergence_contribution_3 = 0.0
# Lets show that KL divergence calculated by nn.KLDivLoss is the same as the value calculated manually.
for i in range(5):
    kl_divergence_contribution_3 += targets_3[0][i].item() * (math.log(targets_3[0][i].item()) - log_predictions_3[0][i].item())
# Both (explicitly calculated value, output of pytorch library) the values should be the same. In this case, since we just 
# have a single example (batch_size=1), the batch_mean value should be the same as the KL divergence value for the calculated 
# sentence. It is the same as the KL divergence value when reduction is set to "sum" since batch_size is 1.
kl_divergence_batch_mean_3 = kl_divergence_contribution_3 / 1
print("calculated KL divergence: ", kl_divergence_batch_mean_3, " ", f"loss_output_3: ", loss_output_3.item())

KLDivLoss()
------------------------------------------------------------------------------------------------------------------------------------------------------
shape:  torch.Size([1, 5])
log predictions: 
 tensor([[-1.4585, -1.5000, -1.4662, -1.7847, -1.9265]])
------------------------------------------------------------------------------------------------------------------------------------------------------
shape:  torch.Size([1, 5])
targets: 
 tensor([[0.1151, 0.1317, 0.2311, 0.2611, 0.2611]])
------------------------------------------------------------------------------------------------------------------------------------------------------
shape:  torch.Size([])
loss_output_3: 
 tensor(0.1175)
------------------------------------------------------------------------------------------------------------------------------------------------------
calculated KL divergence:  0.11753872195015953   loss_output_3:  0.11753872036933899


In [12]:
# When reduction is set to "batchmean", the KL Divergence from each example is summed up and then divided by the batch size.
# Lets set the batch_size to 2 and see how the KL divergence is calculated.
kl_loss_4 = nn.KLDivLoss(reduction="batchmean")
print(kl_loss_4)
print("-" * 150)
log_predictions_4, targets_4 = generate_sample_data(batch_size=2, num_classes=5)
loss_output_4 = kl_loss_4(log_predictions_4, targets_4)
LogInputTensor(input=loss_output_4, name="loss_output_4")
kl_divergence_batch_mean_4 = 0.0
# Iterates over examples. Example corresponds to one (input, output) pair in the batch.
for ex_idx in range(2):
    kl_divergence_for_example_4 = 0.0
    # Iterates over classes within an example.
    for prob_idx in range(5):
        kl_divergence_for_example_4 += targets_4[ex_idx][prob_idx].item() * (math.log(targets_4[ex_idx][prob_idx].item()) - log_predictions_4[ex_idx][prob_idx].item())
    kl_divergence_batch_mean_4 += kl_divergence_for_example_4
# We have two examples in the batch. So, we divide the sum by 2.
kl_divergence_batch_mean_4 /= 2
# Both (explicitly calculated value, output of pytorch library) the values should be the same.
print("calculated KL divergence: ", kl_divergence_batch_mean_4, " ", f"loss_output_4: ", loss_output_4.item())

KLDivLoss()
------------------------------------------------------------------------------------------------------------------------------------------------------
shape:  torch.Size([2, 5])
log predictions: 
 tensor([[-1.4476, -1.5270, -1.9493, -2.1356, -1.2478],
        [-1.9801, -1.3222, -1.4569, -2.0759, -1.4398]])
------------------------------------------------------------------------------------------------------------------------------------------------------
shape:  torch.Size([2, 5])
targets: 
 tensor([[0.1750, 0.1804, 0.1893, 0.1756, 0.2796],
        [0.2650, 0.1472, 0.1578, 0.2739, 0.1562]])
------------------------------------------------------------------------------------------------------------------------------------------------------
shape:  torch.Size([])
loss_output_4: 
 tensor(0.1018)
------------------------------------------------------------------------------------------------------------------------------------------------------
calculated KL divergence:  0.1018

In [13]:
# Lets try to see what happens when the probability of a target class is 0. The contributions from the 
# classes where the target probability is 0 should be 0. Lets see if this is the case.
kl_div_loss_4_5 = nn.KLDivLoss(reduction="none")
print(kl_div_loss_4_5)
print("-" * 150)
predictions_4_5 = torch.rand(size=(1, 5), dtype=torch.float32)
log_predictions_4_5 = apply_log_softmax(predictions_4_5)
LogInputTensor(input=log_predictions_4_5, name="predictions_4_5")
targets_4_5 = torch.tensor([[0.0, 0.0, 0.0, 0.0, 1.0]], dtype=torch.float32)
LogInputTensor(input=targets_4_5, name="targets_4_5")
kl_loss_output_4_5 = kl_div_loss_4_5(log_predictions_4_5, targets_4_5)
LogInputTensor(input=kl_loss_output_4_5, name="kl_loss_output_4_5")

KLDivLoss()
------------------------------------------------------------------------------------------------------------------------------------------------------
shape:  torch.Size([1, 5])
predictions_4_5: 
 tensor([[-1.7047, -1.9487, -1.6372, -1.9514, -1.0815]])
------------------------------------------------------------------------------------------------------------------------------------------------------
shape:  torch.Size([1, 5])
targets_4_5: 
 tensor([[0., 0., 0., 0., 1.]])
------------------------------------------------------------------------------------------------------------------------------------------------------
shape:  torch.Size([1, 5])
kl_loss_output_4_5: 
 tensor([[0.0000, 0.0000, 0.0000, 0.0000, 1.0815]])
------------------------------------------------------------------------------------------------------------------------------------------------------


## Understanding nn.KLDivLoss when the input and target are 3D tensors.

In [14]:
# The transformers usually deal with 3D tensors of shape [batch_size, seq_len, num_classes]. Each sequence in the batch
# contains multiple tokens (seq_len) and each token is represented as a probability distribution over num_classes.

In [18]:
def generate_3D_batch_of_input_data(batch_size: int, seq_len: int, vocab_size: int, low: Optional[float]=0.0, high: Optional[float]=1.0) -> Tensor:
    return (torch.rand(size=(batch_size, seq_len, vocab_size), dtype=torch.float32) * (high - low)) + low

# Generates sample data to experiment with KL Divergence loss.
# Returns (predictions, targets).
def generate_3D_sample_data(batch_size: int, seq_len: int, vocab_size: int) -> Tuple[Tensor, Tensor]:
    predictions = generate_3D_batch_of_input_data(batch_size, seq_len, vocab_size)
    log_predictions = apply_log_softmax(predictions)
    LogInputTensor(input=log_predictions, name="log predictions")
    targets = generate_3D_batch_of_input_data(batch_size, seq_len, vocab_size)
    targets = apply_softmax(targets)
    LogInputTensor(input=targets, name="targets")
    return log_predictions, targets

In [19]:
# Number of examples (sentences) in the input batch.
batch_size = 2
# Number of tokens in each sentence.
seq_len = 2
# Number of classes (vocabulary size).
num_classes = 5

In [20]:
# Lets calculate KL Divergence using the nn.KLDivLoss with reduction set to "none".
log_predictions_5, targets_5 = generate_3D_sample_data(batch_size, seq_len, num_classes)
kl_div_loss_5 = nn.KLDivLoss(reduction='none')
loss_output_5 = kl_div_loss_5(log_predictions_5, targets_5)
LogInputTensor(input=loss_output_5, name="loss_output_5")
# Now lets try to calculate the KL divergence contribution for each point in the probability distribution and compare
# it with the outputs returned by the pytorch library.
for seq_idx in range(batch_size):
    for token_idx in range(seq_len):
        for class_idx in range(num_classes):
            kl_divergence_contribution_5 = targets_5[seq_idx][token_idx][class_idx].item() * (math.log(targets_5[seq_idx][token_idx][class_idx].item()) - log_predictions_5[seq_idx][token_idx][class_idx].item())
            # Both the values should be the same as explained in the 2D case.
            print("calculated KL divergence: ", kl_divergence_contribution_5, " ", f"loss_output_5[{seq_idx}][{token_idx}][{class_idx}]: ", loss_output_5[seq_idx][token_idx][class_idx].item())

shape:  torch.Size([2, 2, 5])
log predictions: 
 tensor([[[-1.8828, -1.8708, -1.4231, -1.2687, -1.7623],
         [-1.9382, -1.3630, -1.9212, -1.7678, -1.2622]],

        [[-1.5502, -1.7241, -1.3850, -1.7437, -1.6914],
         [-1.6912, -1.6190, -1.4458, -1.6422, -1.6686]]])
------------------------------------------------------------------------------------------------------------------------------------------------------
shape:  torch.Size([2, 2, 5])
targets: 
 tensor([[[0.2483, 0.2509, 0.2331, 0.1465, 0.1211],
         [0.1610, 0.2124, 0.1847, 0.3007, 0.1412]],

        [[0.2326, 0.1918, 0.2423, 0.1271, 0.2061],
         [0.2354, 0.1480, 0.1996, 0.2189, 0.1981]]])
------------------------------------------------------------------------------------------------------------------------------------------------------
shape:  torch.Size([2, 2, 5])
loss_output_5: 
 tensor([[[ 0.1217,  0.1225, -0.0077, -0.0955, -0.0423],
         [ 0.0180, -0.0395,  0.0428,  0.1703, -0.0982]],

        [[ 

In [21]:
# Lets calculate KL Divergence using the nn.KLDivLoss with reduction set to "batchmean".
log_predictions_6, targets_6 = generate_3D_sample_data(batch_size, seq_len, num_classes)
kl_div_loss_6 = nn.KLDivLoss(reduction='batchmean')
loss_output_6 = kl_div_loss_6(log_predictions_6, targets_6)
LogInputTensor(input=loss_output_6, name="loss_output_6")
kl_divergence_batch_mean_6 = 0.0
for seq_idx in range(batch_size):
    kl_divergence_for_sentence_6 = 0.0
    for token_idx in range(seq_len):
        for class_idx in range(num_classes):
            kl_divergence_for_sentence_6 += targets_6[seq_idx][token_idx][class_idx].item() * (math.log(targets_6[seq_idx][token_idx][class_idx].item()) - log_predictions_6[seq_idx][token_idx][class_idx].item())
    kl_divergence_batch_mean_6 += kl_divergence_for_sentence_6
# Take the mean of the KL divergence for each sentence in the batch.
kl_divergence_batch_mean_6 /= batch_size
# Both the values should be the same.
print("calculated KL divergence: ", kl_divergence_batch_mean_6, " ", f"loss_output_6: ", loss_output_6.item())

shape:  torch.Size([2, 2, 5])
log predictions: 
 tensor([[[-1.1395, -1.9188, -1.8255, -1.5439, -1.8415],
         [-2.0697, -1.4438, -2.0117, -1.5855, -1.2068]],

        [[-1.5631, -1.6739, -1.4507, -1.6425, -1.7425],
         [-1.4436, -1.5906, -1.8617, -1.9599, -1.3324]]])
------------------------------------------------------------------------------------------------------------------------------------------------------
shape:  torch.Size([2, 2, 5])
targets: 
 tensor([[[0.1868, 0.1228, 0.1924, 0.2534, 0.2446],
         [0.2133, 0.1050, 0.2529, 0.2306, 0.1983]],

        [[0.2136, 0.1437, 0.2141, 0.2467, 0.1819],
         [0.2857, 0.2436, 0.1803, 0.1605, 0.1299]]])
------------------------------------------------------------------------------------------------------------------------------------------------------
shape:  torch.Size([])
loss_output_6: 
 tensor(0.1308)
---------------------------------------------------------------------------------------------------------------------

In [22]:
# Lets calculate KL Divergence using the nn.KLDivLoss with reduction set to "sum".
log_predictions_7, targets_7 = generate_3D_sample_data(batch_size, seq_len, num_classes)
kl_div_loss_7 = nn.KLDivLoss(reduction='sum')
loss_output_7 = kl_div_loss_7(log_predictions_7, targets_7)
LogInputTensor(input=loss_output_7, name="loss_output_7")
kl_divergence_sum_7 = 0.0
for seq_idx in range(batch_size):
    for token_idx in range(seq_len):
        for class_idx in range(num_classes):
            kl_divergence_sum_7 += targets_7[seq_idx][token_idx][class_idx].item() * (math.log(targets_7[seq_idx][token_idx][class_idx].item()) - log_predictions_7[seq_idx][token_idx][class_idx].item())
# Both the values should be the same.
print("calculated KL divergence: ", kl_divergence_sum_7, " ", f"loss_output_7: ", loss_output_7.item())

shape:  torch.Size([2, 2, 5])
log predictions: 
 tensor([[[-1.3373, -1.3475, -2.0353, -1.7831, -1.7215],
         [-1.6993, -1.3270, -1.4256, -1.9131, -1.8083]],

        [[-1.7127, -1.3311, -1.5016, -1.7193, -1.8744],
         [-1.8037, -1.5605, -1.4432, -1.9530, -1.3973]]])
------------------------------------------------------------------------------------------------------------------------------------------------------
shape:  torch.Size([2, 2, 5])
targets: 
 tensor([[[0.1858, 0.1909, 0.1305, 0.1875, 0.3053],
         [0.2553, 0.1131, 0.1268, 0.2315, 0.2734]],

        [[0.2598, 0.1946, 0.2566, 0.1539, 0.1350],
         [0.1166, 0.2467, 0.1591, 0.1983, 0.2792]]])
------------------------------------------------------------------------------------------------------------------------------------------------------
shape:  torch.Size([])
loss_output_7: 
 tensor(0.2802)
---------------------------------------------------------------------------------------------------------------------