<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Complete_Multi_Head_Attention_Code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by the number of heads."

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Query, Key, Value projections
        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)

        # Output projection
        self.out = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, mask=None):
        batch_size, seq_len, embed_dim = x.size()

        # Transform inputs to queries, keys, values
        q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  # (B, heads, seq_len, head_dim)
        k = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  # (B, heads, seq_len, head_dim)
        v = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  # (B, heads, seq_len, head_dim)

        # Scaled dot-product attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32, device=x.device))  # (B, heads, seq_len, seq_len)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        weights = F.softmax(scores, dim=-1)  # (B, heads, seq_len, seq_len)

        # Attention output
        attention_output = torch.matmul(weights, v)  # (B, heads, seq_len, head_dim)

        # Concatenate heads and project
        attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)  # (B, seq_len, embed_dim)
        output = self.out(attention_output)  # Final linear projection
        return output, weights  # Returning weights for visualization if needed

# Example usage
if __name__ == "__main__":
    batch_size = 32
    seq_len = 10
    embed_dim = 64
    num_heads = 8

    input_tensor = torch.rand(batch_size, seq_len, embed_dim)  # Example input tensor
    mha = MultiHeadAttention(embed_dim=embed_dim, num_heads=num_heads)
    output, attention_weights = mha(input_tensor)

    print("Output shape:", output.shape)  # Expected: (32, 10, 64)
    print("Attention weights shape:", attention_weights.shape)  # Expected: (32, 8, 10, 10)