In [2]:
# In this notebook, we learn:
# 
# 1) How to use the Embedding module in pytorch?
# 2) How to use the embedding layer in the attention_is_all_you_need paper?
#

In [None]:
# Useful Resources:
#
# 1) https://blog.acolyer.org/2016/04/21/the-amazing-power-of-word-vectors/
#       -- Explains what word embeddings are, how they are useful and how they are generated traditionally.
# 2) https://medium.com/deeper-learning/glossary-of-deep-learning-word-embedding-f90c3cec34ca
# 3) https://www.youtube.com/watch?v=D-ekE-Wlcds
#       -- Excellent Video explaining word embeddings
# 4) https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html#embedding
#       -- Official pytorch documentation for the embedding module

In [9]:
from torch import Tensor
import torch.nn as nn

import math
import torch

## Understanding nn.Embedding module in pytorch.

In [10]:
# Our vocabulary only has 10 tokens.
vocab_size = 10
# Every token is associated with a 10 sized embedding vector.
embedding_vector_size = 10

In [11]:
# Embedding module basically serves as a look-up table for us. Given a token id, we get the
# embedding vector associated with the token id. We will look at training embeddings at a 
# later stage.
embedding_layer = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_vector_size)
print(embedding_layer, type(embedding_layer))

Embedding(10, 10) <class 'torch.nn.modules.sparse.Embedding'>


In [12]:
token_ids = torch.tensor(data=[0, 3, 4], dtype=torch.int)
# Retrieves the embeddings for the tokens 0, 3 and 4.
embeddings = embedding_layer(token_ids)
print(embeddings)
print(embeddings.shape)
print(type(embeddings))

tensor([[ 1.3012,  1.1632, -0.2358, -0.8029,  0.9754, -0.2711,  0.1415,  0.4453,
          0.2625, -0.5352],
        [ 1.0317,  2.2983,  0.2876, -1.0384,  1.4763,  2.4843,  0.5297,  0.5989,
          0.1117,  1.1091],
        [ 0.7685,  1.6019,  0.3784, -0.2048, -0.6040,  0.1599,  1.5740,  1.0345,
          0.9879,  0.0837]], grad_fn=<EmbeddingBackward0>)
torch.Size([3, 10])
<class 'torch.Tensor'>


In [13]:
# The embedding_layer should raise an error if a token >= 10 is provided.
out_of_bound_ids = torch.tensor(data=[10], dtype=torch.int)
_ = embedding_layer(out_of_bound_ids)

IndexError: index out of range in self

## Using Embeddings in transformers

In [14]:
# The model in 'Attention Is All You Need' paper uses the embedding vectors to represent English 
# language. The input to the model is a list of indices that represent the tokens. Inside the 
# model, we then assign an embedding vector per index and train the embeddding vectors 
# along with the model.
#
# Explaining the above with an example:
# If the English sentence input to transformer is: "I am Batman", then the model gets 
# [7, 5, 89] (excluding <sos>, <eos> in this example) as input i.e., there is a fixed mapping 
# between the English tokens and corresponding indices. 
# 'I' is mapped to 7.
# 'am' is mapped to 5.
# 'Batman' is mapped to 89.
#
# Within the transformer model, we then convert these indices (7, 5, 89) to embedding vectors of 
# size 512 and train these vectors along with the model. Our Embedding module below takes care
# of the above process for us in the transformer model.

In [15]:
# Refer to 'using_modules.ipynb' to understand more about modules in pytorch.
# We will train the embeddings as part of the transformer model. 
class Embeddings(nn.Module):
    def __init__(self, vocab_size: int, embedding_dim: int):
        """Creates the embedding layer that serves as a look-up table for the tokens in the transformer model.

        Args:
            vocab_size (int): Size of the vocabulary i.e., number of distinct tokens in the vocabulary.
            embedding_dim (int): The size of the embedding vector to be generated for each token.
        """
        super(Embeddings, self).__init__()
        self.look_up_table = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim

    # The input is be a '2D' tensor where each '1D' tensor within the '2D' tensor is the list
    # of indices corresponding to the tokens in the vocab.
    # [[0, 123, 3455, 4556, 7, 1, 2, 2], [0, 56, 98, 6234, 909, 56, 1, 2]]
    # 0 - <SOS>, 1 - <eos>, 2 - <pad>
    def forward(self, input: Tensor) -> Tensor:
        """Converts the input tensor of token indices to their corresponding embedding vectors.

        Args:
            input (Tensor): The input tensor of token indices.
                            shape: [batch_size, seq_len]

        Returns:
            Tensor: The tensor of embedding vectors for the corresponding input tokens.
                    shape: [batch_size, seq_len, embedding_dim]
        """
        # There is no reasoning as to why the original 'attention_is_all_you_need' paper scaled the
        # embeddings using 'math.sqrt(embedding_dim)'. A few blogs attempted to explain this reasoning,
        # but I haven't found anything with solid reasoning.
        return self.look_up_table(input) * math.sqrt(self.embedding_dim)

