# Transformer Implementation from Scratch (Oct 21, 2022)

This notebook goes through an implementation of a transformer from scratch.

**Note:** this is intended as a teaching tool, not a practical implementation. Some details (e.g., initialization) are simplified.

In [15]:
from typing import NamedTuple
import torch
from torch import nn
from math import sqrt, sin, cos

In [3]:
class SiTransConfig(NamedTuple):
    """Wrapper object representing architectural hyperparameters."""
    n_vocab: int
    d_model: int
    d_hidden: int
    n_heads: int
    n_layers: int
    seq_len: int
    masked: bool = False
    biases: bool = False  # Add biases to the linear transformations.
    post_ln: bool = False  # Switch to post layer norm.
    scale_scores: bool = False  # Scale the attention weights by a sqrt factor.
    rel_embed: bool = False  # Should we used fixed relative positional embeddings instead of learned ones?
    p_drop: float = 0.  # Dropout probability.

In [4]:
config = SiTransConfig(
    n_vocab=10000,
    d_model=200,
    d_hidden=400,
    n_heads=20,
    n_layers=6,
    seq_len=512,
  )

### Self Attention

Let's implement a self attention mechanism.

In [5]:
class SiSelfAttention(nn.Module):
    def __init__(self, config: SiTransConfig):
        super().__init__()
        assert config.d_model % config.n_heads == 0
        d_head = config.d_model // config.n_heads
        self.query = nn.Linear(config.d_model, d_head, bias=config.biases)
        self.key = nn.Linear(config.d_model, d_head, bias=config.biases)
        self.value = nn.Linear(config.d_model, d_head, bias=config.biases)

        self.config = config
        self.d_head = d_head
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, encodings):
        queries = self.query(encodings)  # [seq_len, d_model / n_heads]
        keys = self.key(encodings)  # [seq_len, d_model / n_heads] 
        values = self.value(encodings)
        scores = torch.einsum("bti, bsi -> bts", queries, keys)  # [seq_len, seq_len]

        if self.config.scale_scores:
            scores = scores / sqrt(self.d_head)

        if self.config.masked:
            seq_len = scores.size(1)
            arange = torch.arange(seq_len, device=queries.device)
            mask = arange.unsqueeze(dim=0) <= arange.unsqueeze(dim=1)
            scores = mask.unsqueeze(dim=0) * scores

        weights = self.softmax(scores)
        # Weighted average of values (weighted by weights).
        return torch.einsum("bts, bsh -> bth", weights, values)

In the transformer, each column has multiple atttention heads that are pooled in parallel by a linear transformation. We now implement a module to wrap this.

In [6]:
class SiMultiHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.heads = [SiSelfAttention(config) for _ in range(config.n_heads)]
        for idx, head in enumerate(self.heads):
            self.add_module(f"head{idx}", head)
        
        self.pooler = torch.nn.Linear(config.d_model, config.d_model, bias=config.biases)
        self.lnorm = nn.LayerNorm(config.d_model)
        self.dropout = nn.Dropout(config.p_drop)
    
    def forward(self, encodings):
        heads = [head(encodings) for head in self.heads]
        outputs = self.pooler(torch.cat(heads, dim=-1))
        outputs = self.dropout(outputs)
        if not self.config.post_ln:
            return self.lnorm(outputs) + encodings
        else:
            return self.lnorm(outputs + encodings)

In [7]:
batch_size = 64

In [8]:
attn_sublayer = SiMultiHead(config)
inputs = torch.randn(size=[64, config.seq_len, config.d_model])
outputs = attn_sublayer(inputs)
outputs.shape

torch.Size([64, 512, 200])

### Feedforward Nets

In [9]:
class SiFeedforward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        d_model = config.d_model
        d_hidden = config.d_hidden
        self.net = nn.Sequential(
            nn.Linear(d_model, d_hidden, bias=config.biases),
            nn.ReLU(),
            nn.Linear(d_hidden, d_model, bias=config.biases),
        )
        self.lnorm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(config.p_drop)
    
    def forward(self, encodings):
        outputs = self.net(encodings)
        outputs = self.dropout(outputs)
        if not self.config.post_ln:
            return self.lnorm(outputs) + encodings
        else:
            return self.lnorm(outputs + encodings)

In [10]:
ff_sublayer = SiFeedforward(config)
inputs = torch.randn(size=[64, config.seq_len, config.d_model])
outputs = ff_sublayer(inputs)
outputs.shape

torch.Size([64, 512, 200])

### Transformer Encoder Stack

We now have enough to implement the stack of transformer layers that makes up the transformer encoder (note that this still excludes the initial token and positional embeddings).

In [11]:
class SiLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.heads = SiMultiHead(config)
        self.ff = SiFeedforward(config)
    
    def forward(self, encodings):
        return self.ff(self.heads(encodings))

In [12]:
layer = SiLayer(config)
inputs = torch.randn(size=[64, config.seq_len, config.d_model])
outputs = layer(inputs)
outputs.shape

torch.Size([64, 512, 200])

In [13]:
class SiEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layers = nn.ModuleList([SiLayer(config) for _ in range(config.n_layers)])

    def forward(self, encodings):
        for layer in self.layers:
            encodings = layer(encodings)
        return encodings

In [None]:
encoder = SiEncoder(config)
encoder.to("cuda")
inputs = torch.randn(size=[64, config.seq_len, config.d_model])
outputs = encoder(inputs)
outputs.shape

### Positional Embeddings

In [16]:
class SiEmbedder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embed = nn.Embedding(config.n_vocab, config.d_model)
        self.pos_embed = nn.Embedding(config.seq_len, config.d_model)

        if config.rel_embed:
            # Relative positional embeddings taken from https://arxiv.org/pdf/1706.03762.pdf.
            self.pos_embed.requires_grad = False
            embeddings = self.pos_embed.weight
            for pos in range(config.seq_len):
                for idx in range(config.d_model // 2):
                    embeddings[pos, 2 * idx] = sin(pos / 10000**(2 * idx / config.d_model))
                    embeddings[pos, 2 * idx + 1] = cos(pos / 10000**(2 * idx / config.d_model))

    def forward(self, token_ids):
        _, seq_len = token_ids.size()
        embeddings = self.embed(token_ids)
        positions = torch.arange(seq_len, device=token_ids.device)
        pos_embeddings = self.pos_embed(positions).unsqueeze(dim=0)
        return embeddings + pos_embeddings

### Putting it All Together

Our encoder-only transformer is then just the composition of the embedding layer with the encoder layer.

In [17]:
class SiTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embedder = SiEmbedder(config)
        # [batch_size, config.seq_len, config.d_model]
        self.encoder = SiEncoder(config)
        # [batch_size, config.seq_len, config.d_model]
    
    def forward(self, token_ids):
        embeddings = self.embedder(token_ids)
        return self.encoder(embeddings)

In [18]:
small_config = SiTransConfig(
    n_vocab=1000,
    d_model=50,
    d_hidden=100,
    n_heads=5,
    n_layers=3,
    seq_len=32,
  )

In [20]:
transformer = SiTransformer(small_config)
tokens = torch.arange(10).unsqueeze(0)
print("Tokens", tokens.shape)
vecs = transformer.forward(tokens)
print("Vecs", vecs.shape)

Tokens torch.Size([1, 10])
Vecs torch.Size([1, 10, 50])
