#### ðŸ§  Multi-Head Attention

A single context vector created by applying only **one attention layer** will capture only **one type of semantic meaning** in a sentence.  
But sentences often contain **multiple possible meanings** or relationships between words.

---

**Example:**

> *"The man saw an astronaut with a telescope."*

This sentence can have **two different interpretations**:

1. The man is using the telescope to see the astronaut.  
2. The astronaut has a telescope, and the man saw him.

---

So, how do we capture these multiple semantics?

We use **Multi-Head Attention** ðŸŽ¯

$$
\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \text{head}_2, ..., \text{head}_h)W^O
$$

Each head learns **different attention patterns** â€” focusing on **different relationships** or **aspects of meaning** in the same sentence.

---

**Intuition:**

- Each head performs **self-attention** independently.  
- Outputs from all heads are then **concatenated** and projected again.  
- This helps the model understand **richer contextual relationships**.

---

âœ… **Summary:**
- **Single-head attention:** Captures one semantic relationship.  
- **Multi-head attention:** Captures multiple types of relationships in parallel.


In [1]:
## Causal Attention 
import torch
import torch.nn as nn 

class CausalAttention(nn.Module):
    def __init__(self, dim, context_len, dropout=0.1):
        super().__init__()
        self.Wq = torch.nn.Parameter(torch.rand(dim, dim))
        self.Wk = torch.nn.Parameter(torch.rand(dim, dim))
        self.Wv = torch.nn.Parameter(torch.rand(dim, dim))
        self.dropout = torch.nn.Dropout(dropout)
        self.register_buffer( 'mask',torch.triu(torch.ones(context_len, context_len), diagonal=1))

    def forward(self, x):
        # x is input senetence (batch_size, context_len, embedding_dim)
        # calculate key, query and value
        batch_size, context_len, emb_dim = x.shape
        keys = x @ self.Wk
        queries = x @ self.Wq 
        values = x @ self.Wv 

        # attention scores 
        attn_scores = queries @ keys.transpose(1, 2)
        # scaling attention scores 
        # mask = torch.triu(torch.ones(x.shape[0], x.shape[0]), diagonal=1)
        attn_scores = attn_scores.masked_fill_(self.mask.bool()[:context_len, :context_len], -torch.inf)

        attn_scores = attn_scores / torch.sqrt(torch.tensor(x.shape[1], dtype=torch.float32))

        scaled_attn_weights = torch.softmax(attn_scores, dim=-1)
        scaled_attn_weights = self.dropout(scaled_attn_weights)
        contextualized_inputs = scaled_attn_weights @ values 
        return contextualized_inputs

#### Multi-head Attention Class

In [4]:
class MultiheadAttention(nn.Module):
    def __init__(self, dim, context_len, dropout, num_heads):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(dim, context_len, dropout)
            for _ in range(num_heads)]
        )

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [10]:
multi_attention = MultiheadAttention(5, 6, 0.1, 3)

In [11]:
inputs = torch.rand(2, 6, 5)
inputs

tensor([[[0.6426, 0.3412, 0.4651, 0.8227, 0.1728],
         [0.0564, 0.0785, 0.6955, 0.6736, 0.2981],
         [0.8229, 0.0105, 0.5075, 0.7254, 0.8440],
         [0.0784, 0.5026, 0.7916, 0.4722, 0.8658],
         [0.3275, 0.2329, 0.5386, 0.4205, 0.4891],
         [0.4363, 0.3532, 0.9617, 0.3267, 0.4047]],

        [[0.9531, 0.8927, 0.4615, 0.9036, 0.6058],
         [0.0031, 0.5224, 0.3887, 0.6819, 0.2497],
         [0.4434, 0.6777, 0.5639, 0.3497, 0.3231],
         [0.7987, 0.0599, 0.8790, 0.3200, 0.8516],
         [0.8521, 0.4802, 0.5847, 0.0934, 0.8232],
         [0.2584, 0.8376, 0.1798, 0.2178, 0.1413]]])

In [12]:
outputs_multihead = multi_attention(inputs)
outputs_multihead.shape

torch.Size([2, 6, 15])

As we concatenated 3 embeddings for single word we got length 3*5 = 15