In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super(MultiHeadAttention, self).__init__()

        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        self.query_linear = nn.Linear(d_model, d_model)
        self.key_linear = nn.Linear(d_model, d_model)
        self.value_linear = nn.Linear(d_model, d_model)
        self.output_linear = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)

    def scaled_dot_product_attention(self, query, key, value, mask=None):
        attention_scores = torch.matmul(query, key.transpose(-2, -1))

        scale_factor = math.sqrt(self.head_dim)
        scaled_attention_scores = attention_scores / scale_factor

        if mask is not None:
            scaled_attention_scores = scaled_attention_scores.masked_fill(mask == 0, -1e9)

        attention_weights = F.softmax(scaled_attention_scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        context_vector = torch.matmul(attention_weights, value)

        return context_vector, attention_weights

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        Q = self.query_linear(query)
        K = self.key_linear(key)
        V = self.value_linear(value)

        Q = Q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

        context_vector, attention_weights = self.scaled_dot_product_attention(Q, K, V, mask)

        context_vector = context_vector.transpose(1, 2).contiguous()
        context_vector = context_vector.view(batch_size, -1, self.d_model)

        output = self.output_linear(context_vector)

        return output

if __name__ == '__main__':
    batch_size = 4
    seq_len_q = 10
    seq_len_kv = 12
    d_model = 512
    num_heads = 8
    dropout = 0.1

    query = torch.rand(batch_size, seq_len_q, d_model)
    key = torch.rand(batch_size, seq_len_kv, d_model)
    value = torch.rand(batch_size, seq_len_kv, d_model)

    padding_mask = torch.ones(batch_size, 1, 1, seq_len_kv)
    padding_mask[:, :, :, -2:] = 0

    mha = MultiHeadAttention(d_model, num_heads, dropout)

    output = mha(query, key, value, mask=padding_mask)

    print("Input Query Shape:", query.shape)
    print("Input Key Shape:", key.shape)
    print("Input Value Shape:", value.shape)
    print("Padding Mask Shape:", padding_mask.shape)
    print("Output Shape:", output.shape)

    self_attention_output = mha(query, query, query, mask=None)
    print("Self-Attention Output Shape:", self_attention_output.shape)


Input Query Shape: torch.Size([4, 10, 512])
Input Key Shape: torch.Size([4, 12, 512])
Input Value Shape: torch.Size([4, 12, 512])
Padding Mask Shape: torch.Size([4, 1, 1, 12])
Output Shape: torch.Size([4, 10, 512])
Self-Attention Output Shape: torch.Size([4, 10, 512])
