In [1]:
# In this notebook, you learn:
#
# 1) How to create Batch objects out of the batched data from dataloaders and use it for training?
# 2) How is masking used in Transformers and how to create these masks for source and target sentences?
#
# NOTE: It might be a good idea to ignore the masking part (this notebook) for now. It might be easier if you go ahead 
# and understand how to implement the MultiHeadAttention, Encoder Layer and Decoder layer without masking and then 
# come back to this notebook to understand masking i.e., probably come back to this after you understand everything 
# until 'step_13_decoder.ipynb'. The shapes here can get a bit confusing and it might be easier to understand them 
# after you have implemented the MultiHeadAttention, Encoder and Decoder without masking. If you are already familiar 
# with MultiHeadAttention, then you can continue with this notebook. 

In [2]:
# Useful respources:
# 
# 1) https://nlp.seas.harvard.edu/annotated-transformer/#batches-and-masking
#       -- The Annotated Transformer by Harvard NLP. It uses a specific format to hold batches and masks. I use the same
#          format in this notebook and explain each part of it.
# 2) https://jalammar.github.io/illustrated-transformer/
#       -- Illustrated Transformer by Jay Alammar.
# 3) https://www.youtube.com/watch?v=IGu7ivuy1Ag
#       -- Explains how the source and target sentences are used by the Encoder and Decoder.
#       -- Useful to understand the logic in target sentence creating within the Batch object below.
# 4) https://www.garysnotebook.com/20210128_1
#       -- Explains how masking is done. Does not explain everything and leaves out some important details.
# 5) https://stats.stackexchange.com/questions/598239/how-is-padding-masking-considered-in-the-attention-head-of-a-transformer
#       -- Discussed about how to use padding mask in target sentences.
#       -- This was the exact confusion I had and I also don't see the need for padding mask in target sentences.

In [3]:
import torch
from torch import Tensor

In [4]:
# The decoder takes a batch of sequences as input and returns a batch of sequences as output. Every sequence in the 
# batch is of same length. The length of the sequence is the number of tokens in it (just for our purposes here). 
# In the decoder, the 'mask' is used to prevent the tokens appearing after the current token to attend to the 
# current token or any token before it i.e., we do not want the model to look at the future tokens when predicting 
# the current token in the translation task. This is done because during inference, the model will not have access 
# to the future tokens. 
#
# The mask is applied when the normalized weights are calculated for the attention mechanism. By applying the mask, 
# the weights of the future tokens are set to zero --> Please follow the 'step_9_multi_headed_attention.ipynb' 
# notebook to exactly understand how the mask is applied. For now, if we ignore the batch dimension, the mask is a 
# square matrix of size (seq_len, seq_len) where seq_len is the length of the sequence. The mask is a lower triangular 
# matrix where the elements above the diagonal are set to False and the elements below the diagonal are set to True. 
# This is because we want to prevent the tokens appearing after the current token to attend to the current token or 
# any token before it. The tokens set to False in a particular row do not attend the token represented by that row.

In [5]:
# Examples constants to experiment with batches and masks.
# Sentence length in a batch. This varies from batch to batch.
seq_len = 10
# Id of the token that marks the beginning of a sentence.
sos_token_id = 0
# Id of the token that marks the end of a sentence.
eos_token_id = 1
# Id of Padding token. This is used to pad the sentences in a batch to make them of same length.
pad_token_id = 2

### Constructing Look Ahead Mask --> Only for target sequences

In [6]:
# Notice that all the elements on the main diagonal and below it are set to zero in the up_traingular_matrix tensor.
# Refer to 'understanding_tensor_manipulations_part_7.ipynb' notebook to understand more about the 'torch.triu' 
# function.
up_triangular_matrix = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.uint8), diagonal=1)
print("shape: ", up_triangular_matrix.shape)
print("up_triangular_matrix: \n", up_triangular_matrix)
print("-" * 150)

