# Transformer from Scratch



In [5]:
import torch
import torch.nn as nn
import numpy as np

In [6]:
import logging
logger = logging.getLogger("tensor_shapes")
handler = logging.StreamHandler()
formatter = logging.Formatter(
        '%(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(1)

In [7]:
import inspect
def getclass():
    stack = inspect.stack()
    return stack[3][0].f_locals["self"].__class__

# A helper function to check how tensor sizes change
def log_size(tsr: torch.Tensor, name: str):
    cls = getclass()
    logger.log(level=cls.level, msg=[{cls.__name__}, {name}, {tsr.shape}])

In [8]:
from enum import IntEnum
# Control how much debugging output we want
class TensorLoggingLevels(IntEnum):
    attention = 1
    attention_head = 2
    multihead_attention_block = 3
    enc_dec_block = 4
    enc_dec = 5

In [9]:
class Dim(IntEnum):
    batch = 0
    seq = 1
    feature = 2

# Components


### Scaled dot product attention

$$ \textrm{Attention}(Q, K, V) = \textrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V $$

In [10]:
import math 

class ScaledDotProductAttention(nn.Module):
    level = TensorLoggingLevels.attention
    def __init__(self, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, q, k, v, mask=None):
        d_k = k.size(-1)
        assert q.size(-1) == d_k
        
        # Compute the dot product between queries and keys for each batch and position in the sequence
        attn = torch.bmm(q, k.transpose(Dim.seq, Dim.feature))
        
        attn = attn / math.sqrt(d_k)
        
        attn = torch.exp(attn)
        
        log_size(attn, "attention weight") # Batch, Seq, Seq
        
        if mask is not None:
            attn = attn.masked_fill(mask, 0)
        attn = attn / attn.sum(dim=-1, keepdim=True)
        attn = self.dropout(attn)
        output = torch.bmm(attn, v) # (Batch, Seq, Feature)
        log_size(output, "attention output size") # (Batch, Seq, Seq)
        return output

In [11]:
attn = ScaledDotProductAttention()

In [12]:
q = torch.rand(5, 10, 20)
k = torch.rand(5, 10, 20)
v = torch.rand(5, 10, 20)

In [13]:
attn(q, k, v)

[{'ScaledDotProductAttention'}, {'attention weight'}, {torch.Size([5, 10, 10])}]
[{'ScaledDotProductAttention'}, {'attention output size'}, {torch.Size([5, 10, 20])}]


tensor([[[0.5696, 0.5328, 0.5473, 0.5663, 0.6148, 0.5105, 0.3823, 0.5348,
          0.7459, 0.5658, 0.5710, 0.4206, 0.4794, 0.5721, 0.5688, 0.5083,
          0.4243, 0.5707, 0.6328, 0.5048],
         [0.4731, 0.5220, 0.5115, 0.5091, 0.5406, 0.4984, 0.3474, 0.4973,
          0.6399, 0.4524, 0.4859, 0.3546, 0.3890, 0.5669, 0.4984, 0.4367,
          0.3912, 0.5467, 0.5609, 0.4470],
         [0.5654, 0.5371, 0.5451, 0.5681, 0.6154, 0.5152, 0.3812, 0.5335,
          0.7432, 0.5596, 0.5680, 0.4191, 0.4731, 0.5773, 0.5658, 0.5069,
          0.4304, 0.5720, 0.6331, 0.4998],
         [0.4588, 0.5145, 0.5530, 0.5379, 0.5367, 0.4244, 0.3039, 0.4629,
          0.6540, 0.5616, 0.4676, 0.4100, 0.4545, 0.5671, 0.4985, 0.3931,
          0.3959, 0.5330, 0.5193, 0.4889],
         [0.5563, 0.5464, 0.5749, 0.5602, 0.6286, 0.5011, 0.3800, 0.5494,
          0.7206, 0.5600, 0.5735, 0.4246, 0.4733, 0.5853, 0.5847, 0.4873,
          0.4274, 0.5876, 0.6218, 0.5155],
         [0.5564, 0.5539, 0.5706, 0.5649, 0.6

### Multi Head Attention

In [14]:
class AttentionHead(nn.Module):
    level = TensorLoggingLevels.attention_head
    def __init__(self, d_model, d_feature, dropout = 0.1):
        super().__init__()
        # We assume that the queries, keys, features all have the same feature size.
        self.attn = ScaledDotProductAttention(dropout)
        self.query_tfm = nn.Linear(d_model, d_feature)
        self.key_tfm = nn.Linear(d_model, d_feature)
        self.value_tfm = nn.Linear(d_model, d_feature)
        
    def forward(self, queries, keys, values, mask=None):
        Q = self.query_tfm(queries)
        K = self.key_tfm(keys)
        V = self.value_tfm(values)
        log_size(Q, "queries, keys, vals")
        
        x = self.attn(Q, K, V)
        return x

In [15]:
attn_head = AttentionHead(20, 20)
attn_head(q, k, v)

[{'AttentionHead'}, {'queries, keys, vals'}, {torch.Size([5, 10, 20])}]
[{'ScaledDotProductAttention'}, {'attention weight'}, {torch.Size([5, 10, 10])}]
[{'ScaledDotProductAttention'}, {'attention output size'}, {torch.Size([5, 10, 20])}]


tensor([[[-1.3415e-01,  4.1547e-01, -1.2958e-01, -1.6530e-01, -1.4366e-02,
           1.2347e-01,  2.8432e-01,  5.0719e-02, -4.1539e-01,  3.5242e-01,
           1.8646e-01, -2.7570e-01, -1.8071e-01, -2.6601e-01, -2.8802e-02,
           1.5071e-01, -1.0538e-01,  3.4945e-01,  2.2730e-01, -2.0590e-02],
         [-1.3463e-01,  4.2426e-01, -1.3946e-01, -8.6122e-02, -4.6897e-02,
           1.1551e-01,  3.0901e-01,  1.4901e-01, -5.6183e-01,  3.9037e-01,
           1.9426e-01, -2.8315e-01, -1.7956e-01, -2.9039e-01, -5.0954e-02,
           9.6262e-02, -2.1555e-01,  4.1239e-01,  3.3244e-01, -5.3709e-02],
         [-1.5827e-01,  4.7342e-01, -1.1925e-01, -1.6868e-01,  4.4889e-02,
           1.7229e-01,  3.0185e-01,  7.8513e-02, -4.6781e-01,  4.1151e-01,
           2.4249e-01, -3.0184e-01, -1.3892e-01, -2.4401e-01, -1.5441e-02,
           1.3222e-01, -9.2620e-02,  3.5561e-01,  3.1089e-01, -1.1497e-02],
         [-1.3133e-01,  4.3470e-01, -1.4057e-01, -1.0172e-01, -2.9635e-02,
           1.3707e-01,

The multi head attention block applies multiple attention heads as can be seen in the paper "Attention is all you need", then concatenates the output and applies single linear projection.

In [16]:
logger.setLevel(TensorLoggingLevels.attention_head)

In [20]:
class MultiHeadAttention(nn.Module):
    level = TensorLoggingLevels.multihead_attention_block
    def __init__(self, d_model, d_feature, n_heads, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.d_feature = d_feature
        self.n_heads = n_heads
        
        assert d_model == d_feature * n_heads
        
        self.attn_heads = nn.ModuleList([
            AttentionHead(d_model, d_feature, dropout) for _ in range(n_heads)
        ])
        self.projection = nn.Linear(d_feature * n_heads, d_model)
        
    def forward(self, queries, keys, values, mask=None):
        log_size(queries, "Input queries")
        x = [attn(queries, keys, values, mask=mask) # (Batch, Seq, Feature)
             for i, attn in enumerate(self.attn_heads)]
        log_size(x[0], "Output of single head")
        
        #reconcatenate
        x = torch.cat(x, dim=Dim.feature) # (Batch, Sequence, D_Feature * n_heads)
        log_size(x, "Concatenated output") 
        x = self.projection(x) # (Batch, Sequence, D_model)
        log_size(x, "projected output")
        return x

In [21]:
heads = MultiHeadAttention(20 * 8, 20, 8)
heads(q.repeat(1, 1, 8), 
      k.repeat(1, 1, 8), 
      v.repeat(1, 1, 8))

[{'MultiHeadAttention'}, {'Input queries'}, {torch.Size([5, 10, 160])}]
[{'AttentionHead'}, {'queries, keys, vals'}, {torch.Size([5, 10, 20])}]
[{'AttentionHead'}, {'queries, keys, vals'}, {torch.Size([5, 10, 20])}]
[{'AttentionHead'}, {'queries, keys, vals'}, {torch.Size([5, 10, 20])}]
[{'AttentionHead'}, {'queries, keys, vals'}, {torch.Size([5, 10, 20])}]
[{'AttentionHead'}, {'queries, keys, vals'}, {torch.Size([5, 10, 20])}]
[{'AttentionHead'}, {'queries, keys, vals'}, {torch.Size([5, 10, 20])}]
[{'AttentionHead'}, {'queries, keys, vals'}, {torch.Size([5, 10, 20])}]
[{'AttentionHead'}, {'queries, keys, vals'}, {torch.Size([5, 10, 20])}]
[{'MultiHeadAttention'}, {'Output of single head'}, {torch.Size([5, 10, 20])}]
[{'MultiHeadAttention'}, {'Concatenated output'}, {torch.Size([5, 10, 160])}]
[{'MultiHeadAttention'}, {'projected output'}, {torch.Size([5, 10, 160])}]


tensor([[[-0.1281, -0.1040, -0.1596,  ..., -0.0451, -0.0963, -0.2033],
         [-0.1508, -0.1195, -0.1500,  ..., -0.0101, -0.1618, -0.2477],
         [-0.1222, -0.0916, -0.1591,  ..., -0.0591, -0.1288, -0.2281],
         ...,
         [-0.1259, -0.1078, -0.1386,  ..., -0.0627, -0.1354, -0.1823],
         [-0.1296, -0.1241, -0.1671,  ..., -0.0396, -0.1452, -0.2197],
         [-0.1626, -0.1016, -0.1604,  ..., -0.0320, -0.1580, -0.2511]],

        [[-0.0885, -0.1208, -0.1718,  ...,  0.0211, -0.1654, -0.2570],
         [-0.0890, -0.1509, -0.2155,  ...,  0.0176, -0.1603, -0.2447],
         [-0.0618, -0.1015, -0.2157,  ...,  0.0111, -0.1502, -0.1866],
         ...,
         [-0.1170, -0.1258, -0.1726,  ...,  0.0263, -0.1808, -0.2803],
         [-0.0922, -0.1131, -0.2260,  ..., -0.0078, -0.1462, -0.2483],
         [-0.0864, -0.0788, -0.2351,  ..., -0.0040, -0.1427, -0.1954]],

        [[-0.0956, -0.1120, -0.1257,  ..., -0.0546, -0.1430, -0.1628],
         [-0.0746, -0.1097, -0.1423,  ..., -0