# Introduction 

The embedding layer is a fundamental component in natural language processing (NLP) tasks, such as text classification or language modeling. It maps discrete input indices (typically representing words or tokens) to dense, continuous vectors called embeddings. These embeddings capture semantic relationships between the input elements, allowing the neural network to reason about the meaning of the words in the context of the task.

# Inference and training 
Regarding the handling of parameters during training and inference:

Training: During training, the parameters of the embedding layer, i.e., the elements of the weight matrix, are learned through backpropagation. They are updated iteratively to minimize the loss function of the neural network using techniques like gradient descent. The gradients are computed and propagated through the network to adjust the embedding values.

Inference: During inference or evaluation, the learned parameters of the embedding layer are fixed and used as is. The network takes input indices, and the corresponding embeddings are looked up from the weight matrix without any further updates or training. The embedding layer acts as a static lookup table, providing fixed embeddings for the given indices.

Overall, the embedding layer acts as a bridge between discrete input indices and continuous embeddings. It allows the neural network to represent and process textual data in a meaningful way, capturing the relationships between words or tokens. The correspondence to a lookup table facilitates efficient and flexible access to the embeddings based on the input indices.

In [2]:
import torch
import torch.nn as nn

class Embedding(nn.Module):
    def __init__(self, num_embeddings, embedding_dim):
        super(Embedding, self).__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim

        # Initialize the embedding matrix
        self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) # note that these parameters are trainable
        nn.init.xavier_uniform_(self.weight)

    def forward(self, input):
        if input.dim() == 1:
            # If the input is a 1D tensor, expand it to 2D
            input = input.unsqueeze(1)
        
        # Retrieve the embeddings for the input indices
        embeddings = self.weight[input] # from this perspective, we are kind of doing a lookup table
        
        return embeddings.squeeze()


# Implementation details
Now, let's dive into the implementation details of the embedding layer and its correspondence to a lookup table:

Initialization: In the implementation provided earlier, the embedding layer is initialized with a weight matrix (self.weight) of shape (num_embeddings, embedding_dim). num_embeddings represents the size of the vocabulary, i.e., the number of distinct words or tokens. embedding_dim represents the desired dimensionality of the embeddings. Each row of the weight matrix corresponds to the embedding vector for a specific index.

Lookup Operation: During the forward pass, given an input tensor of indices, the embedding layer retrieves the corresponding embeddings from the weight matrix. In the implementation, this is achieved by indexing the self.weight matrix with the input indices: embeddings = self.weight[input]. The resulting embeddings tensor has a shape of (batch_size, embedding_dim), where batch_size is the number of input indices provided.

Correspondence to a Lookup Table: The embedding layer can be seen as a lookup table, where each index corresponds to a row in the table (weight matrix), and the embedding vector associated with that index is retrieved. This lookup operation is similar to accessing values in a table or dictionary based on the provided key/index.

In [4]:
# Create an instance of the custom Embedding module
vocab_size = 1000
embedding_dim = 50
embedding_layer = Embedding(vocab_size, embedding_dim)

# Generate some dummy input
input_indices = torch.tensor([1, 3, 5, 2])

# Pass the input through the embedding layer
embeddings = embedding_layer(input_indices)

print(embeddings.shape)  # Output: torch.Size([4, 50])
print(embeddings)

torch.Size([4, 50])
tensor([[-0.0400,  0.0118, -0.0004, -0.0688, -0.0588, -0.0314, -0.0043, -0.0013,
         -0.0626,  0.0684,  0.0536, -0.0084, -0.0579,  0.0711,  0.0380,  0.0638,
         -0.0657,  0.0045, -0.0316,  0.0601, -0.0629,  0.0234, -0.0057, -0.0171,
          0.0056, -0.0239, -0.0229, -0.0548, -0.0599, -0.0509, -0.0168, -0.0189,
         -0.0182, -0.0198, -0.0480,  0.0316, -0.0603, -0.0523, -0.0556, -0.0048,
          0.0647,  0.0057,  0.0633,  0.0470, -0.0298,  0.0455,  0.0719, -0.0164,
          0.0709,  0.0517],
        [ 0.0137, -0.0645,  0.0155,  0.0660,  0.0663, -0.0663,  0.0383, -0.0498,
         -0.0475, -0.0259,  0.0004, -0.0655, -0.0519,  0.0132, -0.0687,  0.0634,
         -0.0730, -0.0432,  0.0569,  0.0094, -0.0417,  0.0394,  0.0326, -0.0216,
         -0.0416, -0.0361, -0.0409,  0.0336,  0.0238,  0.0486,  0.0540,  0.0294,
         -0.0115, -0.0469, -0.0315, -0.0095, -0.0549, -0.0307,  0.0636, -0.0743,
          0.0197, -0.0616,  0.0110,  0.0066,  0.0736,  0.0299