Cheat sheet: https://medium.com/the-dl/transformers-from-scratch-in-pytorch-8777e346ca51
Paper: https://arxiv.org/pdf/1706.03762.
Annotated transformer: https://jalammar.github.io/illustrated-transformer/

In [6]:
import torch
from torch import nn

$$
\text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V
$$
Matrices $Q, K, V$ have dimensions (batch, seq, num_features)

In [16]:
# Step 1: Implement Scaled Dot-Product Attention
def scaled_dot_product_attention(Q, K, V):
    QKt = torch.bmm(Q, torch.transpose(K, 1, 2))
    d_k = K.size(dim=1)
    scaled_qk = torch.sqrt(QKt / d_k)
    softmaxed_scaled_qk = torch.softmax(scaled_qk, dim=1)  # https://twitter.com/hardmaru/status/1359323333720875008?s=20&t=P--bktXSUWO0rsPS1cU-7A
    return torch.bmm(softmaxed_scaled_qk, V)

Q = torch.rand(32, 512, 1024)
K = torch.rand(32, 512, 1024)
V = torch.rand(32, 512, 1024)
output = scaled_dot_product_attention(Q, K, V)
print(f"Scaled Dot Product size with input size {Q.size()}: {output.size()}")

class AttentionHead(nn.Module):
    def __init__(self, n_features, hidden_size):
        super(AttentionHead, self).__init__()
        self.Wq = nn.Parameter(torch.empty(n_features, hidden_size))
        nn.init.trunc_normal_(self.Wq)
        self.Wk = nn.Parameter(torch.empty(n_features, hidden_size))
        nn.init.trunc_normal_(self.Wk)
        self.Wv = nn.Parameter(torch.empty(n_features, hidden_size))
        nn.init.trunc_normal_(self.Wv)

    @staticmethod
    def scaled_dot_product_attention(Q, K, V):
        QKt = torch.bmm(Q, torch.transpose(K, 1, 2))
        d_k = K.size(dim=1)
        scaled_qk = torch.sqrt(QKt / d_k)
        softmaxed_scaled_qk = torch.softmax(scaled_qk, dim=1)  # https://twitter.com/hardmaru/status/1359323333720875008?s=20&t=P--bktXSUWO0rsPS1cU-7A
        return torch.bmm(softmaxed_scaled_qk, V)

    def forward(self, X):
        Q = torch.matmul(X, self.Wq)
        K = torch.matmul(X, self.Wk)
        V = torch.matmul(X, self.Wv)
        Z = self.scaled_dot_product_attention(Q, K, V)
        return Z

X = torch.rand(32, 512, 1024)
single_attention_head = AttentionHead(n_features=X.size(dim=2), hidden_size=8)
output = single_attention_head(X)
print(f"Single Head Attention size with input size {X.size()}: {output.size()}")

Scaled Dot Product size with input size torch.Size([32, 512, 1024]): torch.Size([32, 512, 1024])
Single Head Attention size with input size torch.Size([32, 512, 1024]): torch.Size([32, 512, 8])


In [24]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, n_heads, n_features, hidden_size):
        super(MultiHeadedAttention, self).__init__()
        self.attn_heads = [AttentionHead(n_features, hidden_size) for _ in range(n_heads)]
        self.Wo = nn.Parameter(torch.empty(hidden_size * n_heads, n_features))
        nn.init.trunc_normal_(self.Wo)

    def forward(self, X):
        attn_outputs = torch.cat([head(X) for head in self.attn_heads], dim=2)
        Z = torch.matmul(attn_outputs, self.Wo)
        return Z

X = torch.rand(32, 512, 1024)
multi_attention_head = MultiHeadedAttention(n_heads=8, n_features=X.size(dim=2), hidden_size=8)
output = multi_attention_head(X)
print(f"Multi-headed Attention size with input size {X.size()}: {output.size()}")

Multi-headed Attention size with input size torch.Size([32, 512, 1024]): torch.Size([32, 512, 1024])


In [None]:
class Encoder(nn.Module):
    def __init__(self, n_heads, n_features, hidden_size):
        super(Encoder, self).__init__()
        self.self_attn = MultiHeadedAttention(n_heads, n_features, hidden_size)
        self.layer_norm1 = nn.LayerNorm()
        self.linear = nn.Linear()
        self.layer_norm2 = nn.LayerNorm()

    def forward(self, X):
        Z = self.self_attn(X)
        Z = self.layer_norm1(torch.add(X, Z))
        