In [19]:
# In this notebook, you learn:
#
# 1) What does the token predictor look like?
# 2) How does the token predictor convert the Decoder output to a sequence of tokens?

In [20]:
from torch import nn, Tensor

import torch

In [21]:
# The output of the Decoder is a tensor of shape [batch_size, seq_len - 1, d_model]. The token predictor converts 
# this decoder output tensor into probabilities. The token predictor is a simple linear layer followed by a 
# softmax function. The linear layer projects the d_model dimensional vector into a vocab_size (tgt vocab size) 
# dimensional vector. The softmax function converts the vocab_size dimensional vector into a probability 
# distribution over the vocabulary.

<img src="../../Data/Images/OutputGenerator.png" alt="Output Generator" width="550" height="500">

In [None]:
# credits: The above image is taken from this blog post: https://jalammar.github.io/illustrated-transformer/

In [22]:
# CONSTANTS TO BE USED IN THIS NOTEBOOK.
# Number of sentences in a batch.
batch_size = 3
# Number of tokens in a sentence.
seq_len = 4
# Dimension of the word embeddings.
d_model = 8
# Size of the vocabulary.
vocab_size = 5

In [23]:
# Generates input to experiment with the pipeline.
def generate_batch_of_input_data(batch_size: int, seq_len: int, d_model: int) -> Tensor:
    return torch.randn(batch_size, seq_len, d_model)

In [24]:
class TokenPredictor(nn.Module):
    def __init__(self, d_model: int, tgt_vocab_size: int):
        super(TokenPredictor, self).__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.linear = nn.Linear(in_features=d_model, out_features=tgt_vocab_size)
        # The non-module variables are not added to the list of parameters of the model.
        self.log_softmax = nn.LogSoftmax(dim=-1)

    def forward(self, decoder_output: Tensor) -> Tensor:
        """The forward pass of the token predictor. Calculates the probability distribution over the 
           vocabulary. Each token vector has a corresponding probability distribution over the 
           vocabulary since we predict one token per output.

        Args:
            decoder_output (Tensor): Output of the Decoder.
                                     SHAPE: [batch_size, seq_len - 1, d_model]

        Returns:
            Tensor: Log probability distribution over the vocabulary. 
                    SHAPE: [batch_size, seq_len - 1, vocab_size]
        """
        # Project the decoder output to the vocab_size dimensional space.
        logits = self.linear(decoder_output)
        # Convert the logits to a probability distribution over the vocabulary. All the entires in the
        # output tensor are negative since we are using log softmax. The log softmax is used to make
        # the training more numerically stable. However, the maximum value is still the same as the 
        # maximum value of the original softmax output.
        return self.log_softmax(logits)

In [25]:
token_predictor = TokenPredictor(d_model=d_model, tgt_vocab_size=vocab_size)
print(token_predictor)

OutputGenerator(
  (linear): Linear(in_features=8, out_features=5, bias=True)
  (log_softmax): LogSoftmax(dim=-1)
)


In [26]:
input_data = generate_batch_of_input_data(batch_size=batch_size, seq_len=seq_len, d_model=d_model)
print("shape: ", input_data.shape)
print("input_data: \n", input_data)

shape:  torch.Size([3, 4, 8])
input_data: 
 tensor([[[ 0.6648, -0.0731,  1.4228,  0.2574,  0.4639,  0.2357, -1.0842,
           0.8967],
         [ 1.1176, -1.0840,  1.3274,  1.3652, -0.1541, -0.4554,  0.1552,
           0.1834],
         [ 0.7545, -0.6348, -0.1767,  2.6542,  0.6799, -1.1835, -0.4171,
           0.0487],
         [-0.9292,  0.4037, -0.4447,  0.0704, -0.6954, -0.1496,  0.1776,
          -0.5893]],

        [[-0.3329, -0.9574, -1.6375,  0.1879,  0.6625,  0.4439,  0.1016,
           1.0151],
         [ 1.1896,  0.0778, -0.1971, -0.3428,  0.7501, -1.3605, -0.7687,
          -1.4977],
         [-2.9320,  1.0243,  0.6493, -0.3205,  0.8143,  1.1340, -1.4613,
           0.2017],
         [-1.2985,  0.1435, -0.0709, -0.5146, -0.5569,  0.2387,  1.8324,
          -0.2424]],

        [[-0.0513,  0.3328,  0.7430,  0.3366,  1.3838, -0.0557,  0.2952,
           1.8550],
         [-0.9629, -0.3769,  0.3185, -1.2493, -1.1389,  1.5914,  0.4296,
          -0.8643],
         [-0.7303, -1.

In [28]:
probability_distributions = token_predictor(input_data) 
print("shape: ", probability_distributions.shape)
print("probability_distributions: \n", probability_distributions)

shape:  torch.Size([3, 4, 5])
probability_distributions: 
 tensor([[[-0.8743, -2.2330, -2.1661, -1.2204, -2.7191],
         [-0.8987, -2.1929, -2.2450, -1.6275, -1.7206],
         [-1.6397, -1.3997, -2.8080, -1.6699, -1.1689],
         [-2.4470, -2.0811, -1.6307, -2.1580, -0.7396]],

        [[-1.8463, -1.1323, -1.5098, -2.5524, -1.5095],
         [-1.6705, -2.3230, -1.4088, -1.6888, -1.2564],
         [-2.0533, -1.6800, -1.5841, -1.1228, -1.8654],
         [-2.2552, -2.6807, -1.2657, -2.2573, -0.8211]],

        [[-1.0882, -2.1935, -2.3287, -0.9736, -2.5707],
         [-2.0173, -2.7733, -0.6563, -2.8773, -1.4720],
         [-1.7282, -1.6615, -1.3470, -2.8469, -1.1568],
         [-1.7844, -2.2027, -0.8811, -1.8872, -1.8593]]],
       grad_fn=<LogSoftmaxBackward0>)
