# Building the Transformer from Scratch

In [1]:
import torch
import torch.nn as nn
import numpy as np

In [2]:
class Dim():
    batch = 0
    seq = 1
    feature = 2

# Components

### Scaled dot product attention

In [3]:
import math

class ScaledDotProductAttention(nn.Module):
    def __init__(self, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, mask=None):
        d_k = k.size(-1) # get the size of the key
        assert q.size(-1) == d_k

        # dot product between queries and keys for
        attn = torch.bmm(q, k.transpose(Dim.seq, Dim.feature)) # (Batch, Seq, Seq)

        # scale the dot products
        attn = attn / math.sqrt(d_k)
        # normalize the weights
        attn = torch.exp(attn)
        
        # fill attention weights with 0s where padded
        if mask is not None: attn = attn.masked_fill(mask, 0)
        attn = attn / attn.sum(dim=-1, keepdim=True)
        attn = self.dropout(attn)
        output = torch.bmm(attn, v) # (Batch, Seq, Feature)
        return output

In [4]:
attn = ScaledDotProductAttention()

In [5]:
q = torch.rand(5, 10, 20)
k = torch.rand(5, 10, 20)
v = torch.rand(5, 10, 20)

In [6]:
attn(q, k, v)

tensor([[[0.3688, 0.4383, 0.5014, 0.4978, 0.4820, 0.3852, 0.4309, 0.4215,
          0.6410, 0.4667, 0.3422, 0.4320, 0.3634, 0.2753, 0.4307, 0.4372,
          0.4270, 0.5259, 0.4113, 0.4472],
         [0.5009, 0.4991, 0.6007, 0.4829, 0.4348, 0.3807, 0.5088, 0.5450,
          0.6369, 0.5633, 0.4000, 0.4159, 0.3791, 0.3352, 0.4955, 0.5622,
          0.4917, 0.5233, 0.4838, 0.4445],
         [0.3862, 0.4085, 0.4401, 0.4321, 0.3374, 0.4063, 0.3365, 0.3027,
          0.3533, 0.3219, 0.2748, 0.2514, 0.2744, 0.3729, 0.4559, 0.3248,
          0.3482, 0.4053, 0.3264, 0.3662],
         [0.4966, 0.5104, 0.5825, 0.4997, 0.4202, 0.4039, 0.5239, 0.5089,
          0.6435, 0.5207, 0.4235, 0.4199, 0.3800, 0.3251, 0.5095, 0.5436,
          0.5076, 0.5465, 0.4710, 0.4859],
         [0.4617, 0.4794, 0.5759, 0.5622, 0.4839, 0.4299, 0.4735, 0.4497,
          0.6554, 0.5440, 0.3630, 0.4411, 0.4273, 0.3935, 0.5188, 0.4687,
          0.4655, 0.5489, 0.4153, 0.4540],
         [0.5680, 0.5485, 0.6615, 0.5762, 0.5

### Multi-Head Attention

In [7]:
class AttentionHead(nn.Module):
    def __init__(self, d_model, d_feature, dropout=0.1):
        super().__init__()
        # We will assume the queries, keys, and values all have the same feature size
        self.attn = ScaledDotProductAttention(dropout)
        self.query_tfm = nn.Linear(d_model, d_feature)
        self.key_tfm = nn.Linear(d_model, d_feature)
        self.value_tfm = nn.Linear(d_model, d_feature)

    def forward(self, queries, keys, values, mask=None):
        Q = self.query_tfm(queries) # (Batch, Seq, Feature)
        K = self.key_tfm(keys) # (Batch, Seq, Feature)
        V = self.value_tfm(values) # (Batch, Seq, Feature)
        # compute multiple attention weighted sums
        x = self.attn(Q, K, V)
        return x

In [8]:
attn_head = AttentionHead(20, 20)
attn_head(q, k, v)

tensor([[[ 1.9722e-01, -4.3217e-01, -9.5496e-02,  1.0858e-01,  7.4761e-02,
          -4.2638e-02, -5.9460e-02, -2.2475e-01,  3.8707e-01, -4.2695e-01,
           1.3618e-01,  4.7474e-01,  1.8776e-01, -6.5161e-01, -6.7375e-01,
           6.6575e-01,  1.5730e-01, -3.5106e-02,  4.1626e-01,  2.7541e-02],
         [ 1.9474e-01, -4.3145e-01, -9.6251e-02,  1.0705e-01,  7.4421e-02,
          -4.2501e-02, -6.2102e-02, -2.2431e-01,  3.8733e-01, -4.2657e-01,
           1.3499e-01,  4.7276e-01,  1.8463e-01, -6.5237e-01, -6.7276e-01,
           6.6625e-01,  1.5797e-01, -3.3872e-02,  4.1627e-01,  2.7826e-02],
         [ 1.7152e-01, -4.0612e-01, -9.1967e-02,  1.0535e-01,  8.4465e-02,
          -5.3387e-02, -7.9679e-02, -1.7022e-01,  3.6955e-01, -3.7688e-01,
           1.2617e-01,  4.6859e-01,  1.9201e-01, -5.9551e-01, -6.4006e-01,
           5.8396e-01,  1.3871e-01, -2.6731e-02,  3.2185e-01,  2.5310e-02],
         [ 1.4698e-01, -3.9507e-01, -1.0065e-01,  9.1711e-02,  6.3303e-02,
          -4.1928e-02,

In [9]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, d_feature, n_heads, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.d_feature = d_feature
        self.n_heads = n_heads
        assert d_model == d_feature * n_heads

        self.attn_heads = nn.ModuleList([AttentionHead(d_model, d_feature, dropout) for _ in range(n_heads)])
        self.projection = nn.Linear(d_feature * n_heads, d_model) 
    
    def forward(self, queries, keys, values, mask=None):
        
        x = [attn(queries, keys, values, mask=mask) # (Batch, Seq, Feature)
             for i, attn in enumerate(self.attn_heads)]
        
        # reconcatenate
        x = torch.cat(x, dim=Dim.feature) # (Batch, Seq, D_Feature * n_heads)
        x = self.projection(x) # (Batch, Seq, D_Model)
        return x

In [10]:
heads = MultiHeadAttention(20 * 8, 20, 8)
heads(q.repeat(1, 1, 8), 
      k.repeat(1, 1, 8), 
      v.repeat(1, 1, 8))

tensor([[[-0.0454,  0.0829, -0.0308,  ..., -0.0321,  0.1697, -0.0679],
         [-0.0674,  0.1207, -0.0623,  ..., -0.0312,  0.1778, -0.0592],
         [-0.0408,  0.1154, -0.0454,  ..., -0.0530,  0.1485, -0.0458],
         ...,
         [-0.0437,  0.1139, -0.0348,  ..., -0.0515,  0.1471, -0.0500],
         [-0.0593,  0.1264, -0.0421,  ..., -0.0457,  0.1367, -0.0620],
         [-0.0631,  0.1195, -0.0264,  ..., -0.0756,  0.1537, -0.0429]],

        [[-0.0454,  0.0465, -0.0337,  ..., -0.0048,  0.1210, -0.0700],
         [-0.0457,  0.0553, -0.0347,  ..., -0.0282,  0.1186, -0.0563],
         [-0.0414,  0.0568, -0.0474,  ..., -0.0223,  0.1106, -0.0638],
         ...,
         [-0.0416,  0.0475, -0.0648,  ..., -0.0124,  0.1568, -0.0717],
         [-0.0453,  0.0774, -0.0267,  ..., -0.0255,  0.1185, -0.0660],
         [-0.0474,  0.0514, -0.0467,  ..., -0.0239,  0.1319, -0.0720]],

        [[-0.0525,  0.1671,  0.0089,  ..., -0.0343,  0.1991, -0.0766],
         [-0.0860,  0.1522, -0.0355,  ..., -0

### The Encoder

![image](https://i2.wp.com/mlexplained.com/wp-content/uploads/2017/12/%E3%82%B9%E3%82%AF%E3%83%AA%E3%83%BC%E3%83%B3%E3%82%B7%E3%83%A7%E3%83%83%E3%83%88-2017-12-29-19.14.41.png?w=273)

Layer normalization is similar to batch normalization, but normalizes across the feature dimension instead of the batch dimension.

In [11]:
class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-8):
        super(LayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta

The encoder just stacks these together

In [12]:
class EncoderBlock(nn.Module):
    #level = TensorLoggingLevels.enc_dec_block
    def __init__(self, d_model=512, d_feature=64, d_ff=2048, n_heads=8, dropout=0.1):
        super().__init__()
        self.attn_head = MultiHeadAttention(d_model, d_feature, n_heads, dropout)
        self.layer_norm1 = LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.position_wise_feed_forward = nn.Sequential(nn.Linear(d_model, d_ff),nn.ReLU(),nn.Linear(d_ff, d_model),)
        self.layer_norm2 = LayerNorm(d_model)
        
    def forward(self, x, mask=None):
        att = self.attn_head(x, x, x, mask=mask)
        x = x + self.dropout(self.layer_norm1(att))
        pos = self.position_wise_feed_forward(x)
        x = x + self.dropout(self.layer_norm2(pos))
        return x

In [13]:
enc = EncoderBlock()

In [14]:
enc(torch.rand(5, 10, 512))

tensor([[[-0.7314, -0.9608,  1.9998,  ...,  0.2103, -1.5115,  1.8144],
         [ 0.0362, -0.0593,  0.7419,  ...,  0.7844, -1.7380,  0.8663],
         [-0.4143, -1.0543,  1.4343,  ..., -0.0920, -1.0956,  1.6937],
         ...,
         [-1.3945, -1.1517,  0.3591,  ...,  0.5138, -1.5281,  1.6051],
         [ 0.0715, -0.3030,  3.0468,  ...,  0.4999, -1.3052, -0.5414],
         [-0.7606, -0.2976,  1.5225,  ...,  0.5437, -0.6941, -1.0575]],

        [[-0.2045, -2.0709,  1.8474,  ...,  1.0257, -1.2517,  0.2383],
         [ 0.3157, -1.1443,  2.2165,  ...,  0.3258, -0.0598,  0.9028],
         [ 0.1473, -2.2727,  0.7920,  ...,  0.3123, -0.6273,  0.7053],
         ...,
         [-1.1327, -1.7765,  1.9225,  ...,  0.7935,  0.1383,  2.2534],
         [-0.7408, -1.2487,  1.1963,  ...,  0.7120, -1.0052,  1.2876],
         [ 0.3173, -1.9165,  2.3517,  ...,  1.2001, -0.7337,  0.2320]],

        [[-0.6179, -0.6417,  1.7838,  ...,  0.6559, -0.9354,  1.6094],
         [ 0.4143, -0.4535,  1.4773,  ...,  1

The encoder consists of 6 consecutive encoder blocks

In [15]:
class TransformerEncoder(nn.Module):
    def __init__(self, n_blocks=6, d_model=512, n_heads=8, d_ff=2048, dropout=0.1):
        super().__init__()
        self.encoders = nn.ModuleList([EncoderBlock(d_model=d_model, d_feature=d_model // n_heads,d_ff=d_ff, dropout=dropout)
            for _ in range(n_blocks)])
    
    def forward(self, x: torch.FloatTensor, mask=None):
        for encoder in self.encoders:
            x = encoder(x)
        return x

### The Decoder

![image](https://i1.wp.com/mlexplained.com/wp-content/uploads/2017/12/%E3%82%B9%E3%82%AF%E3%83%AA%E3%83%BC%E3%83%B3%E3%82%B7%E3%83%A7%E3%83%83%E3%83%88-2017-12-29-19.14.47.png?w=287)

The keys and values are the outputs of the encoder, and the queries are the outputs of the multi-head attention over the target entence embeddings.

In [16]:
class DecoderBlock(nn.Module):
    def __init__(self, d_model=512, d_feature=64, d_ff=2048, n_heads=8, dropout=0.1):
        super().__init__()
        self.masked_attn_head = MultiHeadAttention(d_model, d_feature, n_heads, dropout)
        self.attn_head = MultiHeadAttention(d_model, d_feature, n_heads, dropout)
        self.position_wise_feed_forward = nn.Sequential(nn.Linear(d_model, d_ff),nn.ReLU(),nn.Linear(d_ff, d_model),)

        self.layer_norm1 = LayerNorm(d_model)
        self.layer_norm2 = LayerNorm(d_model)
        self.layer_norm3 = LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_out, 
                src_mask=None, tgt_mask=None):
        att = self.masked_attn_head(x, x, x, mask=src_mask)
        x = x + self.dropout(self.layer_norm1(att))

        att = self.attn_head(queries=x, keys=enc_out, values=enc_out, mask=tgt_mask)
        x = x + self.dropout(self.layer_norm2(att))

        pos = self.position_wise_feed_forward(x)
        x = x + self.dropout(self.layer_norm2(pos))
        return x

In [17]:
dec = DecoderBlock()
dec(torch.rand(5, 10, 512), enc(torch.rand(5, 10, 512)))

tensor([[[-6.5143e-01,  4.0064e-01, -6.7979e-01,  ..., -3.1752e+00,
          -2.3786e-01,  1.5174e+00],
         [ 1.4182e+00,  4.7809e-01, -2.2517e-01,  ..., -1.5205e+00,
          -1.4127e-01,  4.1145e-01],
         [-5.8744e-01,  7.9096e-01,  2.1167e-01,  ..., -2.9919e+00,
          -2.7412e-01,  3.7926e-01],
         ...,
         [-5.9609e-01,  1.1183e+00,  1.1979e+00,  ..., -3.5470e+00,
           1.1903e-01, -6.7346e-01],
         [-4.8036e-01,  7.7844e-01, -2.3372e+00,  ..., -2.4234e+00,
          -7.4857e-01,  1.4273e+00],
         [-2.8494e-01,  5.8207e-01, -1.1237e+00,  ..., -2.6507e+00,
          -6.1364e-01,  7.5330e-01]],

        [[-2.1470e-03,  1.4436e+00,  7.0646e-01,  ..., -1.6222e+00,
          -8.7327e-02,  2.0979e+00],
         [-1.0924e+00,  1.2902e+00, -5.1652e-01,  ..., -2.9347e+00,
           2.8413e-01,  1.3538e+00],
         [-6.2192e-01,  7.1885e-01,  1.2365e+00,  ..., -3.3762e+00,
           9.6950e-01,  4.5544e-01],
         ...,
         [-6.4940e-01,  2

Again, the decoder is just a stack of the underlying block so is simple to implement.

In [18]:
class TransformerDecoder(nn.Module):
    def __init__(self, n_blocks=6, d_model=512, d_feature=64, d_ff=2048, n_heads=8, dropout=0.1):
        super().__init__()
        self.position_embedding = PositionalEmbedding(d_model)
        self.decoders = nn.ModuleList([ DecoderBlock(d_model=d_model, d_feature=d_model // n_heads, d_ff=d_ff, dropout=dropout)
            for _ in range(n_blocks) ])
        
    def forward(self, x: torch.FloatTensor, 
                enc_out: torch.FloatTensor, 
                src_mask=None, tgt_mask=None):
        for decoder in self.decoders:
            x = decoder(x, enc_out, src_mask=src_mask, tgt_mask=tgt_mask)
        return x

### Positional Embeddings

In [19]:
class PositionalEmbedding(nn.Module):
    level = 1
    def __init__(self, d_model, max_len=512):
        super().__init__()        
        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                             -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.weight = nn.Parameter(pe, requires_grad=False)
        
    def forward(self, x):
        return self.weight[:, :x.size(1), :] # (1, Seq, Feature)

In [20]:
class WordPositionEmbedding(nn.Module):
    level = 1
    def __init__(self, vocab_size, d_model=512):
        super().__init__()
        self.word_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = PositionalEmbedding(d_model)
        
    def forward(self, x: torch.LongTensor, mask=None) -> torch.FloatTensor:
        return self.word_embedding(x) + self.position_embedding(x)

In [21]:
emb = WordPositionEmbedding(1000)
encoder = TransformerEncoder()

In [22]:
encoder(emb(torch.randint(1000, (5, 30))))

tensor([[[ 3.8228, -2.0909, -3.5735,  ...,  7.5374, -0.6125,  3.9951],
         [ 1.9282, -3.8169,  2.0592,  ...,  4.8064,  1.6905,  1.4939],
         [ 2.1767, -2.8520,  0.3860,  ...,  3.2723,  1.0018,  0.4706],
         ...,
         [ 2.2393, -0.1292,  3.8167,  ...,  9.2700,  2.1734,  6.1297],
         [ 2.0363,  1.6934, -0.6265,  ...,  8.1327,  1.8512,  3.1298],
         [ 2.8362, -2.5242,  1.1126,  ...,  5.5564, -2.2145,  3.3185]],

        [[ 5.5173,  1.8708, -3.5587,  ...,  2.3345, -4.5050,  3.5049],
         [ 6.4314, -0.6686, -0.2301,  ...,  4.5585, -1.9847,  3.4818],
         [ 6.6054, -5.2696, -0.1708,  ...,  1.2646,  1.0728,  4.2250],
         ...,
         [ 5.5104, -1.6445, -2.6206,  ...,  5.0193, -4.9450,  2.0719],
         [ 7.4420, -0.1016, -3.1009,  ...,  2.9020, -1.1202,  1.0917],
         [ 1.3787, -2.1941, -0.5048,  ...,  2.3131,  0.9453,  6.2604]],

        [[ 1.4098, -0.0258, -1.3463,  ...,  7.8112,  2.5633,  3.1469],
         [ 6.4190, -0.4263, -1.5800,  ...,  6

### Putting it All Together

![image](https://camo.githubusercontent.com/88e8f36ce61dedfd2491885b8df2f68c4d1f92f5/687474703a2f2f696d6775722e636f6d2f316b72463252362e706e67)

In [23]:
emb = WordPositionEmbedding(1000)
encoder = TransformerEncoder()
decoder = TransformerDecoder()

In [24]:
src_ids = torch.randint(1000, (5, 30))
tgt_ids = torch.randint(1000, (5, 30))
x = encoder(emb(src_ids))
decoder(emb(tgt_ids), x)

tensor([[[-7.5089, -4.8524,  8.9166,  ...,  4.8481,  3.4639, -3.0672],
         [-4.4578,  0.6768,  7.1931,  ...,  7.4993,  2.0882, -2.2992],
         [-4.3852, -1.9266,  8.4774,  ...,  7.7719, -0.1840, -6.2831],
         ...,
         [-7.4598, -4.0433,  3.9511,  ...,  7.1757,  0.6233, -3.7613],
         [-5.0434, -3.1148,  3.9955,  ...,  7.7747, -2.7859, -3.6723],
         [-5.1826,  0.1430,  6.7114,  ...,  6.5203,  0.2865, -4.3229]],

        [[-3.7279,  1.9699,  2.8120,  ...,  6.3024,  1.5573, -1.8119],
         [-2.8885, -1.1445,  4.2325,  ..., 10.5152, -1.4743, -3.4741],
         [-2.1535,  3.6287,  2.9028,  ...,  2.1750,  3.2068, -4.6496],
         ...,
         [-4.0922,  2.3624,  6.2289,  ...,  8.4165,  2.6457,  1.1291],
         [-2.2871,  1.4994,  5.5360,  ...,  8.1237, -1.0908, -5.8973],
         [-8.6225,  4.3793,  7.1774,  ...,  7.8179,  2.8150, -3.5142]],

        [[-6.7801,  2.1838,  5.9175,  ...,  8.5489,  7.5672, -4.0281],
         [-6.9171,  0.5173,  4.3053,  ...,  6