# The attention mask can be understood by tagging tokens to rows and columns as shown below.
#
# --------------------------------------------------------------------------------------------------
#              | tok 0 | tok 1 | tok 2 | . . . | tok (n - 1)
#              |------------------------------------------------
# tok 0        |       |       |       |       |
# tok 1        |       |       |       |       |
# tok 2        |       |       |       |       |
# .            |       |       |       |       |
# .            |       |       |       |       |
# .            |       |       |       |       |
# tok (n - 1)  |       |       |       |       |
# ----------------------------------------------------------------------------------------------------
#
# elements in row 'i' tell what all tokens can attend to token in 'i' in the attention score calculation.
#
# attention_mask[0][0] = True --> This means that the first token can attend to itself.
# attention_mask[0][1] = False --> This means that the second token cannot attend to the first token.
#
# In general, 
# attention_mask[i][j] = True --> This means that the jth token can attend to the ith token.
# attention_mask[i][j] = False --> This means that the jth token cannot attend to the ith token.
# Only the weights where mask is True are used in the attention scores calculation.
#
# Lets consider the first token (row 1). Only the first token can attend to itself. The tokens after first token 
# cannot attend to the first token. Hence, in the first row, only the first element is True and the rest are False. 
# The same applies to the second token. The second token can attend to the first token and itself. The tokens 
# after the second token cannot attend to the second token. Hence, the second row has two True elements in the 
# first two indices and the rest are False. This pattern continues for all the tokens in the sentence.
attention_mask = (up_triangular_matrix == 0)
print("shape: ", attention_mask.shape)
print("attention_mask: \n", attention_mask)

