In [None]:
import os
import torch
from torch import nn
from torch.nn import Module
from torch import functional as F
import math
import copy

**Input & Model Dimensions**

In [None]:
batch_size = 3
max_seq_length = 2**14
seq_length = 100
num_heads = 4
d_model = 32
d_q = d_k = d_v = d_model // num_heads

In [None]:
x = torch.randn(size=(batch_size, seq_length, d_model))
print('Input vector size:', x.size())

# Multi Head Attention Block

## Computing `Q`, `K` & `V` Vectors

In [None]:
W_q = nn.Linear(in_features=d_model, out_features=d_model)
W_k = nn.Linear(in_features=d_model, out_features=d_model)
W_v = nn.Linear(in_features=d_model, out_features=d_model)

In [None]:
Q = W_q(x)
print('Query vector size', Q.size())
K = W_k(x)
print('Key vector size', K.size())
V = W_v(x)
print('Value vector size', V.size())

## Splitting the `Q`, `K` & `V` vectors into various heads

In [None]:
Q = Q.view(batch_size, seq_length, num_heads, d_q)
print('Query vector size', Q.size())

K = K.view(batch_size, seq_length, num_heads, d_k)
print('Key vector size', K.size())

V = V.view(batch_size, seq_length, num_heads, d_v)
print('Value vector size', V.size())

## Scaled Dot-Product Attention

### Permuting the tensors to compute the matrix multiplication

In [None]:
Q = Q.permute(0, 2, 1, 3)
print('Query vector size', Q.size())

K = K.permute(0, 2, 3, 1)
print('Key vector size', K.size())

### Matrix Multiplication

In [None]:
attention_scores = Q @ K
print('Attention score vector size', attention_scores.size())

### Scaling

In [None]:
attention_scores = attention_scores / math.sqrt(d_k)
print('Attention score vector size', attention_scores.size())

### Masking

In [None]:
full_mask = torch.ones(size=(seq_length, seq_length), dtype=torch.bool)
print('Full mask size vector size', full_mask.size())

lower_triangle_mask = torch.tril(torch.ones(size=(seq_length, seq_length), dtype=torch.bool))
print('Lower triangle mask vector size', lower_triangle_mask.size())

attention_scores = attention_scores.masked_fill(~full_mask, -1e15)
print('Attention scores vector size', attention_scores.size())

### Softmax-ing

In [None]:
attention_probabilities = attention_scores.softmax(dim=-1)
print('Attention probabilities vector size', attention_probabilities.size())

### Matrix Multiplication

In [None]:
scaled_dot_product_attention_output = (attention_probabilities @ V.permute(0, 2, 1, 3)).permute(0, 2, 1, 3)
print('Scaled dot-product attention output vector size', scaled_dot_product_attention_output.size())

## Concat (Merging Heads)

In [None]:
concatenated_scaled_dot_product_attention_output = scaled_dot_product_attention_output.contiguous().view(batch_size, seq_length, d_model)
print('Concatenated scaled dot-product attention output vector size', concatenated_scaled_dot_product_attention_output.size())

## Linear Layer(s)

In [None]:
W_o = nn.Linear(in_features=d_model, out_features=d_model)

In [None]:
multi_head_attention_output = W_o(concatenated_scaled_dot_product_attention_output)

# Residual Connection & Layer Normalization

## Residual Connection

In [None]:
x = x + multi_head_attention_output
print('Residual connection output vector size', x.size())

## Layer Normalization

In [None]:
normalization_layer_after_multi_head_attention = nn.LayerNorm(d_model)

In [None]:
x = normalization_layer_after_multi_head_attention(x)
print('Layer normalized vector size', x.size())

# Point-wise Feed-forward Block

In [None]:
feed_forward = nn.Sequential(
    nn.Linear(in_features=d_model, out_features=d_model),
    nn.ReLU(inplace=True),
    nn.Linear(in_features=d_model, out_features=d_model)
)

In [None]:
feed_forward_output = feed_forward(x)
print('Feed forward output vector size', feed_forward_output.size())

# Residual Connection & Layer Normalization

## Residual Connection

In [None]:
x = x + feed_forward_output
print('Residual connection output vector size', x.size())

## Layer Normalization

In [None]:
normalization_layer_after_feed_forward = nn.LayerNorm(d_model)

In [None]:
x = normalization_layer_after_feed_forward(x)
print('Layer normalized vector size', x.size())

# Positional Embedding

In [None]:
positional_embeddings = torch.zeros(max_seq_length, d_model)
n = 10000

for pos in torch.arange(0, max_seq_length, dtype=torch.int):
    i = torch.arange(0, d_model // 2)
    positional_embeddings[pos, 0::2] = torch.sin(pos / n**(2 * i / d_model))
    positional_embeddings[pos, 1::2] = torch.cos(pos / n**(2 * i / d_model))

In [None]:
print(positional_embeddings[2].unsqueeze(0).T)