# The discussion here (https://datascience.stackexchange.com/a/88159) attempts to explain the reason 
# for scaling the embeddings but that is incorrect.
#
# In general, the word_embeddings in the nn.Embedding layer are initialized using N(0, 1) distribution.
# You can find the evidence for this in the source code of Embedding class in pytorch.
# https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding
# Within the above source code, the 'weight' property (which are our word embeddings) is initalized
# inside the 'reset_parameters' method using N(0, 1) distribution.
#  
# So, the expected magnitude of the embedding vector is sqrt(embedding_dim) and the expected magnitude
# of the positional embedding (more about this in step_8_positional_encoding.ipynb) is roughly (assuming 
# uniform distribution for sinsuodial positional encodings which is not corrrect but gives an easier 
# estimate) sqrt(embedding_dim / 3). So, they are already on the same scale and embeddings don't need 
# to be scaled to bring them to the same scale. Use ChatGPT / Gemini to get an explanation how the 
# expected magnitudes are calculated in the respective cases which gave me a reasonable answer.
#
# This blog (https://towardsdatascience.com/how-to-code-the-transformer-in-pytorch-24db27c8f9ec) just 
# says that this scaling is done to magnify the contribution of word embeddings when word embeddings are
# added to the positional encodings. This is true, however, I have seen people mentioning on the internet
# that scaling did not have any visible impact on their models (To be verified).

In [16]:
sample_input = torch.tensor(data=[[0, 123, 345, 455, 7, 1, 2, 2], [0, 56, 98, 234, 9, 56, 1, 2]], dtype=torch.int)
print(sample_input, "\n")
print(sample_input.shape)

tensor([[  0, 123, 345, 455,   7,   1,   2,   2],
        [  0,  56,  98, 234,   9,  56,   1,   2]], dtype=torch.int32) 

torch.Size([2, 8])


In [17]:
transformer_embedding_layer = Embeddings(vocab_size=500, embedding_dim=20)
print(transformer_embedding_layer) 
print(type(transformer_embedding_layer))

Embeddings(
  (look_up_table): Embedding(500, 20)
)
<class '__main__.Embeddings'>


In [18]:
# Notice that the shape of the input is (2, 8) and the shape of the output is (2, 8, 20).
# For every position (i, j), an embedding of size 20 is added in the last dimension. 
transformer_embedding_layer_output = transformer_embedding_layer.forward(input=sample_input)
print(transformer_embedding_layer_output, "\n")
print(transformer_embedding_layer_output.shape)

tensor([[[ -1.8577,  -0.3983,  -3.8971,  -9.7414,   5.9951,  -1.6411,  -2.4170,
           -0.6763,   5.1841,   2.2469,   3.3569,   2.3509,   6.1423,  -3.4343,
           -5.5828,  -1.1291,  -9.1931,   1.9975,  -5.6461,  -3.1570],
         [  0.6649,  -6.8562,  -1.0972,  -3.3827,   4.4044,   2.6300,   1.7047,
           -1.0038,  -7.0174,  -6.7603,   2.1979,   5.4500,   1.0113,   3.9205,
            0.4894,   0.2233,   7.3648,   6.1077,  -2.5670,  -7.5822],
         [ -6.9634,  -1.0228,   6.5550,   6.1711, -10.7813,  -4.3164,   0.8786,
            4.4746,   0.0207,  -2.9096,   3.8563,  -3.9088,  -3.1286,  -6.4031,
            0.1279,  -0.9704,  -3.5322,   4.8648,   2.2500,   3.4233],
         [  2.9664,   3.0572,  -0.9320,   1.5378,   3.3464,  -0.9655,  -3.4692,
            2.1214,   6.1003,  -0.4259, -11.8368,   0.6232,  -3.9156,  -0.4642,
           -0.0155,   1.9090,   0.0717,  -0.5324,   8.4863,  -3.9943],
         [ -1.4025,   4.5092,   0.5611,   2.6926,   0.9760,  -1.7043,  11.66