### Create Word Embeddings

First of all we need to convert each word in the input sequence to an embedding vector. Embedding vectors will create a more semantic representation of each word.

Suppoese each embedding vector is of **`512`** dimension and suppose our vocab size is **`100`**, then our embedding matrix will be of size **`100x512`**. These marix will be learned on training and during inference each word will be mapped to corresponding **`512 d`** vector. Suppose we have batch size of **`32`** and sequence length of **`10`**(10 words). The the output will be **`32x10x512`**.

In [8]:
import torch
import torch.nn as nn
import math

In [9]:
class Embedding(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        """
        Args:
            vocab_size: size of vocabulary
            embed_dim: dimension of embeddings
        """
        super().__init__()
        self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_dim) # Embedding layer
        self.embed_dim = embed_dim

        # Initialize weights to improve training
        self.init_weights()
        
    def init_weights(self):
        """Initialize embedding weights using normal distribution with small standard deviation"""
        nn.init.normal_(self.embed.weight, mean=0, std=0.02)

    def forward(self, x):
        """
        Args:
            x: Input token indices [batch_size, seq_len]
        Returns:
            embeddings: Token embeddings [batch_size, seq_len, embed_dim]
        """
        # Scale embeddings by sqrt(embed_dim) to stabilize gradients
        return self.embed(x) * math.sqrt(self.embed_dim)

In [10]:
vocab_size = 10  # Assume we have 10 words in our vocabulary
embed_dim = 5    # Each word is embedded into a 5-dimensional vector
embeddings = Embedding(vocab_size, embed_dim)

In [11]:
# Create a sample input (batch of token indices)
sample_input = torch.tensor([1, 3, 7])  # Example token indices

# Get embeddings
output = embeddings(sample_input)
output

tensor([[ 0.0305,  0.0706, -0.0064,  0.0622,  0.0453],
        [ 0.0268, -0.0597,  0.0250,  0.0363,  0.0740],
        [-0.0341, -0.0302, -0.0716, -0.0086,  0.0105]], grad_fn=<MulBackward0>)

In [12]:
output.shape

torch.Size([3, 5])