In [1]:
# Program 11

import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert self.head_dim * heads == embed_size, "Embedding size must be divisible by heads"

        # Linear layers to project input into Q, K, V
        self.values = nn.Linear(embed_size, embed_size)
        self.keys = nn.Linear(embed_size, embed_size)
        self.queries = nn.Linear(embed_size, embed_size)
        self.fc_out = nn.Linear(embed_size, embed_size)  # Final projection

    def forward(self, value, key, query, mask=None):
        N = query.shape[0]  # Batch size
        value_len, key_len, query_len = value.shape[1], key.shape[1], query.shape[1]

        # Transform input into multiple heads
        values = self.values(value).view(N, value_len, self.heads, self.head_dim).transpose(1, 2)
        keys = self.keys(key).view(N, key_len, self.heads, self.head_dim).transpose(1, 2)
        queries = self.queries(query).view(N, query_len, self.heads, self.head_dim).transpose(1, 2)

        # Scaled Dot-Product Attention
        energy = torch.matmul(queries, keys.transpose(-2, -1)) / (self.head_dim ** 0.5)

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))  # Masking for padding tokens

        attention = torch.softmax(energy, dim=-1)

        out = torch.matmul(attention, values)
        out = out.transpose(1, 2).contiguous().view(N, query_len, self.embed_size)

        return self.fc_out(out)  # Final linear layer


embed_size = 128  # Dimension of embeddings
heads = 8  # Number of attention heads
attention = MultiHeadAttention(embed_size, heads)

x = torch.rand(2, 10, embed_size)  # (batch_size=2, seq_len=10, embed_size=128)
output = attention(x, x, x)
print(output.shape)

torch.Size([2, 10, 128])
