# Transformer from Scratch



In [1]:
import torch
import torch.nn as nn
import numpy as np

In [2]:
import logging
logger = logging.getLogger("tensor_shapes")
handler = logging.StreamHandler()
formatter = logging.Formatter(
        '%(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(1)

In [3]:
import inspect
def getclass():
    stack = inspect.stack()
    return stack[3][0].f_locals["self"].__class__

# A helper function to check how tensor sizes change
def log_size(tsr: torch.Tensor, name: str):
    cls = getclass()
    logger.log(level=cls.level, msg=[{cls.__name__}, {name}, {tsr.shape}])

In [4]:
from enum import IntEnum
# Control how much debugging output we want
class TensorLoggingLevels(IntEnum):
    attention = 1
    attention_head = 2
    multihead_attention_block = 3
    enc_dec_block = 4
    enc_dec = 5

In [5]:
class Dim(IntEnum):
    batch = 0
    seq = 1
    feature = 2

# Components


### Scaled dot product attention

$$ \textrm{Attention}(Q, K, V) = \textrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V $$

In [6]:
import math 

class ScaledDotProductAttention(nn.Module):
    level = TensorLoggingLevels.attention
    def __init__(self, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, q, k, v, mask=None):
        d_k = k.size(-1)
        assert q.size(-1) == d_k
        
        # Compute the dot product between queries and keys for each batch and position in the sequence
        attn = torch.bmm(q, k.transpose(Dim.seq, Dim.feature))
        
        attn = attn / math.sqrt(d_k)
        
        attn = torch.exp(attn)
        
        log_size(attn, "attention weight") # Batch, Seq, Seq
        
        if mask is not None:
            attn = attn.masked_fill(mask, 0)
        attn = attn / attn.sum(dim=-1, keepdim=True)
        attn = self.dropout(attn)
        output = torch.bmm(attn, v) # (Batch, Seq, Feature)
        log_size(output, "attention output size") # (Batch, Seq, Seq)
        return output

In [7]:
attn = ScaledDotProductAttention()

In [8]:
q = torch.rand(5, 10, 20)
k = torch.rand(5, 10, 20)
v = torch.rand(5, 10, 20)

In [9]:
attn(q, k, v)

[{'ScaledDotProductAttention'}, {'attention weight'}, {torch.Size([5, 10, 10])}]
[{'ScaledDotProductAttention'}, {'attention output size'}, {torch.Size([5, 10, 20])}]


tensor([[[0.5911, 0.5458, 0.4365, 0.5171, 0.6749, 0.4671, 0.4404, 0.5136,
          0.7278, 0.4326, 0.3734, 0.4543, 0.4597, 0.5284, 0.5157, 0.6071,
          0.7912, 0.6423, 0.5359, 0.3856],
         [0.5981, 0.5506, 0.4611, 0.5268, 0.6753, 0.5026, 0.4653, 0.4943,
          0.7193, 0.4539, 0.3836, 0.4768, 0.4588, 0.5002, 0.5198, 0.5863,
          0.7832, 0.6498, 0.5481, 0.3841],
         [0.5686, 0.4442, 0.4441, 0.4966, 0.6489, 0.4964, 0.4433, 0.3829,
          0.6438, 0.4285, 0.3321, 0.4568, 0.3883, 0.4378, 0.4915, 0.5611,
          0.6876, 0.6509, 0.5434, 0.3746],
         [0.5874, 0.5380, 0.4460, 0.5247, 0.6819, 0.4935, 0.4663, 0.5041,
          0.7309, 0.4440, 0.3720, 0.4575, 0.4571, 0.5184, 0.5227, 0.5911,
          0.7923, 0.6545, 0.5287, 0.3792],
         [0.5865, 0.5624, 0.4464, 0.5167, 0.6713, 0.4842, 0.4540, 0.5115,
          0.7174, 0.4444, 0.3797, 0.4850, 0.4645, 0.5060, 0.5152, 0.5960,
          0.7918, 0.6428, 0.5473, 0.3746],
         [0.5927, 0.5666, 0.4519, 0.5179, 0.6

### Multi Head Attention

In [10]:
class AttentionHead(nn.Module):
    level = TensorLoggingLevels.attention_head
    def __init__(self, d_model, d_feature, dropout = 0.1):
        super().__init__()
        # We assume that the queries, keys, features all have the same feature size.
        self.attn = ScaledDotProductAttention(dropout)
        self.query_tfm = nn.Linear(d_model, d_feature)
        self.key_tfm = nn.Linear(d_model, d_feature)
        self.value_tfm = nn.Linear(d_model, d_feature)
        
    def forward(self, queries, keys, values, mask=None):
        Q = self.query_tfm(queries)
        K = self.key_tfm(keys)
        V = self.value_tfm(values)
        log_size(Q, "queries, keys, vals")
        
        x = self.attn(Q, K, V)
        return x

In [11]:
attn_head = AttentionHead(20, 20)
attn_head(q, k, v)

[{'AttentionHead'}, {'queries, keys, vals'}, {torch.Size([5, 10, 20])}]
[{'ScaledDotProductAttention'}, {'attention weight'}, {torch.Size([5, 10, 10])}]
[{'ScaledDotProductAttention'}, {'attention output size'}, {torch.Size([5, 10, 20])}]


tensor([[[-2.0784e-01,  2.2152e-01,  4.5362e-01, -1.5373e-01,  1.2132e-01,
           2.8852e-01,  9.0924e-02, -7.1280e-02,  4.8409e-01, -3.6644e-01,
           8.2507e-02, -1.2756e-01,  2.3375e-01, -1.1124e-01,  2.2679e-01,
          -4.7546e-01, -3.7602e-01,  4.1561e-02,  1.3608e-01, -3.1414e-02],
         [-1.9894e-01,  1.9652e-01,  3.8503e-01, -1.2525e-01,  1.0525e-01,
           2.5120e-01,  1.0919e-01, -3.3621e-02,  4.4273e-01, -2.8875e-01,
           5.2675e-02, -1.1692e-01,  1.9413e-01, -1.1617e-01,  2.1346e-01,
          -3.8980e-01, -3.7404e-01,  6.3505e-03,  1.6117e-01, -5.6203e-02],
         [-2.0458e-01,  2.2366e-01,  4.5607e-01, -1.5232e-01,  1.2111e-01,
           2.9050e-01,  8.6029e-02, -6.8638e-02,  4.8368e-01, -3.7043e-01,
           8.1905e-02, -1.2623e-01,  2.2998e-01, -1.1342e-01,  2.2555e-01,
          -4.7172e-01, -3.7471e-01,  4.6374e-02,  1.3519e-01, -2.6267e-02],
         [-2.0544e-01,  2.2270e-01,  4.5883e-01, -1.5014e-01,  1.1762e-01,
           2.8600e-01,

The multi head attention block applies multiple attention heads as can be seen in the paper "Attention is all you need", then concatenates the output and applies single linear projection.

In [12]:
logger.setLevel(TensorLoggingLevels.attention_head)

In [13]:
class MultiHeadAttention(nn.Module):
    level = TensorLoggingLevels.multihead_attention_block
    def __init__(self, d_model, d_feature, n_heads, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.d_feature = d_feature
        self.n_heads = n_heads
        
        assert d_model == d_feature * n_heads
        
        self.attn_heads = nn.ModuleList([
            AttentionHead(d_model, d_feature, dropout) for _ in range(n_heads)
        ])
        self.projection = nn.Linear(d_feature * n_heads, d_model)
        
    def forward(self, queries, keys, values, mask=None):
        log_size(queries, "Input queries")
        x = [attn(queries, keys, values, mask=mask) # (Batch, Seq, Feature)
             for i, attn in enumerate(self.attn_heads)]
        log_size(x[0], "Output of single head")
        
        #reconcatenate
        x = torch.cat(x, dim=Dim.feature) # (Batch, Sequence, D_Feature * n_heads)
        log_size(x, "Concatenated output") 
        x = self.projection(x) # (Batch, Sequence, D_model)
        log_size(x, "projected output")
        return x

In [14]:
heads = MultiHeadAttention(20 * 8, 20, 8)
heads(q.repeat(1, 1, 8), 
      k.repeat(1, 1, 8), 
      v.repeat(1, 1, 8))

[{'MultiHeadAttention'}, {'Input queries'}, {torch.Size([5, 10, 160])}]
[{'AttentionHead'}, {'queries, keys, vals'}, {torch.Size([5, 10, 20])}]
[{'AttentionHead'}, {'queries, keys, vals'}, {torch.Size([5, 10, 20])}]
[{'AttentionHead'}, {'queries, keys, vals'}, {torch.Size([5, 10, 20])}]
[{'AttentionHead'}, {'queries, keys, vals'}, {torch.Size([5, 10, 20])}]
[{'AttentionHead'}, {'queries, keys, vals'}, {torch.Size([5, 10, 20])}]
[{'AttentionHead'}, {'queries, keys, vals'}, {torch.Size([5, 10, 20])}]
[{'AttentionHead'}, {'queries, keys, vals'}, {torch.Size([5, 10, 20])}]
[{'AttentionHead'}, {'queries, keys, vals'}, {torch.Size([5, 10, 20])}]
[{'MultiHeadAttention'}, {'Output of single head'}, {torch.Size([5, 10, 20])}]
[{'MultiHeadAttention'}, {'Concatenated output'}, {torch.Size([5, 10, 160])}]
[{'MultiHeadAttention'}, {'projected output'}, {torch.Size([5, 10, 160])}]


tensor([[[ 1.2458e-01, -4.8857e-01,  6.1657e-02,  ...,  4.4601e-01,
          -5.2292e-02, -4.9479e-02],
         [ 1.1982e-01, -4.7597e-01,  8.6063e-02,  ...,  4.6261e-01,
          -9.0486e-02, -3.3123e-02],
         [ 1.3382e-01, -4.6789e-01,  8.6461e-02,  ...,  4.5981e-01,
          -5.1542e-02, -3.1109e-04],
         ...,
         [ 1.2216e-01, -4.7918e-01,  8.1664e-02,  ...,  4.2958e-01,
          -6.2201e-02, -3.9964e-02],
         [ 9.1237e-02, -4.5377e-01,  8.3547e-02,  ...,  4.7651e-01,
          -7.7586e-02, -8.5860e-02],
         [ 5.9553e-02, -4.2683e-01,  9.0134e-02,  ...,  4.5960e-01,
          -8.3119e-02, -9.4775e-02]],

        [[ 1.6726e-01, -4.6402e-01,  4.6791e-02,  ...,  4.1466e-01,
          -1.0124e-01, -1.1301e-01],
         [ 1.9437e-01, -4.9319e-01, -1.1182e-02,  ...,  4.1176e-01,
          -1.3245e-01, -1.0238e-01],
         [ 1.8208e-01, -4.8433e-01,  8.1616e-03,  ...,  4.2897e-01,
          -1.4379e-01, -1.2226e-01],
         ...,
         [ 1.7532e-01, -4

### The Encoder

The encoder is made up of the following components:
- multi-head attention block
- simple feedforward neural network

These components are connected using residual connections and layer normalization.

In [15]:
logger.setLevel(TensorLoggingLevels.multihead_attention_block)

In [16]:
class LayerNorm(nn.Module):
    def __init__(self, d_model, eps =1e-8):
        super(LayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps
        
    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta

In [18]:
class EncoderBlock(nn.Module):
    level = TensorLoggingLevels.enc_dec_block
    def __init__(self, d_model=512, d_feature=64, d_ff=2048, n_heads=8, dropout=0.1):
        super().__init__()
        self.attn_head = MultiHeadAttention(d_model, d_feature, n_heads, dropout)
        self.layer_norml = LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.position_wise_feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        self.layer_norm2 = LayerNorm(d_model)
        
    def forward(self, x, mask=None):
        log_size(x, "Encoder block input")
        attn = self.attn_head(x, x, x, mask=mask)
        log_size(x, "Attention Output")
        # Applying normalization and residual connection.
        x = x + self.dropout(self.layer_norml(attn))
        # Applying position-wise feedforward network
        pos = self.position_wise_feed_forward(x)
        log_size(x, "Feedforward output")
        # Applying normalization and residual connection
        x = x + self.dropout(self.layer_norm2(pos))
        log_size(x, "Encoder size output")
        return x

In [19]:
enc = EncoderBlock()

In [20]:
enc(torch.rand(5, 10, 512))

[{'EncoderBlock'}, {'Encoder block input'}, {torch.Size([5, 10, 512])}]
[{'MultiHeadAttention'}, {'Input queries'}, {torch.Size([5, 10, 512])}]
[{'MultiHeadAttention'}, {'Output of single head'}, {torch.Size([5, 10, 64])}]
[{'MultiHeadAttention'}, {'Concatenated output'}, {torch.Size([5, 10, 512])}]
[{'MultiHeadAttention'}, {'projected output'}, {torch.Size([5, 10, 512])}]
[{'EncoderBlock'}, {'Attention Output'}, {torch.Size([5, 10, 512])}]
[{'EncoderBlock'}, {'Feedforward output'}, {torch.Size([5, 10, 512])}]
[{'EncoderBlock'}, {'Encoder size output'}, {torch.Size([5, 10, 512])}]


tensor([[[ 0.5833,  3.1703,  2.7410,  ...,  0.8939, -1.1184, -1.0422],
         [ 1.5682,  3.0529,  1.9326,  ..., -0.7414, -0.8055, -1.8851],
         [ 1.4794,  1.9777,  2.1705,  ...,  0.3330, -1.0011, -0.2723],
         ...,
         [ 1.0954,  2.7196,  2.5938,  ..., -0.9390, -1.0696,  0.1681],
         [ 2.5727,  2.4877,  3.5313,  ..., -0.8972, -1.5018, -0.5330],
         [-0.7734,  2.0177,  2.0921,  ...,  0.1652, -1.1879, -0.7408]],

        [[ 1.1904,  2.4746,  2.8082,  ..., -1.8894, -2.0657,  0.5893],
         [ 3.3139,  2.7497,  3.2283,  ..., -0.9785, -1.8251, -0.8296],
         [ 1.7955,  1.5634,  3.3610,  ..., -1.5514, -1.3444, -0.5784],
         ...,
         [ 2.3839,  1.9991,  2.3468,  ..., -1.4291, -1.9721, -0.6121],
         [ 2.2030,  2.3667,  1.1972,  ..., -1.9345, -1.0124, -1.1355],
         [ 1.3757,  2.1950,  2.9918,  ...,  1.1616, -0.6135, -1.0362]],

        [[ 2.9574,  0.8282,  2.0989,  ..., -0.9618, -0.1570,  1.7864],
         [ 2.0363,  2.1637,  2.5635,  ...,  1

The Encoder is having six consecutive encoder blocks, thus:

In [21]:
class TransformerEncoder(nn.Module):
    level = TensorLoggingLevels.enc_dec
    def __init__(self, n_blocks = 6, d_model = 512, n_heads = 8, d_ff = 2048, dropout = 0.1):
        super().__init__()
        self.encoders = nn.ModuleList([EncoderBlock(d_model = d_model, d_feature = d_model // n_heads, 
                                                   d_ff = f_ff, dropout = dropout)
                                       for _ in range(n_blocks)
                                      ])
        
        def forward(self, x: torch.FloatTensor, mask = None):
            for encoder in self.encoders:
                x = encoder(x)
            return(x)

### The Decoder

The decoder is same in structure as the encoder with just one additional multi-head attention block that takes the target sentence as input.

In [30]:
class DecoderBlock(nn.Module):
    level = TensorLoggingLevels.enc_dec_block
    def __init__(self, d_model=512, d_feature=64, d_ff=2048, n_heads=8, dropout=0.1):
        super().__init__()
        self.masked_attn_head = MultiHeadAttention(d_model, d_feature, n_heads, dropout)
        self.attn_head = MultiHeadAttention(d_model, d_feature, n_heads, dropout)
        self.position_wise_feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
        )

        self.layer_norm1 = LayerNorm(d_model)
        self.layer_norm2 = LayerNorm(d_model)
        self.layer_norm3 = LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_out, 
                src_mask=None, tgt_mask=None):
        # Applying attention to inputs
        att = self.masked_attn_head(x, x, x, mask=src_mask)
        x = x + self.dropout(self.layer_norm1(att))
        # Applying attention to the encoder outputs and outputs of the previous layer
        att = self.attn_head(queries=x, keys=enc_out, values=enc_out, mask=tgt_mask)
        x = x + self.dropout(self.layer_norm2(att))
        # Applying position-wise feedforward network
        pos = self.position_wise_feed_forward(x)
        x = x + self.dropout(self.layer_norm2(pos))
        return x

In [31]:
dec = DecoderBlock()
dec(torch.rand(5, 10, 512), enc(torch.rand(5, 10, 512)))

[{'EncoderBlock'}, {'Encoder block input'}, {torch.Size([5, 10, 512])}]
[{'MultiHeadAttention'}, {'Input queries'}, {torch.Size([5, 10, 512])}]
[{'MultiHeadAttention'}, {'Output of single head'}, {torch.Size([5, 10, 64])}]
[{'MultiHeadAttention'}, {'Concatenated output'}, {torch.Size([5, 10, 512])}]
[{'MultiHeadAttention'}, {'projected output'}, {torch.Size([5, 10, 512])}]
[{'EncoderBlock'}, {'Attention Output'}, {torch.Size([5, 10, 512])}]
[{'EncoderBlock'}, {'Feedforward output'}, {torch.Size([5, 10, 512])}]
[{'EncoderBlock'}, {'Encoder size output'}, {torch.Size([5, 10, 512])}]
[{'MultiHeadAttention'}, {'Input queries'}, {torch.Size([5, 10, 512])}]
[{'MultiHeadAttention'}, {'Output of single head'}, {torch.Size([5, 10, 64])}]
[{'MultiHeadAttention'}, {'Concatenated output'}, {torch.Size([5, 10, 512])}]
[{'MultiHeadAttention'}, {'projected output'}, {torch.Size([5, 10, 512])}]
[{'MultiHeadAttention'}, {'Input queries'}, {torch.Size([5, 10, 512])}]
[{'MultiHeadAttention'}, {'Output of

tensor([[[-1.4330,  1.2729,  0.6011,  ...,  1.9798, -0.1023,  3.0183],
         [-1.9039,  0.9775,  1.4220,  ...,  1.6646, -0.1382,  3.1742],
         [ 0.1228,  1.1531,  0.2262,  ...,  0.5965, -0.5110,  3.5058],
         ...,
         [-0.5174,  2.7846,  2.5088,  ...,  1.5095, -0.7713,  2.4147],
         [-0.2419,  1.3183,  1.0722,  ...,  1.6039,  0.8515,  2.8190],
         [-0.9381,  1.7539,  0.3402,  ...,  0.2018,  0.3287,  3.3271]],

        [[ 0.4451,  0.6623,  1.3989,  ...,  2.1048, -0.0888,  2.2787],
         [ 0.2795,  0.4871,  1.4640,  ...,  1.6827, -0.3780,  3.4613],
         [-0.7752,  0.3807,  1.8392,  ...,  1.5982, -1.5318,  2.7350],
         ...,
         [ 0.6048,  1.2919,  1.8496,  ...,  1.8464, -0.2870,  1.0968],
         [ 0.3051, -0.2096,  1.5842,  ...,  0.4973, -0.5270,  3.5191],
         [-0.1666,  0.8657,  1.3226,  ...,  0.4381, -0.4392,  3.1127]],

        [[ 0.7625,  2.0394,  1.7812,  ...,  0.6427, -1.0682,  1.4749],
         [-0.2428,  1.9672,  1.7076,  ...,  1

In [36]:
class TransformerDecoder(nn.Module):
    level = TensorLoggingLevels.enc_dec
    def __init__(self, n_blocks=6, d_model=512, d_feature=64,d_ff=2048, n_heads=8, dropout=0.1):
        super().__init__()
        self.position_embedding = PositionalEmbedding(d_model)
        self.decoders = nn.ModuleList([
            DecoderBlock(d_model=d_model, d_feature=d_model // n_heads, d_ff=d_ff, dropout=dropout)
            for _ in range(n_blocks)
        ])
        
    def forward(self, x: torch.FloatTensor, enc_out: torch.FloatTensor, src_mask=None, tgt_mask=None):
        for decoder in self.decoders:
            x = decoder(x, enc_out, src_mask=src_mask, tgt_mask=tgt_mask)
        return x

### Positional Embeddings

Attention blocks don't have any notion of word order in a sentence. The Transformer explicitly adda the positional information via the positional embeddings.

In [None]:
class PositionalEmbedding(nn.Module):
    level = 1
    def __init__(self, d_model, max_len=512):
        super().__init__()