# Building Transformer from Scratch



The code is based off of the following repos/blog posts:

- [attention-is-all-you-need-pytorch](https://github.com/jadore801120/attention-is-all-you-need-pytorch)
- [pytorch-pretrained-BERT](https://github.com/huggingface/pytorch-pretrained-BERT)
- [The Annotated Transformer](http://nlp.seas.harvard.edu/2018/04/03/attention.html) 

Thanks so much to their authors!

In [128]:
import torch
import torch.nn as nn
import numpy as np

In [129]:
import logging
logger = logging.getLogger("tensor_shapes")
handler = logging.StreamHandler()
formatter = logging.Formatter('%(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)

logger.setLevel(1)

In [130]:
import inspect
def getclass():
    stack = inspect.stack()
    return stack[3][0].f_locals["self"].__class__

# A helper function to check how tensor sizes change
def log_size(tsr: torch.Tensor, name: str):
    cls = getclass()
    logger.log(level=cls.level, msg=[{cls.__name__}, {name}, {tsr.shape}])

In [131]:
from enum import IntEnum
# Control how much debugging output we want
class TensorLoggingLevels(IntEnum):
    attention = 1
    attention_head = 2
    multihead_attention_block = 3
    enc_dec_block = 4
    enc_dec = 5

In [132]:
class Dim(IntEnum):
    batch = 0
    seq = 1
    feature = 2

# Components


### Scaled dot product attention

$$ \textrm{Attention}(Q, K, V) = \textrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V $$

In [133]:
import math

class ScaledDotProductAttention(nn.Module):
    level = TensorLoggingLevels.attention # Logging level: 
    def __init__(self, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, mask=None):
        d_k = k.size(-1) # get the size of the key
        assert q.size(-1) == d_k

        # Compute the dot product between queries and keys for each batch and position in the sequence
        attn = torch.bmm(q, k.transpose(Dim.seq, Dim.feature)) # (Batch, Seq, Seq)

        attn = attn / math.sqrt(d_k)

        attn = torch.exp(attn)
        log_size(attn, "attention weight") # (Batch, Seq, Seq)
        
        # fill attention weights with 0s where padded
        if mask is not None: attn = attn.masked_fill(mask, 0)
        attn = attn / attn.sum(dim=-1, keepdim=True)
        attn = self.dropout(attn)
        output = torch.bmm(attn, v) # (Batch, Seq, Feature)
        log_size(output, "attention output size") # (Batch, Seq, Seq)
        return output

In [134]:
attn = ScaledDotProductAttention()

In [135]:
q = torch.rand(5, 10, 20)
k = torch.rand(5, 10, 20)
v = torch.rand(5, 10, 20)

In [136]:
attn(q, k, v)

[{'ScaledDotProductAttention'}, {'attention weight'}, {torch.Size([5, 10, 10])}]
[{'ScaledDotProductAttention'}, {'attention weight'}, {torch.Size([5, 10, 10])}]
[{'ScaledDotProductAttention'}, {'attention weight'}, {torch.Size([5, 10, 10])}]
[{'ScaledDotProductAttention'}, {'attention weight'}, {torch.Size([5, 10, 10])}]
[{'ScaledDotProductAttention'}, {'attention output size'}, {torch.Size([5, 10, 20])}]
[{'ScaledDotProductAttention'}, {'attention output size'}, {torch.Size([5, 10, 20])}]
[{'ScaledDotProductAttention'}, {'attention output size'}, {torch.Size([5, 10, 20])}]
[{'ScaledDotProductAttention'}, {'attention output size'}, {torch.Size([5, 10, 20])}]


tensor([[[0.5487, 0.6859, 0.4235, 0.5959, 0.6148, 0.6001, 0.5805, 0.4103,
          0.7313, 0.7234, 0.7655, 0.5142, 0.3910, 0.8098, 0.6246, 0.5217,
          0.7084, 0.4585, 0.5233, 0.4251],
         [0.3373, 0.5487, 0.3416, 0.4417, 0.4118, 0.3639, 0.4414, 0.3636,
          0.5160, 0.4745, 0.5662, 0.3672, 0.2377, 0.6190, 0.5036, 0.4159,
          0.5534, 0.3568, 0.4758, 0.3571],
         [0.4081, 0.5068, 0.2381, 0.4257, 0.3871, 0.4319, 0.4775, 0.1905,
          0.4603, 0.4435, 0.4720, 0.3849, 0.2394, 0.5561, 0.4252, 0.2508,
          0.4453, 0.3278, 0.2998, 0.3501],
         [0.5352, 0.6768, 0.4247, 0.5926, 0.5892, 0.5686, 0.5846, 0.4257,
          0.7153, 0.7050, 0.7539, 0.5124, 0.3841, 0.7924, 0.6262, 0.5240,
          0.7170, 0.4837, 0.5364, 0.4157],
         [0.4520, 0.6554, 0.3821, 0.5466, 0.5309, 0.4867, 0.5075, 0.4208,
          0.6461, 0.6102, 0.6723, 0.3967, 0.3367, 0.7423, 0.5300, 0.4870,
          0.6903, 0.3904, 0.5220, 0.3977],
         [0.4389, 0.6298, 0.4233, 0.5146, 0.5

### Multi Head Attention

In [137]:
class AttentionHead(nn.Module):
    level = TensorLoggingLevels.attention_head
    def __init__(self, d_model, d_feature, dropout=0.1):
        super().__init__()
        # We will assume the queries, keys, and values all have the same feature size
        self.attn = ScaledDotProductAttention(dropout)
        self.query_tfm = nn.Linear(d_model, d_feature)
        self.key_tfm = nn.Linear(d_model, d_feature)
        self.value_tfm = nn.Linear(d_model, d_feature)

    def forward(self, queries, keys, values, mask=None):
        Q = self.query_tfm(queries) # (Batch, Seq, Feature)
        K = self.key_tfm(keys) 
        V = self.value_tfm(values) 
        log_size(Q, "queries, keys, vals")
        # compute multiple attention weighted sums
        x = self.attn(Q, K, V)
        return x

In [138]:
attn_head = AttentionHead(20, 20)
attn_head(q, k, v)

[{'AttentionHead'}, {'queries, keys, vals'}, {torch.Size([5, 10, 20])}]
[{'AttentionHead'}, {'queries, keys, vals'}, {torch.Size([5, 10, 20])}]
[{'AttentionHead'}, {'queries, keys, vals'}, {torch.Size([5, 10, 20])}]
[{'AttentionHead'}, {'queries, keys, vals'}, {torch.Size([5, 10, 20])}]
[{'ScaledDotProductAttention'}, {'attention weight'}, {torch.Size([5, 10, 10])}]
[{'ScaledDotProductAttention'}, {'attention weight'}, {torch.Size([5, 10, 10])}]
[{'ScaledDotProductAttention'}, {'attention weight'}, {torch.Size([5, 10, 10])}]
[{'ScaledDotProductAttention'}, {'attention weight'}, {torch.Size([5, 10, 10])}]
[{'ScaledDotProductAttention'}, {'attention output size'}, {torch.Size([5, 10, 20])}]
[{'ScaledDotProductAttention'}, {'attention output size'}, {torch.Size([5, 10, 20])}]
[{'ScaledDotProductAttention'}, {'attention output size'}, {torch.Size([5, 10, 20])}]
[{'ScaledDotProductAttention'}, {'attention output size'}, {torch.Size([5, 10, 20])}]


tensor([[[-0.0839, -0.1160, -0.3149, -0.2495, -0.1862, -0.0053,  0.0912,
           0.0918,  0.4122,  0.2330,  0.1884,  0.4006, -0.3230,  0.0929,
           0.1710,  0.4789, -0.5986, -0.4741, -0.1592, -0.1236],
         [-0.1058, -0.0880, -0.3448, -0.2976, -0.1884, -0.0016,  0.1065,
           0.1135,  0.5054,  0.3075,  0.2410,  0.4434, -0.3087,  0.1432,
           0.1693,  0.5710, -0.6204, -0.5448, -0.2274, -0.1145],
         [-0.1096, -0.1153, -0.4004, -0.3189, -0.2159,  0.0050,  0.1237,
           0.1141,  0.5421,  0.3275,  0.2803,  0.5214, -0.3736,  0.1355,
           0.1965,  0.6032, -0.6990, -0.5953, -0.2246, -0.1616],
         [-0.1117, -0.1138, -0.4006, -0.3179, -0.2142,  0.0077,  0.1246,
           0.1156,  0.5438,  0.3253,  0.2700,  0.5120, -0.3707,  0.1390,
           0.1967,  0.6073, -0.7029, -0.5964, -0.2262, -0.1530],
         [-0.1105, -0.1132, -0.3989, -0.3193, -0.2136,  0.0042,  0.1234,
           0.1157,  0.5451,  0.3276,  0.2751,  0.5163, -0.3712,  0.1357,
          

The multi head attention block applies multiple attention heads as can be seen in the paper "Attention is all you need", then concatenates the output and applies single linear projection.

In [139]:
logger.setLevel(TensorLoggingLevels.attention_head)

In [140]:
class MultiHeadAttention(nn.Module):
    level = TensorLoggingLevels.multihead_attention_block
    def __init__(self, d_model, d_feature, n_heads, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.d_feature = d_feature
        self.n_heads = n_heads

        assert d_model == d_feature * n_heads

        self.attn_heads = nn.ModuleList([
            AttentionHead(d_model, d_feature, dropout) for _ in range(n_heads)
        ])
        self.projection = nn.Linear(d_feature * n_heads, d_model) 
    
    def forward(self, queries, keys, values, mask=None):
        log_size(queries, "Input queries")
        x = [attn(queries, keys, values, mask=mask) # (Batch, Sequence, Feature)
             for i, attn in enumerate(self.attn_heads)]
        log_size(x[0], "output of single head")
        
        # reconcatenate
        x = torch.cat(x, dim=Dim.feature) # (Batch, Sequence, D_Feature * n_heads)
        log_size(x, "concatenated output")
        x = self.projection(x) # (Batch, Sequence, D_Model)
        log_size(x, "projected output")
        return x

In [141]:
heads = MultiHeadAttention(20 * 8, 20, 8)
heads(q.repeat(1, 1, 8), 
      k.repeat(1, 1, 8), 
      v.repeat(1, 1, 8))

[{'MultiHeadAttention'}, {'Input queries'}, {torch.Size([5, 10, 160])}]
[{'MultiHeadAttention'}, {'Input queries'}, {torch.Size([5, 10, 160])}]
[{'MultiHeadAttention'}, {'Input queries'}, {torch.Size([5, 10, 160])}]
[{'MultiHeadAttention'}, {'Input queries'}, {torch.Size([5, 10, 160])}]
[{'AttentionHead'}, {'queries, keys, vals'}, {torch.Size([5, 10, 20])}]
[{'AttentionHead'}, {'queries, keys, vals'}, {torch.Size([5, 10, 20])}]
[{'AttentionHead'}, {'queries, keys, vals'}, {torch.Size([5, 10, 20])}]
[{'AttentionHead'}, {'queries, keys, vals'}, {torch.Size([5, 10, 20])}]
[{'AttentionHead'}, {'queries, keys, vals'}, {torch.Size([5, 10, 20])}]
[{'AttentionHead'}, {'queries, keys, vals'}, {torch.Size([5, 10, 20])}]
[{'AttentionHead'}, {'queries, keys, vals'}, {torch.Size([5, 10, 20])}]
[{'AttentionHead'}, {'queries, keys, vals'}, {torch.Size([5, 10, 20])}]
[{'AttentionHead'}, {'queries, keys, vals'}, {torch.Size([5, 10, 20])}]
[{'AttentionHead'}, {'queries, keys, vals'}, {torch.Size([5, 10,

tensor([[[ 3.8625e-02,  4.4277e-03, -5.3866e-01,  ..., -1.1546e-01,
          -1.3225e-01,  1.2852e-02],
         [ 3.1109e-02,  4.0733e-02, -4.9640e-01,  ..., -6.7372e-02,
          -1.1141e-01, -7.8243e-03],
         [ 4.2289e-02,  1.4381e-02, -5.2694e-01,  ..., -9.5735e-02,
          -1.0509e-01,  2.2255e-03],
         ...,
         [ 3.0347e-02,  4.4175e-03, -5.3604e-01,  ..., -1.4354e-01,
          -1.4847e-01, -7.5184e-03],
         [ 1.3951e-02, -3.7347e-02, -4.5489e-01,  ..., -7.9561e-02,
          -1.8650e-01, -1.7714e-02],
         [ 5.5675e-03,  9.2399e-03, -4.9065e-01,  ..., -1.2078e-01,
          -1.0069e-01, -4.1701e-02]],

        [[ 8.4204e-03,  2.5438e-02, -4.5736e-01,  ..., -5.8907e-02,
          -1.9848e-01, -1.0257e-01],
         [-3.0457e-02,  6.0310e-02, -3.9213e-01,  ..., -2.5930e-02,
          -2.1253e-01, -1.2257e-01],
         [-1.2325e-02,  2.3600e-02, -4.4613e-01,  ..., -1.3368e-02,
          -1.8879e-01, -9.9528e-02],
         ...,
         [-1.6378e-02, -6

### The Encoder

The encoder is made up of the following components:
- multi-head attention block
- simple feedforward neural network

These components are connected using residual connections and layer normalization.

In [143]:
    logger.setLevel(TensorLoggingLevels.multihead_attention_block)

In [144]:
class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-8):
        super(LayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta

In [146]:
class EncoderBlock(nn.Module):
    level = TensorLoggingLevels.enc_dec_block
    def __init__(self, d_model=512, d_feature=64,
                 d_ff=2048, n_heads=8, dropout=0.1):
        super().__init__()
        self.attn_head = MultiHeadAttention(d_model, d_feature, n_heads, dropout)
        self.layer_norm1 = LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.position_wise_feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
        )
        self.layer_norm2 = LayerNorm(d_model)
        
    def forward(self, x, mask=None):
        log_size(x, "Encoder block input")
        att = self.attn_head(x, x, x, mask=mask)
        log_size(x, "Attention output")
        # Applying normalization and residual connection.
        x = x + self.dropout(self.layer_norm1(att))
        # Applying position-wise feedforward network.
        pos = self.position_wise_feed_forward(x)
        log_size(x, "Feedforward output")
        # Applying normalization and residual connection.
        x = x + self.dropout(self.layer_norm2(pos))
        log_size(x, "Encoder size output")
        return x

In [147]:
enc = EncoderBlock()

In [148]:
enc(torch.rand(5, 10, 512))

[{'EncoderBlock'}, {'Encoder block input'}, {torch.Size([5, 10, 512])}]
[{'EncoderBlock'}, {'Encoder block input'}, {torch.Size([5, 10, 512])}]
[{'EncoderBlock'}, {'Encoder block input'}, {torch.Size([5, 10, 512])}]
[{'EncoderBlock'}, {'Encoder block input'}, {torch.Size([5, 10, 512])}]
[{'MultiHeadAttention'}, {'Input queries'}, {torch.Size([5, 10, 512])}]
[{'MultiHeadAttention'}, {'Input queries'}, {torch.Size([5, 10, 512])}]
[{'MultiHeadAttention'}, {'Input queries'}, {torch.Size([5, 10, 512])}]
[{'MultiHeadAttention'}, {'Input queries'}, {torch.Size([5, 10, 512])}]
[{'MultiHeadAttention'}, {'output of single head'}, {torch.Size([5, 10, 64])}]
[{'MultiHeadAttention'}, {'output of single head'}, {torch.Size([5, 10, 64])}]
[{'MultiHeadAttention'}, {'output of single head'}, {torch.Size([5, 10, 64])}]
[{'MultiHeadAttention'}, {'output of single head'}, {torch.Size([5, 10, 64])}]
[{'MultiHeadAttention'}, {'concatenated output'}, {torch.Size([5, 10, 512])}]
[{'MultiHeadAttention'}, {'con

tensor([[[ 1.7751,  1.1467,  3.4461,  ..., -0.6820, -2.4045,  1.2170],
         [ 1.3583, -0.1930,  4.0995,  ..., -0.1053, -1.5256,  0.3254],
         [ 0.2713,  0.1199,  2.4488,  ..., -0.2137,  0.1164, -0.1484],
         ...,
         [ 2.9047,  0.0258,  2.6899,  ..., -0.4092, -2.7584, -0.8172],
         [ 1.4165,  0.1880,  3.6077,  ...,  0.7653, -2.6604, -1.3268],
         [ 0.7241,  1.0821,  2.5580,  ...,  0.2299, -2.2741, -0.8416]],

        [[ 2.2110,  0.6463,  4.1705,  ...,  0.0997, -3.2632, -1.5028],
         [ 2.2632,  1.8673,  4.1719,  ...,  1.0696, -2.9554,  0.5570],
         [ 1.5239,  0.3931,  3.8011,  ...,  0.7221, -3.5631, -1.3783],
         ...,
         [ 2.2188,  1.4470,  3.1167,  ...,  0.6588, -0.7751, -0.6084],
         [ 1.3707,  1.2954,  3.4004,  ..., -0.4372, -1.9897, -0.2909],
         [ 2.7263,  1.1859,  4.0976,  ...,  0.8435, -2.3570, -1.1413]],

        [[ 1.1442,  0.4413,  2.9840,  ..., -0.4572, -2.6007, -1.4273],
         [ 0.9324,  0.4598,  1.0547,  ..., -0

The Encoder is having six consecutive encoder blocks, thus:

In [149]:
class TransformerEncoder(nn.Module):
    level = TensorLoggingLevels.enc_dec
    def __init__(self, n_blocks=6, d_model=512, n_heads=8, d_ff=2048, dropout=0.1):
        super().__init__()
        self.encoders = nn.ModuleList([
            EncoderBlock(d_model=d_model, d_feature=d_model // n_heads, d_ff=d_ff, dropout=dropout)
            for _ in range(n_blocks)
        ])
    
    def forward(self, x: torch.FloatTensor, mask=None):
        for encoder in self.encoders:
            x = encoder(x)
        return x

### The Decoder

The decoder is same in structure as the encoder with just one additional multi-head attention block that takes the target sentence as input.

In [150]:
class DecoderBlock(nn.Module):
    level = TensorLoggingLevels.enc_dec_block
    def __init__(self, d_model=512, d_feature=64,
                 d_ff=2048, n_heads=8, dropout=0.1):
        super().__init__()
        self.masked_attn_head = MultiHeadAttention(d_model, d_feature, n_heads, dropout)
        self.attn_head = MultiHeadAttention(d_model, d_feature, n_heads, dropout)
        self.position_wise_feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
        )

        self.layer_norm1 = LayerNorm(d_model)
        self.layer_norm2 = LayerNorm(d_model)
        self.layer_norm3 = LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_out, 
                src_mask=None, tgt_mask=None):
        # Applying attention to inputs.
        att = self.masked_attn_head(x, x, x, mask=src_mask)
        x = x + self.dropout(self.layer_norm1(att))
        # Applying attention to the encoder outputs and outputs of the previous layer.
        att = self.attn_head(queries=x, keys=enc_out, values=enc_out, mask=tgt_mask)
        x = x + self.dropout(self.layer_norm2(att))
        # Applying position-wise feedforward network.
        pos = self.position_wise_feed_forward(x)
        x = x + self.dropout(self.layer_norm2(pos))
        return x

In [151]:
dec = DecoderBlock()
dec(torch.rand(5, 10, 512), enc(torch.rand(5, 10, 512)))

[{'EncoderBlock'}, {'Encoder block input'}, {torch.Size([5, 10, 512])}]
[{'EncoderBlock'}, {'Encoder block input'}, {torch.Size([5, 10, 512])}]
[{'EncoderBlock'}, {'Encoder block input'}, {torch.Size([5, 10, 512])}]
[{'EncoderBlock'}, {'Encoder block input'}, {torch.Size([5, 10, 512])}]
[{'MultiHeadAttention'}, {'Input queries'}, {torch.Size([5, 10, 512])}]
[{'MultiHeadAttention'}, {'Input queries'}, {torch.Size([5, 10, 512])}]
[{'MultiHeadAttention'}, {'Input queries'}, {torch.Size([5, 10, 512])}]
[{'MultiHeadAttention'}, {'Input queries'}, {torch.Size([5, 10, 512])}]
[{'MultiHeadAttention'}, {'output of single head'}, {torch.Size([5, 10, 64])}]
[{'MultiHeadAttention'}, {'output of single head'}, {torch.Size([5, 10, 64])}]
[{'MultiHeadAttention'}, {'output of single head'}, {torch.Size([5, 10, 64])}]
[{'MultiHeadAttention'}, {'output of single head'}, {torch.Size([5, 10, 64])}]
[{'MultiHeadAttention'}, {'concatenated output'}, {torch.Size([5, 10, 512])}]
[{'MultiHeadAttention'}, {'con

tensor([[[-0.5575, -0.9342, -0.8271,  ...,  2.1886,  3.3905, -1.1279],
         [ 0.2831,  0.1582,  0.2548,  ...,  0.0379,  1.5856, -1.4237],
         [ 0.4611, -2.0253, -0.7102,  ..., -0.5144,  2.2126,  0.2841],
         ...,
         [ 0.1243, -1.4682, -1.0445,  ...,  1.5168,  2.4904, -1.3807],
         [ 0.4120, -1.2903, -0.1611,  ...,  0.8600,  2.2050, -0.8814],
         [ 0.0323, -1.7075, -2.1069,  ..., -0.2950,  1.1387, -0.4732]],

        [[ 0.4988, -2.6431, -1.9691,  ..., -0.2050,  1.2303, -0.8485],
         [ 0.7427, -0.0800, -0.6465,  ...,  0.7916,  1.7076, -1.6753],
         [ 0.6754, -0.8299, -0.7837,  ...,  0.8535,  1.5801, -1.5188],
         ...,
         [-0.4863, -1.2051, -0.0130,  ...,  1.7027,  1.5050, -1.0456],
         [ 1.0873, -1.3567, -0.4417,  ..., -0.3055,  1.3601, -0.3979],
         [-0.5913, -1.0328, -0.2476,  ...,  0.0326,  1.9341, -1.4867]],

        [[ 1.1301, -1.1240, -1.7672,  ...,  0.4280,  1.6007, -0.0960],
         [ 1.5915, -1.8536, -2.0984,  ...,  0

In [152]:
class TransformerDecoder(nn.Module):
    level = TensorLoggingLevels.enc_dec
    def __init__(self, n_blocks=6, d_model=512, d_feature=64, d_ff=2048, n_heads=8, dropout=0.1):
        super().__init__()
        self.position_embedding = PositionalEmbedding(d_model)
        self.decoders = nn.ModuleList([
            DecoderBlock(d_model=d_model, d_feature=d_model // n_heads, d_ff=d_ff, dropout=dropout)
            for _ in range(n_blocks)
        ])
        
    def forward(self, x: torch.FloatTensor, 
                enc_out: torch.FloatTensor, 
                src_mask=None, tgt_mask=None):
        for decoder in self.decoders:
            x = decoder(x, enc_out, src_mask=src_mask, tgt_mask=tgt_mask)
        return x

### Positional Embeddings

Attention blocks don't have any notion of word order in a sentence. The Transformer explicitly adds the positional information via the positional embeddings.

In [153]:
class PositionalEmbedding(nn.Module):
    level = 1
    def __init__(self, d_model, max_len=512):
        super().__init__()        

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.weight = nn.Parameter(pe, requires_grad=False)
        
    def forward(self, x):
        return self.weight[:, :x.size(1), :] # (1, Seq, Feature)

In [154]:
class WordPositionEmbedding(nn.Module):
    level = 1
    def __init__(self, vocab_size, d_model=512):
        super().__init__()
        self.word_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = PositionalEmbedding(d_model)
        
    def forward(self, x: torch.LongTensor, mask=None) -> torch.FloatTensor:
        return self.word_embedding(x) + self.position_embedding(x)

In [155]:
emb = WordPositionEmbedding(1000)
encoder = TransformerEncoder()

In [156]:
encoder(emb(torch.randint(1000, (5, 30))))

[{'EncoderBlock'}, {'Encoder block input'}, {torch.Size([5, 30, 512])}]
[{'EncoderBlock'}, {'Encoder block input'}, {torch.Size([5, 30, 512])}]
[{'EncoderBlock'}, {'Encoder block input'}, {torch.Size([5, 30, 512])}]
[{'EncoderBlock'}, {'Encoder block input'}, {torch.Size([5, 30, 512])}]
[{'MultiHeadAttention'}, {'Input queries'}, {torch.Size([5, 30, 512])}]
[{'MultiHeadAttention'}, {'Input queries'}, {torch.Size([5, 30, 512])}]
[{'MultiHeadAttention'}, {'Input queries'}, {torch.Size([5, 30, 512])}]
[{'MultiHeadAttention'}, {'Input queries'}, {torch.Size([5, 30, 512])}]
[{'MultiHeadAttention'}, {'output of single head'}, {torch.Size([5, 30, 64])}]
[{'MultiHeadAttention'}, {'output of single head'}, {torch.Size([5, 30, 64])}]
[{'MultiHeadAttention'}, {'output of single head'}, {torch.Size([5, 30, 64])}]
[{'MultiHeadAttention'}, {'output of single head'}, {torch.Size([5, 30, 64])}]
[{'MultiHeadAttention'}, {'concatenated output'}, {torch.Size([5, 30, 512])}]
[{'MultiHeadAttention'}, {'con

[{'MultiHeadAttention'}, {'projected output'}, {torch.Size([5, 30, 512])}]
[{'MultiHeadAttention'}, {'projected output'}, {torch.Size([5, 30, 512])}]
[{'MultiHeadAttention'}, {'projected output'}, {torch.Size([5, 30, 512])}]
[{'MultiHeadAttention'}, {'projected output'}, {torch.Size([5, 30, 512])}]
[{'EncoderBlock'}, {'Attention output'}, {torch.Size([5, 30, 512])}]
[{'EncoderBlock'}, {'Attention output'}, {torch.Size([5, 30, 512])}]
[{'EncoderBlock'}, {'Attention output'}, {torch.Size([5, 30, 512])}]
[{'EncoderBlock'}, {'Attention output'}, {torch.Size([5, 30, 512])}]
[{'EncoderBlock'}, {'Feedforward output'}, {torch.Size([5, 30, 512])}]
[{'EncoderBlock'}, {'Feedforward output'}, {torch.Size([5, 30, 512])}]
[{'EncoderBlock'}, {'Feedforward output'}, {torch.Size([5, 30, 512])}]
[{'EncoderBlock'}, {'Feedforward output'}, {torch.Size([5, 30, 512])}]
[{'EncoderBlock'}, {'Encoder size output'}, {torch.Size([5, 30, 512])}]
[{'EncoderBlock'}, {'Encoder size output'}, {torch.Size([5, 30, 512]

tensor([[[-2.2577e+00, -1.2576e+00,  2.1603e+00,  ...,  1.8294e+00,
           8.6026e+00, -3.7816e-01],
         [ 3.6246e-01,  6.4817e-01,  4.4497e+00,  ...,  7.0527e-01,
           6.6239e+00,  7.6782e-01],
         [ 7.2338e-03,  5.6890e-01,  5.5462e+00,  ..., -1.7403e+00,
           3.6862e+00, -1.6055e+00],
         ...,
         [ 1.1985e+00,  2.9628e+00,  4.9665e+00,  ...,  4.1127e-01,
           4.2465e+00, -2.2261e+00],
         [ 1.6161e+00, -5.7866e-01,  6.3342e+00,  ..., -2.7053e+00,
           3.3395e+00, -1.5021e+00],
         [-1.1713e-01, -2.8804e+00,  4.6234e+00,  ..., -1.1865e+00,
           6.3678e+00, -6.4391e-01]],

        [[ 3.3451e+00,  2.2801e+00,  2.4258e+00,  ..., -4.1175e-01,
           4.5857e+00, -2.9214e+00],
         [ 4.0797e+00,  2.9298e+00,  3.6687e+00,  ...,  2.3159e-01,
           6.1309e+00, -4.4661e+00],
         [ 5.0622e+00,  1.0185e+00,  3.8508e+00,  ...,  8.4012e-01,
           1.7289e+00, -7.4246e-01],
         ...,
         [ 1.7492e+00, -5

### Putting it all together.

In [157]:
logger.setLevel(TensorLoggingLevels.enc_dec_block)

In [158]:
emb = WordPositionEmbedding(1000)
encoder = TransformerEncoder()
decoder = TransformerDecoder()

In [159]:
src_ids = torch.randint(1000, (5, 30))
tgt_ids = torch.randint(1000, (5, 30))
x = encoder(emb(src_ids))
decoder(emb(tgt_ids), x)

[{'EncoderBlock'}, {'Encoder block input'}, {torch.Size([5, 30, 512])}]
[{'EncoderBlock'}, {'Encoder block input'}, {torch.Size([5, 30, 512])}]
[{'EncoderBlock'}, {'Encoder block input'}, {torch.Size([5, 30, 512])}]
[{'EncoderBlock'}, {'Encoder block input'}, {torch.Size([5, 30, 512])}]
[{'EncoderBlock'}, {'Attention output'}, {torch.Size([5, 30, 512])}]
[{'EncoderBlock'}, {'Attention output'}, {torch.Size([5, 30, 512])}]
[{'EncoderBlock'}, {'Attention output'}, {torch.Size([5, 30, 512])}]
[{'EncoderBlock'}, {'Attention output'}, {torch.Size([5, 30, 512])}]
[{'EncoderBlock'}, {'Feedforward output'}, {torch.Size([5, 30, 512])}]
[{'EncoderBlock'}, {'Feedforward output'}, {torch.Size([5, 30, 512])}]
[{'EncoderBlock'}, {'Feedforward output'}, {torch.Size([5, 30, 512])}]
[{'EncoderBlock'}, {'Feedforward output'}, {torch.Size([5, 30, 512])}]
[{'EncoderBlock'}, {'Encoder size output'}, {torch.Size([5, 30, 512])}]
[{'EncoderBlock'}, {'Encoder size output'}, {torch.Size([5, 30, 512])}]
[{'Encod

tensor([[[-3.2814,  3.7747,  3.0296,  ..., -4.6321,  3.3187,  0.8122],
         [-1.3944, -1.7222,  1.2845,  ..., -1.1253, -1.0714,  3.9675],
         [-4.3500,  5.4735,  3.5669,  ...,  0.8623,  1.9125, -1.6193],
         ...,
         [-5.7946,  5.7984, -3.0367,  ..., -4.0312, -0.2769, -2.2626],
         [-3.0773,  2.6791,  8.0121,  ..., -1.0238, -1.2385,  3.3627],
         [-5.2695,  2.3082,  0.7036,  ..., -7.9986, -0.1983,  3.4443]],

        [[-3.9637,  4.1581,  6.5100,  ..., -5.7876,  0.3916,  6.7167],
         [ 0.8292,  4.9841,  7.2288,  ..., -4.0899,  7.0819,  5.8788],
         [ 0.4265,  2.3783,  6.0820,  ..., -4.7353,  4.3201,  2.1114],
         ...,
         [ 2.0476,  4.7992,  2.8144,  ..., -0.2555,  1.9226,  0.1272],
         [-3.6189,  2.0032,  3.4841,  ..., -4.3685,  2.4501,  6.3162],
         [-2.0968,  3.8993,  6.6248,  ..., -5.6812,  3.8303,  3.1802]],

        [[-0.8310,  9.9016,  6.0963,  ...,  1.9776,  5.9374,  2.3406],
         [-5.4111,  4.5498,  6.7915,  ..., -0