# Implementing the Transformer

Reference: [Implementation_Tutorial](Transformer_Implementation_Tutorial.ipynb)

In [34]:
from torch import nn 
import torch
from math import log, sqrt

## Embdedding and Position Encoding Module

In [None]:
class EmbeddingWithPositionalEncoding(nn.Module):
    def __init__(self, vocab_size: int, 
                 d_embed: int, 
                 d_model: int,
                 dropout_p: float = 0.1
                 ):
        super().__init__()
        self.d_model = d_model
        self.d_embed = d_embed
        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=d_embed
        )
        self.projection = nn.Linear(
            in_features=d_embed,
            out_features=d_model
        )
        self.scaling = float(sqrt(self.d_model))
        self.layerNorm = nn.LayerNorm(self.d_model)
        self.dropout = nn.Dropout(p=dropout_p)

    @staticmethod # decorator that indicates that the following function doesn't operate on `self`
    def create_positional_encoding(seq_length, d_model, batch_size):

        positions = torch.arange(seq_length, dtype=torch.long)\
            .unsqueeze(1) # shape (seq_length, 1) i.e. makes it vertical
        
        div_term = torch.exp(
            (torch.arange(0, d_model, 2)/d_model)*(-4)*log(10)
        )
        
        pe = torch.zeros(size=(seq_length, d_model), dtype=torch.long) # the tensor to be multiplied to positions tensor to get pe
        pe[:, 0::2] = torch.sin(positions*div_term) # for even dimensions
        pe[:, 1::2] = torch.cos(positions*div_term) # for odd dimensions
        pe = pe.unsqueeze(0).expand(batch_size, -1, -1) # copy out the encodings for each batch
        return pe
    
    def forward(self, x):
        batch_size, seq_length = x.shape

        # step 1: make embeddings
        token_embedding = self.embedding(x)

        # step 2: go from d_embed to d_model
        token_embedding = self.projection(token_embedding) \
            * self.scaling # multiplying with scaling factor, just like in the paper

        # step 3: add positional encoding
        pos_encoding = self.create_positional_encoding(
            seq_length=seq_length, 
            d_model = self.d_model,
            batch_size=batch_size
        )

        #step 4: normalize the sum of pos encoding and token_embed
        norm_sum = self.layerNorm(pos_encoding + token_embedding)
        op = self.dropout(norm_sum)
        return op



##