In [1]:
_ = !python3.12 -m pip install torch
import warnings
warnings.filterwarnings('ignore')

In [2]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer, MultiheadAttention

In [3]:
n_batch = 16
n_sequence = 20
n_features = 512
n_heads = 4
dim_feedforward = 1024
dropout = 0.0

In [4]:
x = torch.rand(n_batch, n_sequence, n_features)
x.shape

torch.Size([16, 20, 512])

In [5]:
transformer_encoder_layer = TransformerEncoderLayer(d_model=n_features, nhead=n_heads, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True)
y = transformer_encoder_layer(x)
y.shape

torch.Size([16, 20, 512])

In [6]:
n_layers = 5
transformer_encoder = TransformerEncoder(encoder_layer=transformer_encoder_layer, num_layers=n_layers)
y = transformer_encoder(x)
y.shape

torch.Size([16, 20, 512])

In [7]:
def count_trainable_parameters(model: nn.Module):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(count_trainable_parameters(transformer_encoder_layer))
print(count_trainable_parameters(transformer_encoder))

print(count_trainable_parameters(transformer_encoder_layer) * n_layers == count_trainable_parameters(transformer_encoder))

2102784
10513920
True


In [8]:
class MyTransformerEncoderLayer(nn.Module):

    def __init__(self, d_model: int, nhead: int, dim_feedforward: int, dropout: float, batch_first: bool):
        super(MyTransformerEncoderLayer, self).__init__()
        self.attention = MultiheadAttention(embed_dim=d_model, num_heads=nhead, batch_first=True)
        self.feedforward = nn.ModuleList([
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward, d_model)
        ])
        self.dropout_1 = nn.Dropout(dropout)
        self.layer_norm_1 = nn.LayerNorm(d_model)
        self.dropout_2 = nn.Dropout(dropout)
        self.layer_norm_2 = nn.LayerNorm(d_model)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        
        _x = x.clone()
        x, _ = self.attention(x, x, x)
        x = self.dropout_1(x)
        x = x + _x
        x = self.layer_norm_1(x)

        _x = x.clone()
        for _module in self.feedforward:
            x = _module(x)
        x = self.dropout_2(x)
        x = x + _x
        x = self.layer_norm_2(x)

        return x
        

In [9]:
my_transformer_encoder_layer = MyTransformerEncoderLayer(d_model=n_features, nhead=n_heads, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True)
y = my_transformer_encoder_layer(x)
y.shape

torch.Size([16, 20, 512])

In [10]:
print(count_trainable_parameters(my_transformer_encoder_layer))
print(count_trainable_parameters(transformer_encoder_layer))

2102784
2102784


In [12]:
attention = MultiheadAttention(embed_dim=n_features, num_heads=n_heads, batch_first=True)

In [19]:
class MyUnmaskedMultiheadAttention(nn.Module):

    def __init__(self, embed_dim: int, num_heads: int):
        super(MyUnmaskedMultiheadAttention, self).__init__()
        self.w_q = nn.Linear(embed_dim, embed_dim)
        self.w_k = nn.Linear(embed_dim, embed_dim)
        self.w_v = nn.Linear(embed_dim, embed_dim)

        if embed_dim % num_heads != 0:
            raise ValueError("embed_dim must be divisible by num_heads")

        self.n_features = embed_dim
        self.n_heads = num_heads
        self.head_dim = int(embed_dim / num_heads)

        self.linear_out = nn.Linear(embed_dim, embed_dim)
    
    def forward(self, q, k, v) -> torch.Tensor:
        n_batch, _, _ = q.shape
        q = self.w_q(q).view(n_batch, -1, self.n_heads, self.head_dim).transpose(1, 2) # n_batch, n_heads, n_sequence, head_dim
        k = self.w_k(k).view(n_batch, -1, self.n_heads, self.head_dim).transpose(1, 2)
        v = self.w_v(v).view(n_batch, -1, self.n_heads, self.head_dim).transpose(1, 2) 

        k_t = k.transpose(-2, -1) # n_batch, n_heads, head_dim, n_sequence

        scores = torch.matmul(q, k_t) / math.sqrt(self.head_dim) # n_batch, n_heads, n_sequence, n_sequence

        attention_weights = F.softmax(scores, dim=-1) # n_batch, n_heads, n_sequence, n_sequence
        output = torch.matmul(attention_weights, v)   # n_batch, n_heads, n_sequence, n_sequence __X__ n_batch, n_heads, n_sequence, head_dim 
                                                      # = n_batch, n_heads, n_sequence, head_dim

        # reconcatenate the heads
        output = output.transpose(1, 2).contiguous().view(n_batch, -1, self.n_features) # n_batch, n_sequence, n_features

        output = self.linear_out(output)

        return output, attention_weights

In [20]:
my_attention = MyUnmaskedMultiheadAttention(embed_dim=n_features, num_heads=n_heads)

In [21]:
print(count_trainable_parameters(attention))
print(count_trainable_parameters(my_attention))

1050624
1050624


In [23]:
my_attention(x, x, x)[0].shape

torch.Size([16, 20, 512])

In [24]:
attention(x, x, x)[0].shape

torch.Size([16, 20, 512])