## **Multi-Head Attention**

Multi-Head Attention is a variant of the attention mechanism that improves performance by using multiple attention heads in parallel. Each head operates on a different linear projection of the input, allowing the model to attend to different parts of the input in different ways. The outputs of all heads are concatenated and linearly transformed.

**Imports**

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

**Data Loading**

In [None]:
query = torch.randn(1, 20)
key = torch.randn(10, 20)
value = torch.randn(10, 20)

**Multi-Head Attention Model**

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, input_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = input_dim // num_heads
        self.query_layer = nn.Linear(input_dim, input_dim)
        self.key_layer = nn.Linear(input_dim, input_dim)
        self.value_layer = nn.Linear(input_dim, input_dim)
        self.output_layer = nn.Linear(input_dim, input_dim)
    
    def forward(self, query, key, value):
        batch_size = query.size(0)

        # Apply linear projections to query, key, value
        query = self.query_layer(query).view(batch_size, -1, self.num_heads, self.head_dim)
        key = self.key_layer(key).view(batch_size, -1, self.num_heads, self.head_dim)
        value = self.value_layer(value).view(batch_size, -1, self.num_heads, self.head_dim)

        # Compute Scaled Dot-Product Attention for each head
        attention_output, attention_weights = self.scaled_dot_product_attention(query, key, value)

        # Concatenate heads and apply final output linear layer
        attention_output = attention_output.view(batch_size, -1, self.num_heads * self.head_dim)
        output = self.output_layer(attention_output)
        return output, attention_weights

    def scaled_dot_product_attention(self, query, key, value):
        # Compute Scaled Dot-Product Attention
        scores = torch.matmul(query, key.transpose(-2, -1)) / np.sqrt(self.head_dim)
        attention_weights = torch.softmax(scores, dim=-1)
        attention_output = torch.matmul(attention_weights, value)
        return attention_output, attention_weights

**Instantiate and Apply Attention**

In [None]:
multi_head_attention = MultiHeadAttention(input_dim=20, num_heads=4)
output, attention_weights = multi_head_attention(query, key, value)

**Display Results**

In [None]:
print("Multi-Head Attention Output:", output)
print("Attention Weights:", attention_weights)