This is the corresponding code for the [01: Attention is All You Need](https://yyzhang2025.github.io/100-AI-Papers/posts/01-attention.html).


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import random
import math

In [6]:
from dataclasses import dataclass


@dataclass
class ModelConfig:
    src_vocab_size: int = 10000
    tgt_vocab_size: int = 10000
    max_seq: int = 512

    d_model: int = 512
    d_ff: int = 2048
    num_heads: int = 8
    num_layers: int = 6
    dropout: float = 0.1

    eps: float = 1e-6  # for Layer Normalization

## Transformer Model Implementation


In [5]:
class WordEmbedding(nn.Module):
    def __init__(self, config: ModelConfig, is_tgt: bool = False):
        super().__init__()

        if is_tgt:
            self.embedding = nn.Embedding(config.tgt_vocab_size, config.d_model)
        else:
            self.embedding = nn.Embedding(config.src_vocab_size, config.d_model)

    def forward(self, x):
        """
        x: (batch_size, seq_len)
        """
        return self.embedding(x)


class PositionalEmbedding(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()

        pos_index = torch.arange(config.max_seq).unsqueeze(1)  # (max_seq, 1)

        div_term = torch.exp(
            torch.arange(0, config.d_model, 2) * -(math.log(10000.0) / config.d_model)
        )

        pe = torch.zeros(config.max_seq, config.d_model)  # (max_seq, d_model)
        pe[:, 0::2] = torch.sin(pos_index * div_term)
        pe[:, 1::2] = torch.cos(pos_index * div_term)

        pe = pe.unsqueeze(0)  # (1, max_seq, d_model)

        pe.requires_grad = False
        self.register_buffer("pe", pe)

    def forward(self, x):
        """
        x: (batch_size, seq_len, d_model)
        """
        seq_len = x.size(1)
        return self.pe[:, :seq_len, :]  # (1, seq_len, d_model)


class Embedding(nn.Module):
    def __init__(self, config: ModelConfig, is_tgt: bool = False):
        super().__init__()
        self.word_embedding = WordEmbedding(config, is_tgt)
        self.positional_embedding = PositionalEmbedding(config)

    def forward(self, x):
        """
        x: (batch_size, seq_len)
        """
        word_emb = self.word_embedding(x)
        pos_emb = self.positional_embedding(word_emb)
        return word_emb + pos_emb  # (batch_size, seq_len, d_model)

### Layer Normalization

$$
\text{LayerNorm}(x) = \frac{x - \mu}{\sigma + \epsilon} \cdot \gamma + \beta
$$


In [None]:
class LayerNormalization(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()

        self.eps = config.eps

        self.gamma = nn.Parameter(torch.ones(config.d_model))  # (d_model,)
        self.beta = nn.Parameter(torch.zeros(config.d_model))  # (d_model,)

    def _compute_mean_std(self, x):
        """
        Compute mean and standard deviation for the input tensor x
        On the last dimension (features)
        x: (batch_size, seq_len, d_model)
        Output:
            mean: (batch_size, seq_len, 1)
            std: (batch_size, seq_len, 1)
        """
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return mean, std

    def forward(self, x):
        mean, std = self._compute_mean_std(x)

        # normalize x: (batch_size, seq_len, d_model)
        normalized_x = (x - mean) / (std + self.eps)  # Avoid division by zero

        return normalized_x * self.gamma + self.beta  # (batch_size, seq_len, d_model)

### Feedforward Neural Network


In [7]:
class FFN(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()

        self.ln1 = nn.Linear(config.d_model, config.d_ff, bias=True)
        self.ln2 = nn.Linear(config.d_ff, config.d_model, bias=True)

    def forward(self, x):
        x = F.relu(self.ln1(x))  # Apply ReLU activation
        x = self.ln2(x)  # Linear transformation
        return x  # (batch_size, seq_len, d_model)

In [8]:
import einops

In [None]:
def scaled_dot_product_attention(q, k, v, mask=None):
    """
    Scaled Dot-Product Attention
    q: (batch_size, num_heads, seq_len_q, d_k)
    k: (batch_size, num_heads, seq_len_k, d_k)
    v: (batch_size, num_heads, seq_len_v, d_v)
    mask: (batch_size, 1, seq_len_q, seq_len_k) or None
    """
    d_k = k.shape[-1]

    scores = einops.einsum(
        "batch heads seq_len_q d_k, batch heads seq_len_k d_k -> batch heads seq_len_q seq_len_k",
        q,
        k,
    )

    scores = scores / math.sqrt(d_k)  # Scale the scores
    scores = F.softmax(scores, dim=-1)  # Apply softmax to get attention weights

    if mask:
        scores = scores.masked_fill(mask, float("-inf"))  # Apply mask if provided

    output = einops.einsum(
        "batch heads seq_len_q seq_len_k, batch heads seq_len_k d_v -> batch heads seq_len_q d_v",
        scores,
        v,
    )

    return output

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()

        assert (
            config.d_model % config.num_heads == 0
        ), "d_model must be divisible by num_heads"
        self.d_k = config.d_model // config.num_heads  # Dimension of each head
        self.num_heads = config.num_heads

        self.qkv_proj = nn.Linear(
            config.d_model, config.d_model * 3, bias=True
        )  # (d_model, d_model * 3)

        self.out_proj = nn.Linear(config.d_model, config.d_model, bias=True)

    def forward(self, x, mask=None):
        """
        x: (batch_size, seq_len, d_model)
        mask: (batch_size, 1, seq_len_q, seq_len_k) or None
        """
        batch_size, seq_len, _ = x.size()

        q, k, v = map(
            lambda t: einops.rearrange(
                t,
                "batch seq_len (heads d_k) -> batch heads seq_len d_k",
                heads=self.num_heads,
            ),
            self.qkv_proj(x).chunk(3, dim=-1),
        )  # (batch, num_heads, seq_len, d_k)

        # Compute attention
        attn_output = scaled_dot_product_attention(q, k, v, mask)

        # Rearrange back to (batch_size, seq_len, d_model)
        attn_output = einops.rearrange(
            attn_output,
            "batch heads seq_len d_v -> batch seq_len (heads d_v)",
            heads=self.num_heads,
        )

        output = self.out_proj(attn_output)  # (batch_size, seq_len, d_model)
        return output  # (batch_size, seq_len, d_model)

In [None]:
class CrossAttention(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()

        assert (
            config.d_model % config.num_heads == 0
        ), "d_model must be divisible by num_heads"
        self.d_k = config.d_model // config.num_heads  # Dimension of each head

        self.num_heads = config.num_heads

        self.q_proj = nn.Linear(
            config.d_model, config.d_model, bias=True
        )  # (d_model, d_model)
        self.k_proj = nn.Linear(config.d_model, config.d_model, bias=True)
        self.v_proj = nn.Linear(config.d_model, config.d_model, bias=True)
        self.out_proj = nn.Linear(config.d_model, config.d_model, bias=True)

    def forward(self, query, key, value, mask=None):
        """
        query: (batch_size, seq_len_q, d_model)
        key: (batch_size, seq_len_k, d_model)
        value: (batch_size, seq_len_v, d_model)
        mask: (batch_size, 1, seq_len_q, seq_len_k) or None
        """
        batch_size = query.size(0)

        q = einops.rearrange(
            self.q_proj(query),
            "batch seq_len_q d_model -> batch heads seq_len_q d_k",
            heads=self.num_heads,
            d_k=self.d_k,
        )

        k = einops.rearrange(
            self.k_proj(key),
            "batch seq_len_k d_model -> batch heads seq_len_k d_k",
            heads=self.num_heads,
            d_k=self.d_k,
        )

        v = einops.rearrange(
            self.v_proj(value),
            "batch seq_len_v d_model -> batch heads seq_len_v d_v",
            heads=self.num_heads,
            d_v=self.d_k,  # Assuming d_v == d_k
        )

        # Compute attention
        attn_output = scaled_dot_product_attention(q, k, v, mask)

        # Rearrange back to (batch_size, seq_len_q, d_model)
        attn_output = einops.rearrange(
            attn_output,
            "batch heads seq_len_q d_v -> batch seq_len_q (heads d_v)",
            heads=self.num_heads,
        )

### Encoder Block


In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()

        self.self_attn = MultiHeadAttention(config)
        self.ffn = FFN(config)
        self.ln1 = LayerNormalization(config)
        self.ln2 = LayerNormalization(config)

    def forward(self, x):
        out = self.self_attn(x)  # (batch_size, seq_len, d_model)
        out = self.ln1(out + x)  # Add & Norm
        out = self.ffn(out)  # (batch_size, seq_len, d_model
        out = self.ln2(out + x)  # Add & Norm
        return out  # (batch_size, seq_len, d_model

In [12]:
class DecoderBlock(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()

        self.self_attn = MultiHeadAttention(config)
        self.cross_attn = CrossAttention(config)
        self.ffn = FFN(config)
        self.ln1 = LayerNormalization(config)
        self.ln2 = LayerNormalization(config)
        self.ln3 = LayerNormalization(config)

    def forward(self, x, enc_output, mask=None):
        out = self.self_attn(x)  # Self-attention
        out = self.ln1(out + x)  # Add & Norm

        out = self.cross_attn(out, enc_output, enc_output, mask)  # Cross-attention
        out = self.ln2(out + x)  # Add & Norm

        out = self.ffn(out)  # Feedforward
        out = self.ln3(out + x)  # Add & Norm

        return out  # (batch_size, seq_len, d_model)

In [13]:
class Encoder(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()

        self.embedding = Embedding(config)
        self.layers = nn.ModuleList(
            [EncoderBlock(config) for _ in range(config.num_layers)]
        )
        self.ln = LayerNormalization(config)

    def forward(self, x):
        """
        x: (batch_size, seq_len)
        """
        x = self.embedding(x)
        for layer in self.layers:
            x = layer(x)
        x = self.ln(x)  # Final Layer Normalization
        return x  # (batch_size, seq_len, d_model)


class Decoder(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()

        self.embedding = Embedding(config, is_tgt=True)
        self.layers = nn.ModuleList(
            [DecoderBlock(config) for _ in range(config.num_layers)]
        )
        self.ln = LayerNormalization(config)

    def forward(self, x, enc_output, mask=None):
        """
        x: (batch_size, seq_len)
        enc_output: (batch_size, seq_len, d_model)
        """
        x = self.embedding(x)
        for layer in self.layers:
            x = layer(x, enc_output, mask)
        x = self.ln(x)  # Final Layer Normalization
        return x  # (batch_size, seq_len, d_model)

In [None]:
class Transformer(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()

        self.encoder = Encoder(config)
        self.decoder = Decoder(config)
        self.output_layer = nn.Linear(config.d_model, config.tgt_vocab_size)

    def forward(self, src, tgt, mask=None):
        """
        src: (batch_size, src_seq_len)
        tgt: (batch_size, tgt_seq_len)
        mask: (batch_size, 1, tgt_seq_len, src_seq_len) or None
        """
        enc_output = self.encoder(src)  # (batch_size, src_seq_len, d_model)
        dec_output = self.decoder(
            tgt, enc_output, mask
        )  # (batch_size, tgt_seq_len, d_model)
        output = self.output_layer(
            dec_output
        )  # (batch_size, tgt_seq_len, tgt_vocab_size)
        return output  # Final output logits

In [14]:
def create_causal_mask(seq_len_q, seq_len_k):
    """
    Create a causal mask for the attention mechanism.
    seq_len_q: Length of the query sequence
    seq_len_k: Length of the key sequence
    """
    mask = torch.triu(torch.ones(seq_len_q, seq_len_k), diagonal=1).bool()
    return mask.unsqueeze(0).unsqueeze(0)  # (1, 1, seq_len_q, seq_len_k)


def create_padding_mask(seq_len, padding_idx=0):
    """
    Create a padding mask for the attention mechanism.
    seq_len: Length of the sequence
    padding_idx: Index used for padding (default is 0)
    """
    mask = (torch.arange(seq_len).unsqueeze(0) == padding_idx).unsqueeze(
        1
    )  # (1, 1, seq_len)
    return mask  # (1, 1, seq_len, seq_len)