# Transformer From Scratch

## Table of Contents
- [Introduction](#introduction)
- [Setup](#setup)
- [Embeddings](#embeddings)
  - [Input Embeddings](#input-embeddings)
  - [Positional Embeddings](#positional-embeddings)
- [Attention Layers](#attention-layers)
  - [Scaled Dot Product Attention](#scaled-dot-product-attention)
  - [Multi-Headed Attention](#multi-headed-attention)
- [Feed-Forward Network](#feed-forward-network)
- [Intermediate Layers](#intermediate-layers)
  - [Layer Normalization](#layer-normalization)
  - [Residual Connections](#residual-connections)
  - [Linear Layer](#linear-layer)
- [Encoder-Decoder Structure](#encoder-decoder-structure)
  - [Encoder](#encoder)
  - [Decoder](#decoder)
- [Transformer](#transformer)

## Introduction
The goal of this notebook is to provide a practical resource to systematically review and learn the underlying Transformer model architecture from the "[Attention Is All You Need](https://arxiv.org/pdf/1706.03762)" paper.

<img src="images/transformer.png" width="400">


## Setup

In [2]:
import math
import torch
import torch.nn as nn
from torch import Tensor

## Embeddings

<img src="images/embeddings.png" width="600">

### Input Embeddings
The input tokens passed through the Transformer model are first convered to vectors of dimension $d_{model}$ through a learned embedding, which we call the **Input Embeddings**.

In section **3.4 Embeddings and Softmax** of the paper, the authors state that the output of the embedding layer is multiplied by $\sqrt{d_{model}}$.

Another thing to note is that input tokens are of type `int64`. These integer token indicies are then mapped to dense vectors via the learned embedding resulting in continuous floating-point vectors of type `float32`. All computations later (e.g. self-attention layers, feedforward networks, etc.) are done using floating-point operations.

In [None]:
class InputEmbeddings(nn.Module):
    def __init__(self, d_model: int, vocab_size: int):
        """
        Embedding layer for input tokens

        Args:
            d_model (int): Hidden dimension of the model. The size of the vector
                representations (embeddings / hidden states) used throughout the
                Transformer model.
            vocab_size (int): Size of the vocabulary. Number of unique tokens in the
                input data.
        """
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size

        # Input to embedding layer: (*)
        # Output from embedding layer: (*, H), where H is the hidden dim of the model.

        # TODO: Create an embedding layer of size (vocab_size, d_model)
        self.embedding = ...

    def forward(self, x: Tensor) -> Tensor:
        """
        Embed input tokens.

        Args:
            x (Tensor): Input tokens of shape `(bs, seq_len)`.

        Returns:
            Tensor: Embedded input of shape `(bs, seq_len, d_model)`.
        """
        # seq_len dimension contains token ids that can be mapped back to unique word

        ### REFER TO 3.4 Embeddings and Softmax in paper ###
        # TODO: Return the result of the embedded `x` tensor
        return ...

### Positional Embeddings
The parallel nature of Transformers means that it lacks positional information about token order compared to sequential models like Recurrent Neuralnets (RNNs). We can resolve this by applying **Positional Encoding**.

The paper uses sine and cosine functions for even and odd positions:
$$
\text{PE}_{(pos, 2i)} = \sin(pos / 10000^{2i/d_{model}}) 
$$

$$
\text{PE}_{(pos, 2i + 1)} = \cos(pos / 10000^{2i/d_{model}})
$$

In [None]:
class PositionalEncoding(nn.Module):
    pe: Tensor

    def __init__(self, d_model: int, max_seq_len: int, dropout: float = 0.1):
        """
        Positional encoding / embeddings for input tokens

        Args:
            d_model (int): Hidden dimension of the model. The size of the vector
                representations (embeddings / hidden states) used throughout the
                Transformer model.
            max_seq_len (int, optional): Maximum sequence length.
            dropout (float, optional): Dropout rate. Defaults to 0.1.
        """
        super().__init__()
        self.d_model = d_model
        self.max_seq_len = max_seq_len

        # TODO: create dropout
        self.dropout = ...

        # Create positional encodings of shape (max_seq_len, d_model)
        pe = torch.zeros(max_seq_len, d_model)

        ### REFER TO 3.5 Positional Encoding ###

        # TODO: Create tensor of shape (max_seq_len, 1) with type `torch.float`
        # Result: [[0, 1, ..., max_seq_len]]
        pos = ...

        # PE division term => 10000^-(2 * i / d_model)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )

        ### TODO: Use sine and cosine functions for even and odd positions
        # NOTE: the above `div_term` is actually the reciprocal of the 10000^(2*i / d_model)

        ...
        
        ###

        # TODO: Add batch dimension to positional encodings
        pe = ...  # (1, max_seq_len, d_model)

        # Tensor is saved in file when model is saved.
        self.register_buffer("pe", pe) # Allows you to use `self.pe`

    def forward(self, x: Tensor) -> Tensor:
        """
        Apply positional encoding to input embeddings.

        Args:
            x (Tensor): Input embeddings of shape `(bs, seq_len, d_model)`.

        Returns:
            Tensor: Positional encodings of shape `(bs, seq_len, d_model)`.
        """

        # Add positional encodings to input embeddings

        # TODO: get the seq_len from `x`
        seq_len = ...

        # TODO: Shorten positional encodings if seq_len is greater than max_seq_len
        pe_out = ...

        # TODO: Add the positional information onto the input tensor `x`
        x = ...

        # TODO: Apply dropout
        return ...


## Attention Layers

The **Multi-Head Attention block** is where the attention mechanism exists. It is computed fundamentally with scaled dot product attention.

There are 2 types of attention in the Transformer: **Self-Attention** and **Cross-Attention**. They both use the same multi-head attention mechanism.

The primary differences are:

- **Self-Attention:** Queries, keys, and values come from the same input sequence.

- **Cross Attention:** Queries come from one source (e.g., the decoder’s hidden state in a transformer), while keys and values come from another source (e.g., the encoder’s outputs).

<br>

<img src="images/attention.png" width="600">

### Scaled Dot Product Attention

$$
\text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V
$$

The paper uses $d_{\text{model}} = 512$ with $h=8$ parallel attention layers (heads).

Therefore, the dimension of queries, keys, and values will be $d_k=d_v=d_{model}/h = 64$.

In [None]:
@staticmethod
def scaled_dot_product_attention(
    q: Tensor, k: Tensor, v: Tensor, mask: Tensor | None, dropout: nn.Dropout | None
) -> tuple[Tensor, Tensor]:
    """Compute Scaled Dot Product Attention."""

    d_k = q.shape[-1]

    # TODO: Compute attention scores (not applying softmax) with
    # @ operation (matrix multiply) and .transpose()

    # (bs, num_heads, seq_len, d_k) -> (bs, num_heads, seq_len, seq_len)
    scores = ...

    if mask is not None:
        # For all values in mask == 0 replace with -inf
        scores = scores.masked_fill(mask == 0, float("-inf"))

    # TODO: Apply softmax to last dim
    # Each row is a query, each column is a key. You want to convert raw scores over keys
    # into a probability distribution. In other words, you want each row / query to have
    # weights that sum to 1.
    scores = ...  # (bs, num_heads, seq_len, seq_len)

    if dropout is not None:
        scores = dropout(scores)

    # TODO: Multiply by values
    weights = ...  # (bs, num_heads, seq_len, d_k)

    # We return the scores for visualization
    return weights, scores

### Multi-Headed Attention

Each head of attention is computed with the above scaled dot-product attention and then concatenated.

$$
\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \cdots, \text{head}_h) W^O
$$
$$
\text{where head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
$$

The projection matrix $W^O \in \mathbb{R}^{hd_v \times d_{model}}$ ensures that the concatenation of the heads $h \cdot d_v$ is projected back into $d_{model}$ which is the desired output dimension of the multi-headed attention.

However, in the original paper, $h \cdot d_v = d_{model}$ so we can ignore the details.

In [None]:
import math

import torch
import torch.nn as nn
from torch import Tensor


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads  # aka. `h`

        # TODO: Get the dimension of d_k, d_v
        self.d_k = ...  # d_k = d_v

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, q: Tensor, k: Tensor, v: Tensor, mask: Tensor | None) -> Tensor:
        """Compute Multi-Headed Attention."""


        # TODO: Create `query`, `key`, `value` tensors
        query = ...  # (bs, seq_len, d_model) -> (bs, seq_len, d_model)
        key = ...  # (bs, seq_len, d_model) -> (bs, seq_len, d_model)
        value = ...  # (bs, seq_len, d_model) -> (bs, seq_len, d_model)

        # TODO: Split into multiple heads with .view()
        # (bs, seq_len, d_model) -> (bs, seq_len, num_heads, d_k)
        query = ...

        # TODO: Use .transpose()
        # (bs, seq_len, num_heads, d_k) -> (bs, num_heads, seq_len, d_k)
        query = ...

        key = ...
        key = ...

        value = ...
        value = ...

        weights, scores = scaled_dot_product_attention(
            query, key, value, mask, self.dropout
        )

        ### Perform concatenation of the heads ###

        # TODO: Use .transpose()
        # (bs, num_heads, seq_len, d_k) -> (bs, seq_len, num_heads, d_k)
        weights = ...

        weights = weights.contiguous()

        # TODO: Use .view() to concatenate the heads
        # (bs, seq_len, num_heads, d_k) -> (bs, seq_len, d_model)
        concat = ...

        # TODO: Apply W_o projection
        # (bs, seq_len, d_model) -> (bs, seq_len, d_model)
        return ...


## Feed-Forward Network
The Feed-Forward Network (FFN) in a Transformer is applied independently to each token position after the Multi-Head Self-Attention Mechanism. It consists of two linear layers with a non-linearity:

$$
\text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2
$$

It can also be expressed with the $\text{ReLU}$ activation function which squashes negative inputs to $0$.

$$
\text{FFN}(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2
$$

This will be helpful:
> The dimensionality of input and output is $d_{model} = 512$, and the inner-layer has dimensionality $d_{ff} = 2048$.

In [None]:
class FeedForwardBlock(nn.Module):
    def __init__(self, d_model: int = 512, d_ff: int = 2048, dropout: float = 0.1):
        super().__init__()

        ### REFER TO Section 3.3 Position-wise Feed-Forward Networks ###

        # TODO: Create the two linear transformations between
        self.linear1 = ...
        self.dropout = ...
        self.linear2 = ...

    def forward(self, x: Tensor) -> Tensor:
        """
        1. (batch_size, seq_len, d_model) -> (batch_size, seq_len, d_ff)
        2. (batch_size, seq_len, d_ff) -> (batch_size, seq_len, d_model)

        Args:
            x (Tensor): The input tensor. `(batch_size, seq_len, d_model)`
        Returns:
            Tensor: The output tensor. `(batch_size, seq_len, d_model)`
        """
        
        ### TODO: Create forward pass. Apply dropout after ReLU.
        # Use torch.relu()

        x = ...

        ###

        return x

## Intermediate Layers

### Layer Normalization

LayerNorm operates independently on each sample within a batch, unlike BatchNorm, which normalizes across the batch dimension. It normalizes the inputs across the feature dimension.

**Purpose:** Mitigate internal covariate shift thus improving training speed, stability, and convergence of the model. Also, improves generalization.

In [None]:
class LayerNormalization(nn.Module):
    def __init__(self, eps: float = 1e-6):
        """
        Args:
            eps (float, optional): Epsilon value to avoid division by zero.
                Defaults to 1e-6.
        """
        super().__init__()
        self.eps = eps

        # Two learnable parameters
        self.alpha = nn.Parameter(torch.ones(1))  # Scale parameter (Multiplicative)
        self.bias = nn.Parameter(torch.zeros(1))  # Shift parameter (Additive)

    def forward(self, x: Tensor) -> Tensor:
        """
        Apply layer norm to last dimension of the input tensor.

        Args:
            x (Tensor): `(bs, seq_len, d_model)`.

        Returns:
            Tensor: `(bs, seq_len, d_model)`.
        """

        # TODO: Apply mean & std to last dimension
        mean = ... # (bs, seq_len, 1)
        std = ...  # (bs, seq_len, 1)

        std = std + self.eps

        # TODO: normalize x
        x = ...

        # TODO: scale by alpha
        x = ...

        # TODO: add bias
        x = ...

        return x

### Residual Connections

The paper defines the residual connection implementation as
$$
\text{LayerNorm}(x + \text{Sublayer}(x))
$$

However, we will follow [The Annotated Transformer's](https://nlp.seas.harvard.edu/2018/04/03/attention.html) implementation by applying dropout to the output of each normalized sub-layer, before adding it to the input.


In [None]:
class ResidualConnection(nn.Module):
    def __init__(self, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = LayerNormalization()

    def forward(self, x: Tensor, sublayer: nn.Module) -> Tensor:
        """
        Residual connection with layer normalization.

        Args:
            x (Tensor): `(bs, seq_len, d_model)`.
            sublayer (nn.Module): The intermediate layer to wrap w/ residual connection.

        Returns:
            Tensor: `(bs, seq_len, d_model)`.
        """

        # TODO: Apply dropout to sublayer before 
        return ...

### Linear Layer

This layer is a projection from $d_{model}$ into log probabilities across the entire vocab.

In [None]:
class LinearLayer(nn.Module):
    def __init__(self, d_model: int, vocab_size: int):
        """
        Linear Layer is a projection layer that converts the embedding into the
        vocabulary.

        Args:
            d_model (int): The size of the model's hidden dimension.
            vocab_size (int): The size of the vocabulary.
        """
        super().__init__()

        # TODO: Create a linear layer of size (d_model, vocab_size)
        self.linear = ...

    def forward(self, x: Tensor) -> Tensor:
        """
        Apply projection on embeddings.
        Output will be a log probability distribution over the vocabulary.

        Args:
            x (Tensor): `(bs, seq_len, d_model)`.

        Returns:
            Tensor: `(bs, seq_len, vocab_size)`.
        """

        # (bs, seq_len, d_model) -> (bs, seq_len, vocab_size)
        # TODO: Apply projection
        out = ...

        # TODO: Return log probabilities not probabilities to vocab_size
        return torch.log_softmax(..., dim=...)

## Encoder-Decoder Structure
We can finally put everything together

<img src="images/transformer.png" width="400">

In [1]:
# dummy variables
src_vocab_size = 1000
tgt_vocab_size = 1000
src_seq_len = 100
tgt_seq_len = 100

d_model = 512  # hidden dimension of the model
num_blocks = 6  # number of encoder and decoder blocks
num_heads = 8  # number of attention heads
d_ff = 2048  # size of the feed-forward layer
dropout = 0.1  # dropout rate

### Encoder

In [None]:
class EncoderBlock(nn.Module):
    def __init__(
        self,
        self_attention_block: MultiHeadAttention,
        feed_forward_block: FeedForwardBlock,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList(
            [ResidualConnection(dropout) for _ in range(2)]
        )

    def forward(self, x: Tensor, src_mask: Tensor) -> Tensor:
        """
        Forward pass through the encoder block.

        Args:
            x (Tensor): `(bs, seq_len, d_model)`.
            src_mask (Tensor): The mask for the source language `(bs, 1, 1, seq_len)`.

        Returns:
            Tensor: `(bs, seq_len, d_model)`.
        """

        ### TODO: Create the encoder block forward pass ###
        # NOTE: the second param of the residual connection requires a layer

        x = self.residual_connections[0](
            x, lambda x: ...
        ) # Residual connection on self attention block

        # Residual connection on feed forward block
        x = self.residual_connections[1](x, ...)
        ###

        return x

In [None]:
encoder_blocks: list[EncoderBlock] = []

for _ in range(num_blocks):
    # TODO: Create and append Encoder Blocks here
    ...

encoder_layers = nn.ModuleList(encoder_blocks)

In [None]:
class Encoder(nn.Module):
    def __init__(self, layers: nn.ModuleList):
        super().__init__()
        self.layers = layers # the encoder blocks
        self.norm = LayerNormalization()

    def forward(self, x: Tensor, src_mask: Tensor) -> Tensor:
        """
        Foward pass through the encoder.

        Args:
            x (Tensor): The input to the encoder.
            src_mask (Tensor): The mask for the source language.

        Returns:
            Tensor: A tensor of `(batch_size, seq_len, d_model)` represents a sequence
                of context-rich embeddings that encode the input sequence's semantic and
                positional information.
        """
        for layer in self.layers:
            x = layer(x, src_mask)

        # Apply a final layer normalization after all encoder blocks
        return self.norm(x)
    
encoder = Encoder(encoder_layers)

### Decoder

To reiterate, for **Cross Attention**, queries come from one source (e.g., the decoder’s hidden state in a transformer), while keys and values come from another source (e.g., the encoder’s outputs).

In [None]:
class DecoderBlock(nn.Module):
    def __init__(
        self,
        self_attention_block: MultiHeadAttention,
        cross_attention_block: MultiHeadAttention,
        feed_forward_block: FeedForwardBlock,
        dropout: float = 0.1,
    ):
        """
        Decoder block contains:
            1. (Masked Multi-Head Attention) A self-attention block where `qkv` come
                from decoder's input embedding.
            2. (Multi-Head Attention) A cross-attention block where `q` come from
                decoder and `k`,`v` come from encoder outputs.
        """
        super().__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList(
            [ResidualConnection(dropout) for _ in range(3)]
        )

    def forward(
        self,
        x: Tensor,
        encoder_output: Tensor,
        src_mask: Tensor,
        tgt_mask: Tensor,
    ) -> Tensor:
        """
        Forward pass through the decoder block.
        Decoder block ussed for machine-translation to go from source to target lang.

        Args:
            x (Tensor): The decoder input `(bs, seq_len, d_model)`.
            encoder_output (Tensor): `(bs, seq_len, d_model)`.
            src_mask (Tensor): `(bs, 1, 1, seq_len)`.
            tgt_mask (Tensor): `(bs, 1, seq_len, seq_len)`.

        Returns:
            Tensor: `(bs, seq_len, d_model)`.
        """
        
        ### TODO: Create the decoder block forward pass ###

        # Residual connection on self attention block
        x = self.residual_connections[0](
            x, lambda x: ...
        )

        # Residual connection around cross-attention block
        # Use encoder output for the keys and values
        x = self.residual_connections[1](
            x,
            lambda x: ...
        )

        # Residual connection on feed forward block
        x = self.residual_connections[2](x, ...)
        ###
        
        return x

In [None]:
decoder_blocks: list[DecoderBlock] = []

for _ in range(num_blocks):
    # TODO: Create and append Decoder Blocks here
    ...

decoder_layers = nn.ModuleList(decoder_blocks)

In [None]:
class Decoder(nn.Module):
    def __init__(self, layers: nn.ModuleList):
        super().__init__()
        self.layers = layers # decoder blocks
        self.norm = LayerNormalization()

    def forward(
        self,
        x: Tensor,
        encoder_output: Tensor,
        src_mask: Tensor,
        tgt_mask: Tensor,
    ) -> Tensor:
        """
        Forward pass through the decoder.

        Args:
            x (Tensor): The input to the decoder block.
            encoder_output (Tensor): The output from the encoder.
            src_mask (Tensor): The mask used for the source language (e.g. English).
            tgt_mask (Tensor): The mask used for the target language (e.g. German).

        Returns:
            Tensor: `(bs, seq_len, d_model)`.
        """
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)

        return self.norm(x)

decoder = Decoder(decoder_layers)

## Transformer

In [None]:
class Transformer(nn.Module):
    def __init__(
        self,
        encoder: Encoder,
        decoder: Decoder,
        src_embed: InputEmbeddings,
        tgt_embed: InputEmbeddings,
        src_pos: PositionalEncoding,
        tgt_pos: PositionalEncoding,
        projection_layer: LinearLayer,
    ):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer

    def encode(self, src: Tensor, src_mask: Tensor) -> Tensor:
        """Forward pass through the encoder with input tokens of type int64.

        Args:
            src (Tensor): `(bs, seq_len)`.
            src_mask (Tensor): `(bs, 1, 1, seq_len)`.

        Returns:
            Tensor: `(bs, seq_len, d_model)`.
        """

        # Embedding maps token ids to dense vectors of type float32

        # TODO: Embed source tokens
        src = ...  # (bs, seq_len) -> (bs, seq_len, d_model)
        
        # TODO: Apply positional encoding
        src = ...

        return ...

    def decode(
        self, encoder_output: Tensor, src_mask: Tensor, tgt: Tensor, tgt_mask: Tensor
    ) -> Tensor:
        """
        Forward pass through the decoder.
        - Encoder output is used in the cross-attention block and is of type float32.
        - Target tokens are still of type int64 and need to be embedded with input
        embeddings + positional encoding.

        Args:
            encoder_output (Tensor): `(bs, seq_len, d_model)`.
            src_mask (Tensor): `(bs, 1, 1, seq_len)`.
            tgt (Tensor): `(bs, seq_len)`.
            tgt_mask (Tensor): `(bs, 1, seq_len, seq_len)`.

        Returns:
            Tensor: `(bs, seq_len, d_model)`.
        """

        # TODO: Embed target tokens
        tgt = ...  # (bs, seq_len) -> (bs, seq_len, d_model)
        
        # TODO: Apply positional encoding
        tgt = ...

        # TODO: Forward pass through the decoder
        return ...

    def project(self, x: Tensor) -> Tensor:
        """
        Project the output of the decoder to the target vocabulary size.

        Args:
            x (Tensor): The output of the decoder `(bs, seq_len, d_model)`.

        Returns:
            Tensor: `(bs, seq_len, vocab_size)`.
        """

        return self.projection_layer(x)

# TODO: Create embedding layers
src_embed = ...
tgt_embed = ...

# TODO: Create positional encoding layers
src_pos = ...
tgt_pos = ...

# TODO: Create projection layer
projection_layer = ...

transformer = Transformer(
    encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer
)