shape:  torch.Size([10, 10])
up_triangular_matrix: 
 tensor([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.uint8)
------------------------------------------------------------------------------------------------------------------------------------------------------
shape:  torch.Size([10, 10])
attention_mask: 
 tensor([[ True, False, False, False, False, False, False, False, False, False],
        [ True,  True, False, False, False, False, False, False, False, False],
        [ True,  True,  True, False, False, False, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False, False, False],
        [

In [7]:
# This function is just created by combining the code from the above cells into a single function.
# This function will be used later when creating the mask in the Batch object.
def construct_look_ahead_mask(size: int) -> Tensor:
    """Create a mask to prevent the tokens appearing after the current token 
       to attend to the current token or any token before it.

    Args:
        size (int): Size of the mask to be created i.e., the length of the sentence.

    Returns:
        Tensor: A boolean tensor of shape (size, size).
    """
    attention_mask = torch.triu(torch.ones(size, size, dtype=torch.uint8), diagonal=1)
    return attention_mask == 0

In [8]:
look_ahead_mask = construct_look_ahead_mask(seq_len)
print("shape: ", look_ahead_mask.shape)
print("look_ahead_mask: \n", look_ahead_mask)

shape:  torch.Size([10, 10])
look_ahead_mask: 
 tensor([[ True, False, False, False, False, False, False, False, False, False],
        [ True,  True, False, False, False, False, False, False, False, False],
        [ True,  True,  True, False, False, False, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]])


### Constructing Padding Mask --> Both for source and target sequences

In [9]:
# For now, lets ignore the target sentences and focus on the source sentences.
# When creating the mask for the source sequences, we need to create a mask that prevents the padding tokens from 
# attending to the actual text tokens. The padding tokens are the tokens that are added to the sequences that are 
# shorter than the longest sequence in the batch. The padding tokens are added to the end of the sequence. The 
# mask is created by checking the indices of the padding tokens and setting the corresponding indices in the mask 
# to False.
#
# Here, we are not bothered about the future tokens attending to the current token. This is because the source
# sequences are the input to the Encoder and the Encoder does not have to predict the next token. The encoder 
# only has to encode the input sequences. Hence, we do not need to prevent the future tokens from attending to 
# the current token.

In [10]:
# src_sequence is of the format [token_1 token_2 ... token_n <pad> <pad> ... <pad>]
# where <pad> is the padding token. Notice that src_sequence does not have the '<sos>' and '<eos>' tokens. They
# are only present in the target_sequence and do not serve any purpose in the source_sequence.
src_batch = torch.tensor(data=[[24, 42, 33, 4124, 231, 12321, pad_token_id, pad_token_id, pad_token_id, pad_token_id], 
                               [27, 67, 83, 23124, 131, 1321, 23, 90, pad_token_id, pad_token_id]], dtype=torch.int64)
print("shape: ", src_batch.shape)
print("src_batch: \n", src_batch)

shape:  torch.Size([2, 10])
src_batch: 
 tensor([[   24,    42,    33,  4124,   231, 12321,     2,     2,     2,     2],
        [   27,    67,    83, 23124,   131,  1321,    23,    90,     2,     2]])


In [11]:
src_mask = (src_batch != pad_token_id)
print("shape: ", src_mask.shape)
print("src_mask: \n", src_mask)
print("-" * 150)
# We add a dimension so that the mask can be broadcasted with the batched data when used in transformers. src_mask tensor 
# above contains 1 row per sequence in the batch. The mask need to contain one row per token which contains the 
# information about what other tokens can attend to the token in the current row. Since, we are calculating the padding 
# mask, the mask is same for every single token in one sequence. However, we will not explicitly create the mask for every 
# token in the sequence. We will just create 1 row and allow python's broadcasting to take care of the rest of the rows 
# in the sequence when the mask is used in the transformer. 
# On an additional note, it might be tempting to make the src_mask of shape [batch_size, seq_len, seq_len] instead of 
# leaving it as [batsh_size, 1, seq_len]. However, for the src_mask to be used with self attention in Encoder, the shape 
# should be [batch_size, seq_len, seq_len] and for the src_mask to be used with source attention in Decoder, the shape 
# should be [batch_size, seq_len - 1, seq_len]. So, we will keep the shape of src_mask as [batch_size, 1, seq_len] and 
# then let the model handle the broadcasting of the mask to the required shape. Please note that we have omitted the 
# dimension that corresponds to the number of heads in the mask to keep things simple for now. Lets handle this below.
src_mask = src_mask.unsqueeze(1)
print("shape: ", src_mask.shape)
print("src_mask: \n", src_mask)

shape:  torch.Size([2, 10])
src_mask: 
 tensor([[ True,  True,  True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True, False, False]])
------------------------------------------------------------------------------------------------------------------------------------------------------
shape:  torch.Size([2, 1, 10])
src_mask: 
 tensor([[[ True,  True,  True,  True,  True,  True, False, False, False, False]],

        [[ True,  True,  True,  True,  True,  True,  True,  True, False, False]]])


In [12]:
# This function is just created by combining the code from the above cells into a single function.
def construct_padding_mask(input: Tensor, pad_token_id: int) -> Tensor:
    """Create a mask to prevent the padding tokens from attending to the tokens.

    Args:
        input (Tensor): A batch of sentences of shape (batch_size, seq_len).
        pad_token_id (int): Id of the padding token.

    Returns:
        Tensor: A boolean tensor of shape (batch_size, seq_len, seq_len).
    """
    mask = (input != pad_token_id)
    mask = mask.unsqueeze(1)
    return mask

In [13]:
src_mask_via_function = construct_padding_mask(input=src_batch, pad_token_id=pad_token_id)
print("shape: ", src_mask_via_function.shape)
print("src_mask_via_function: \n", src_mask_via_function)

shape:  torch.Size([2, 1, 10])
src_mask_via_function: 
 tensor([[[ True,  True,  True,  True,  True,  True, False, False, False, False]],

        [[ True,  True,  True,  True,  True,  True,  True,  True, False, False]]])


In [14]:
if torch.equal(src_mask, src_mask_via_function):
    print("Both the source masks are equal as they should be.")
else:
    print("The source masks are not equal. There is some mistake in the code.")

Both the source masks are equal as they should be.


In [15]:
# In the transformers implementation, the same mask is applied to all the heads in the multi-headed attention mechanism. 
# The calculated attention_scores will have the shape [batch_size, num_heads, seq_len, seq_len] in self attention. 
# Hence, the src_mask is updated to have the shape [batch_size, 1, 1, seq_len] which can be broadcasted and applied to 
# all the heads.
src_mask = src_mask.unsqueeze(1)
print("shape: ", src_mask.shape)
print("src_mask: \n", src_mask)

shape:  torch.Size([2, 1, 1, 10])
src_mask: 
 tensor([[[[ True,  True,  True,  True,  True,  True, False, False, False, False]]],


        [[[ True,  True,  True,  True,  True,  True,  True,  True, False, False]]]])


### More About Target Mask

In [16]:
# target sequence is of the format [<sos> token_1 token_2 ... token_n <eos> <pad> <pad> ... <pad>] where <sos> 
# is the start of sequence token, <eos> is the end of sequence token and <pad> is the padding token.
# 
# target sequence is used in two ways in the Decoder during training. The target sequence is used as an input
# to the decoder to predict the next token in the output target sequence. The target sequence is also used to 
# calculate the loss. The target sequence is shifted by one position to the left when used to calculate the loss. 
# This is because the model should predict the next token in the target sequence and not the current token.
#
# target = [<sos> token_1 token_2 ... token_n <eos> <pad> <pad> ... <pad>]
#       -- The target sequence from the data loader.
#       -- Length: L
# target_decoder_input = [<sos> token_1 token_2 ... token_n <eos> <pad> <pad> ... <pad>]
#       -- The last token (<pad> for this example) from the target is removed to create the target_decoder_input.
#       -- Length: (L - 1)
# target_expected_decoder_output = [token_1 token_2 ... token_n <eos> <pad> <pad> ... <pad> <pad>]
#       -- The first token (<sos>) from the target is removed to create the target_expected_decoder_output.
#       -- Used to calculate the loss.
#       -- Length: (L - 1)
#
# The decoder predicts 1 token as output for each token in the input to the decoder. So, the predicted_output 
# will have 1 more token than the target_expected_decoder_output if we just removed <sos> from target to 
# create target_expected_decoder_output and use target without changes as input to the decoder. To resolve this 
# issue, we remove the last token from the target to create target_decoder_input and use it as input to the 
# decoder. The last token in target is either <eos> or <pad> and it doesn't matter if it is removed. The 
# predicted_output and target_expected_decoder_output are then compared and used to calculate the loss.
#
# For now, lets assume that the last token is <eos> and see how the loss is calculated without any problems.
#
# target_decoder_input = [<sos> token_1 token_2 ... token_n]
#       -- <eos> is removed from the end.
# target_expected_decoder_output = [token_1 token_2 ... token_n <eos>]
#       -- <sos> is removed from the start from the original target but <eos> is not touched.
# predicted_output = [output_token_1 output_token_2 ... output_token_n output_token_n+1]
#
# Now, when calculating the loss target_expected_decoder_output is compared to predicted output. The hope is that
# output_token_1 is token_1
# output_token_2 is token_2
# ...
# output_token_n is token_n
# output_token_n+1 is <eos>
#
# Exactly what we want.
#
# If this explanation is not clear, please refer to the below cells to see the various variables created.

In [18]:
target_batch = torch.tensor(data=[[sos_token_id, 12, 3, 3545, eos_token_id, pad_token_id, pad_token_id, pad_token_id, pad_token_id, pad_token_id], 
                                  [sos_token_id, 122, 6, 545, 40, 78, 90, 89, 78, eos_token_id]], dtype=torch.int16)
print("shape: ", target_batch.shape)
print("target_batch: \n", target_batch)

shape:  torch.Size([2, 10])
target_batch: 
 tensor([[   0,   12,    3, 3545,    1,    2,    2,    2,    2,    2],
        [   0,  122,    6,  545,   40,   78,   90,   89,   78,    1]],
       dtype=torch.int16)


In [19]:
# Remove the last token from the original target sentences to create the target_decoder_input.
target_decoder_input = target_batch[:, :-1]
print("shape: ", target_decoder_input.shape)
print("target_decoder_input: \n", target_decoder_input)

shape:  torch.Size([2, 9])
target_decoder_input: 
 tensor([[   0,   12,    3, 3545,    1,    2,    2,    2,    2],
        [   0,  122,    6,  545,   40,   78,   90,   89,   78]],
       dtype=torch.int16)


In [21]:
# The first token is removed from the original target sequences to create the target_expected_decoder_output.
target_expected_decoder_output = target_batch[:, 1:]
print("shape: ", target_expected_decoder_output.shape)
print("target_expected_decoder_output: \n", target_expected_decoder_output)

shape:  torch.Size([2, 9])
target_expected_decoder_output: 
 tensor([[  12,    3, 3545,    1,    2,    2,    2,    2,    2],
        [ 122,    6,  545,   40,   78,   90,   89,   78,    1]],
       dtype=torch.int16)


In [23]:
# We first mask the padding tokens in the target sequences. This is done to prevent the padding tokens
# from attending to the other tokens in the target sequences.
# Notice that the mask is calculated on the target_decoder_input and not on the target. This is because
# the target_decoder_input is used as input to the decoder to predict the next token in the output.
target_padding_mask = construct_padding_mask(input=target_decoder_input, pad_token_id=pad_token_id)
print("shape: ", target_padding_mask.shape)
print("target_padding_mask: \n", target_padding_mask)
print("-" * 150)
# The target_padding_mask is updated to have the shape [batch_size, L - 1, L - 1] instead of leaving
# it as [batch_size, 1, L - 1] like we did with source mask. This is because targets also have look ahead 
# mask which is different for every single token in the target sequence. So, the look ahead mask below
# will have 1 row per token in the sequence and it will be merged with the target_padding_mask to create
# the final mask. Hence, we explicitly create the mask for every token in the target sequence before 
# merging.
target_padding_mask = target_padding_mask.repeat(1, target_decoder_input.size(1), 1)
print("shape of target_padding_mask: ", target_padding_mask.shape)
print("target_padding_mask: \n", target_padding_mask)

shape:  torch.Size([2, 1, 9])
target_padding_mask: 
 tensor([[[ True,  True,  True,  True,  True, False, False, False, False]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True]]])
------------------------------------------------------------------------------------------------------------------------------------------------------
shape of target_padding_mask:  torch.Size([2, 9, 9])
target_padding_mask: 
 tensor([[[ True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True, False

In [24]:
# We now create the mask to prevent the future tokens from attending to the current token or any token before it.
# We combine both the masks using the logical 'and' operator.
# The type_as function is used to ensure that the mask is of the same type as the target_mask. Since, we are using
# the logical 'and' operator, the future_mask should be of the same type as the target_mask.
# Even though '&' is a bitwise 'and' operator, it is overloaded in PyTorch to work with boolean tensors.
# It is the same as applying logical 'and' operator when both the value types being used with '&' are of type 
# torch.bool. 
target_mask = target_padding_mask & construct_look_ahead_mask(target_decoder_input.size(-1)).type_as(target_padding_mask.data)
print("shape: ", target_mask.shape)
print("target_mask: \n", target_mask)

shape:  torch.Size([2, 9, 9])
target_mask: 
 tensor([[[ True, False, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False, False],
         [ True,  True,  True, False, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False]],

        [[ True, False, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False, False],
         [ True,  True,  True, False, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False,

In [25]:
# YOU CAN IGNORE THIS CELL. IT DOESN'T REALLY EFFECT THE MASKS BUT IT IS AN INTERESTING POINT TO 
# THINK ABOUT. IF THE ARGUMENT BELOW IS NOT CLEAR, JUST IGNORE IT.
#
# Looking at the target_mask tensor output in the above cell, there is really no value in using the 
# padding mask in the target sentences. The padding tokens are only at the end of the sentence and 
# the future tokens are not allowed to attend to the earlier tokens because of the look_ahead_mask 
# already. This means the padding tokens are already masked for the actual tokens (part of the 
# sentence) because of the look_ahead_mask. In other words, the padding_mask is not changing 
# anything in the masks for the rows where the tokens are non-padding tokens. The padding_mask only 
# changes the target_mask for the rows where the tokens are padding tokens. This anyway doesn't 
# matter because we don't really use the output of the padding tokens in the final loss calculation.
#
# However, I don't know why most of the implementations still go ahead and use padding mask in the
# target sentences. I am not sure if I am missing something here. So, I am just leaving the code
# as it is incase it is important. If you know the reason, please let me know.

In [26]:
# We could mask every single token for the rows that correspond to padding tokens. This is equivalent 
# to saying that no token can attend to the padding tokens. However, if no token can attend to the 
# padding token and everything in those rows is set to false, then the attention scores will just 
# be equal to some constant value (non-negligible). So, in the end, this doesn't really matter and 
# doesn't give any advantage in computing the attention scores. However, I am showing the 
# implementation here because this is one more way padding mask is used in some alternative 
# implementations of the transformer.
print("shape: ", target_decoder_input.shape)
print("target_decoder_input: \n", target_decoder_input)
print("--------------------------------------------------")
target_padding_mask_alternative = (target_decoder_input != pad_token_id).unsqueeze(dim=1)
print("shape: ", target_padding_mask_alternative.shape)
print("target_padding_mask_alternative: \n", target_padding_mask_alternative)
print("--------------------------------------------------")
# We need to False out the entire rows corresponding to padding tokens. The rows corresponding to the
# padding tokens are the rows where the target_padding_mask_alternative is False. So, we transpose the 
# target_padding_mask_alternative and repeat it 'seq_len' times as a first step.
intermediate_padding_mask = target_padding_mask_alternative.transpose(1, 2)
print("shape: ", intermediate_padding_mask.shape)
print("intermediate_padding_mask: \n", intermediate_padding_mask)
print("--------------------------------------------------")
# We now repeat the column values 'seq_len' times to match the shape of the target_padding_mask.
intermediate_padding_mask = intermediate_padding_mask.repeat(1, 1, target_padding_mask.size(-1))
print("shape: ", intermediate_padding_mask.shape)
print("intermediate_padding_mask: \n", intermediate_padding_mask)
print("--------------------------------------------------")
# Finally, we take the logical 'and' of the two intermediate masks to get the final padding mask.
target_padding_mask_alternative = intermediate_padding_mask & construct_padding_mask(input=target_decoder_input, pad_token_id=pad_token_id).repeat(1, target_decoder_input.size(1), 1)
print("shape: ", target_padding_mask_alternative.shape)
print("target_padding_mask_alternative: \n", target_padding_mask_alternative)

shape:  torch.Size([2, 9])
target_decoder_input: 
 tensor([[   0,   12,    3, 3545,    1,    2,    2,    2,    2],
        [   0,  122,    6,  545,   40,   78,   90,   89,   78]],
       dtype=torch.int16)
--------------------------------------------------
shape:  torch.Size([2, 1, 9])
target_padding_mask_alternative: 
 tensor([[[ True,  True,  True,  True,  True, False, False, False, False]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True]]])
--------------------------------------------------
shape:  torch.Size([2, 9, 1])
intermediate_padding_mask: 
 tensor([[[ True],
         [ True],
         [ True],
         [ True],
         [ True],
         [False],
         [False],
         [False],
         [False]],

        [[ True],
         [ True],
         [ True],
         [ True],
         [ True],
         [ True],
         [ True],
         [ True],
         [ True]]])
--------------------------------------------------
shape:  torch.Size([2, 9, 9])
interme

In [27]:
# We won't be using this function in the transformer implementation since it does not provide any 
# advantages and it is a waste of computation. This is just to show how the padding mask is created 
# in some alternative implementations. This function just takes the code from the above cell and 
# puts it into a single function.
def construct_padding_mask_alternative(input: Tensor, pad_token_id: int) -> Tensor:
    """Creates a padding mask to prevent the padding tokens from attending to the other tokens in
       the input. This is the same for both source and target sentences.

    Args:
        input (Tensor): Input tensor of shape (batch_size, seq_len) where each row is a sentence.
        pad_token_id (int): Id of the padding token.

    Returns:
        Tensor: Returns a padding mask of shape (batch_size, seq_len, seq_len) where each 2D tensor
        (seq_len, seq_len) is a mask for a single sentence.
    """
    padding_mask = (input != pad_token_id).unsqueeze(dim=1)
    intermediate_padding_mask_1 = padding_mask.repeat(1, padding_mask.size(-1), 1)
    intermediate_padding_mask_2 = padding_mask.transpose(1, 2)
    intermediate_padding_mask_2 = intermediate_padding_mask_2.repeat(1, 1, padding_mask.size(-1))
    padding_mask = intermediate_padding_mask_1 & intermediate_padding_mask_2
    return padding_mask

In [28]:
target_padding_mask_alternative_via_function = construct_padding_mask_alternative(input=target_decoder_input, pad_token_id=pad_token_id)
print("shape: ", target_padding_mask_alternative_via_function.shape)
print("target_padding_mask_alternative_via_function: \n", target_padding_mask_alternative_via_function)

shape:  torch.Size([2, 9, 9])
target_padding_mask_alternative_via_function: 
 tensor([[[ True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False],
         [False, False, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False, False]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  

In [29]:
if torch.equal(target_padding_mask_alternative, target_padding_mask_alternative_via_function):
    print("Both the alternative target padding masks are equal as they should be.")
else:
    print("The alternative target padding masks are not equal. There is some mistake in the code.")

Both the alternative target padding masks are equal as they should be.


In [30]:
# In the transformers implementation, the same mask is applied to all the heads in the multi-headed
# attention mechanism. The calculated attention_scores will have the shape 
# (batch_size, num_heads, seq_len, seq_len). Hence, the target_mask is updated to have the shape 
# (batch_size, 1, seq_len, seq_len) which can be broadcasted and applied to all the heads.
target_mask = target_mask.unsqueeze(dim=1)
print("shape: ", target_mask.shape)
print("target_mask: \n", target_mask)

shape:  torch.Size([2, 1, 9, 9])
target_mask: 
 tensor([[[[ True, False, False, False, False, False, False, False, False],
          [ True,  True, False, False, False, False, False, False, False],
          [ True,  True,  True, False, False, False, False, False, False],
          [ True,  True,  True,  True, False, False, False, False, False],
          [ True,  True,  True,  True,  True, False, False, False, False],
          [ True,  True,  True,  True,  True, False, False, False, False],
          [ True,  True,  True,  True,  True, False, False, False, False],
          [ True,  True,  True,  True,  True, False, False, False, False],
          [ True,  True,  True,  True,  True, False, False, False, False]]],


        [[[ True, False, False, False, False, False, False, False, False],
          [ True,  True, False, False, False, False, False, False, False],
          [ True,  True,  True, False, False, False, False, False, False],
          [ True,  True,  True,  True, False, Fa

In [31]:
# To be used in Transformers.
class Batch:
    """Object for holding a batch of data and the corresponding mask to be used for training."""

    def __init__(self, src_batch: Tensor, tgt_batch: Tensor, pad_token_id: int):
        """Initialize the Batch object. Updates the tgt_batch to the format expected by the decoder
           during training. Also, creates the mask for the source and target sentences.

        Args:
            src_batch (Tensor): Tensor containing the source sentences in the batch. 
                                shape: [batch_size, seq_len].
            tgt_batch (Tensor): Tensor containing the target sentences in the batch.
                                shape: [batch_size, seq_len].
            pad_token_id (int): Id of the pad token appended to the sentences in the batch. Usually 
                                set to 2.
        """
        self.src = src_batch
        # The source sentences only need the padding mask since the Encoder does not have to predict
        # the next token in the sentence but just encode the input to be used by the Decoder.
        # Shape of src_mask: [batch_size, 1, 1, seq_len]
        self.src_mask = construct_padding_mask(input=src_batch, pad_token_id=pad_token_id).unsqueeze(1)
        # Removes the last token (<eos> or <pad>) from the target sentences to create the target_decoder_input.
        # Shape of tgt_decoder_input: [batch_size, seq_len - 1]
        self.tgt_decoder_input = tgt_batch[:, :-1]
        # Removes the first token (<sos>) from the target sentences to create the target_expected_decoder_output.
        # Shape of tgt_expected_decoder_output: [batch_size, seq_len - 1]
        self.tgt_expected_decoder_output = tgt_batch[:, 1:]
        # Shape of tgt_mask: [batch_size, 1, seq_len - 1, seq_len - 1]
        self.tgt_mask = self.construct_target_mask(tgt=self.tgt_decoder_input, pad_token_id=pad_token_id).unsqueeze(1)
        # Number of tokens in the target sentences excluding the padding tokens. This is used during model 
        # training for the loss calculation inorder to normalize the total loss and find the loss per token.
        self.non_pad_tokens = (self.tgt_expected_decoder_output != pad_token_id).sum()

    def construct_target_mask(self, tgt: Tensor, pad_token_id: int) -> Tensor:
        # The target sentences need both the padding mask and the look ahead mask. The padding mask is used
        # to prevent the padding tokens from attending to the other tokens in the target sentences. The look
        # ahead mask is used to prevent the future tokens from attending to the current token or any token.
        tgt_mask = construct_padding_mask(input=tgt, pad_token_id=pad_token_id).repeat(1, tgt.size(1), 1)
        tgt_mask = tgt_mask & construct_look_ahead_mask(tgt.size(-1)).type_as(tgt_mask.data)
        return tgt_mask

In [32]:
batch = Batch(src_batch=src_batch, tgt_batch=target_batch, pad_token_id=pad_token_id)
print("shape: ", src_batch.shape)
print("src_batch: \n", src_batch)
print("-" * 150)
print("shape: ", target_batch.shape)
print("target_batch: \n", target_batch)
print("-" * 150)
print("shape: ", batch.src.shape)
print("src in the batch object: \n", batch.src)
print("-" * 150)
print("shape: ", batch.src_mask.shape)
print("src_mask: \n", batch.src_mask)
print("-" * 150)
print("shape: ", batch.tgt_decoder_input.shape)
print("tgt_decoder_input: \n", batch.tgt_decoder_input)
print("-" * 150)
print("shape: ", batch.tgt_expected_decoder_output.shape)
print("tgt_expected_decoder_output: \n", batch.tgt_expected_decoder_output)
print("-" * 150)
print("shape: ", batch.tgt_mask.shape)
print("tgt_mask: \n", batch.tgt_mask)
print("-" * 150)
print("Number of non-padding tokens in the target sentences: ", batch.non_pad_tokens)

shape:  torch.Size([2, 10])
src_batch: 
 tensor([[   24,    42,    33,  4124,   231, 12321,     2,     2,     2,     2],
        [   27,    67,    83, 23124,   131,  1321,    23,    90,     2,     2]])
------------------------------------------------------------------------------------------------------------------------------------------------------
shape:  torch.Size([2, 10])
target_batch: 
 tensor([[   0,   12,    3, 3545,    1,    2,    2,    2,    2,    2],
        [   0,  122,    6,  545,   40,   78,   90,   89,   78,    1]],
       dtype=torch.int16)
------------------------------------------------------------------------------------------------------------------------------------------------------
shape:  torch.Size([2, 10])
src in the batch object: 
 tensor([[   24,    42,    33,  4124,   231, 12321,     2,     2,     2,     2],
        [   27,    67,    83, 23124,   131,  1321,    23,    90,     2,     2]])
---------------------------------------------------------------------