In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_dim, num_heads, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        assert embedding_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads"


        self.num_heads = num_heads #NOTE: # of heads
        self.head_dim = embedding_dim // num_heads #NOTE: # of dimension per head [batch, seq_l, emb_dim] --> [batch, head, seq_l, head_dim]

        # Linear layers for query, key, and value
        self.query_linear = nn.Linear(embedding_dim, embedding_dim) #NOTE: Q-projection for the WHOLE input with dim unchanged
        self.key_linear = nn.Linear(embedding_dim, embedding_dim) #NOTE: K-projection for the WHOLE input with dim unchanged
        self.value_linear = nn.Linear(embedding_dim, embedding_dim) #NOTE: V-projection for the WHOLE input with dim unchanged

        # Linear layer for output projection
        self.out_linear = nn.Linear(embedding_dim, embedding_dim) #NOTE: for recombining the concatenated heads

        self.dropout = nn.Dropout(dropout)  #NOTE: for adding some noises --> improve generalizability
        self.scale = self.head_dim ** -0.5  # Scaling factor (1 / sqrt(d_k))

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0) #NOTE: get the batch size

        # Linear transformations
        #NOTE: ---------------------
        # Why transforming as a WHOLE before cutting the input?
        #     - make sure each head is looking at the SAME Q/K/V but in DIFFERENT perspectives
        #       (i.e. mathematically, subspaces/dimentions)
        #     - if cut the input first, then each head will be "blinded" to look only at a subset of
        #       the original data -> losing the holistic meaning.
        # ---------------------------
        '''#params: embedding_dim * (embedding_dim+1) *3'''
        Q = self.query_linear(query)  #NOTE:  (batch_size, seq_len, embedding_dim)
        K = self.key_linear(key)      #NOTE:  (batch_size, seq_len, embedding_dim)
        V = self.value_linear(value)  #NOTE:  (batch_size, seq_len, embedding_dim)

        # Reshape into (batch_size, num_heads, seq_len, head_dim)
        #NOTE: [batch, seq_l, emb_dim] --> [batch, head, seq_l, head_dim]
        '''#params: 0'''
        Q = Q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

        # Scaled dot-product attention
        '''#params: 0'''
        #NOTE: Compute scaled dot-product attention scores:
        #      QK^T gives [batch, head, seq_l, seq_l], scaled by 1/sqrt
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
        if mask is not None:  #NOTE: apply the causal attention mask to prevent peeking into later words
            attention_scores = attention_scores.masked_fill(mask == 0, float('-inf'))
        #NOTE: normalize using softmax -> score between [0,1]
        attention_probs = F.softmax(attention_scores, dim=-1)
        attention_probs = self.dropout(attention_probs) #NOTE: add some noise

        # Attention output
        attention_output = torch.matmul(attention_probs, V)  #NOTE: (batch_size, num_heads, seq_len, head_dim)

        #NOTE: directly concatenate heads and put through final linear layer
        attention_output = attention_output.transpose(1, 2).contiguous()
        attention_output = attention_output.view(batch_size, -1, self.num_heads * self.head_dim)
        '''#params: embedding_dim*(embedding_dim+1)'''
        output = self.out_linear(attention_output)  #NOTE: for recombining the concatenated heads

        return output

# Example usage
embedding_dim = 768
num_heads = 12
dropout = 0.1
mha = MultiHeadAttention(embedding_dim, num_heads, dropout)

# Dummy input
batch_size = 32
seq_len = 512
dummy_input = torch.rand(batch_size, seq_len, embedding_dim)

# Forward pass
output = mha(dummy_input, dummy_input, dummy_input)
output.shape

torch.Size([32, 512, 768])

## Manually

In [4]:
cnt_params = embedding_dim*(embedding_dim+1)*3 ##NOTE: K/Q/V
cnt_params += embedding_dim*(embedding_dim+1) ## output
cnt_params


2362368

## loop

In [8]:
sum([p.numel() for p in mha.parameters() if p.requires_grad])

2362368