### Multi-Head Attention

**Concept:**
Multi-head attention is a core component of the Transformer architecture. It allows the model to jointly attend to information from different representation subspaces at different positions. Instead of performing a single attention function, the multi-head attention mechanism runs multiple attention functions in parallel.

### Steps Involved:

1. **Linear Projections:**
   - The input embeddings are linearly projected into multiple sets of Queries (Q), Keys (K), and Values (V). This is done by multiplying the input by learned weight matrices to create different "heads" (subspaces).
   - If we have \(h\) heads, we create \(Q_1, Q_2, ..., Q_h\), \(K_1, K_2, ..., K_h\), and \(V_1, V_2, ..., V_h\).

2. **Scaled Dot-Product Attention:**
   - For each head, we perform the scaled dot-product attention independently.
   - The attention scores are computed as:
     \[
     \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
     \]
     where \(d_k\) is the dimension of the key vectors (usually \(d_k = d/h\) where \(d\) is the dimension of the input embeddings).

3. **Concatenation:**
   - The outputs of all the heads are concatenated back together. This concatenated output represents the combined attention from multiple subspaces.
   - If each head produces an output of dimension \(d_k\), the concatenated output will have a dimension of \(h \times d_k\).

4. **Final Linear Projection:**
   - The concatenated output is linearly projected back to the original dimension \(d\) using another learned weight matrix.

### Detailed Breakdown:

1. **Linear Projections:**
   - Given an input matrix \(X \in \mathbb{R}^{n \times d}\) (where \(n\) is the sequence length and \(d\) is the embedding dimension), we create different sets of Queries, Keys, and Values:
     \[
     Q_i = XW_i^Q, \quad K_i = XW_i^K, \quad V_i = XW_i^V \quad \text{for } i \in \{1, ..., h\}
     \]
     where \(W_i^Q, W_i^K, W_i^V \in \mathbb{R}^{d \times d_k}\) are the learned weight matrices for the \(i\)-th head.

2. **Scaled Dot-Product Attention for Each Head:**
   - For each head, compute the attention scores:
     \[
     \text{Attention}_i = \text{softmax}\left(\frac{Q_iK_i^T}{\sqrt{d_k}}\right)V_i
     \]

3. **Concatenation:**
   - Concatenate the outputs of all heads:
     \[
     \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \text{head}_2, ..., \text{head}_h)W^O
     \]
     where \(W^O \in \mathbb{R}^{hd_k \times d}\) is the learned weight matrix for the final linear projection.

4. **Final Linear Projection:**
   - Apply the final linear transformation to get the multi-head attention output:
     \[
     \text{Output} = \text{Concat}(\text{Attention}_1, \text{Attention}_2, ..., \text{Attention}_h)W^O
     \]

### Advantages of Multi-Head Attention:

1. **Parallel Processing:**
   - Multiple attention heads allow the model to process different parts of the input in parallel, improving efficiency.

2. **Rich Representation:**
   - Each head learns to focus on different parts of the input, capturing various aspects and nuances of the data.

3. **Better Learning Capacity:**
   - By attending to information from different subspaces, the model can capture complex relationships and dependencies more effectively.

### Example Calculation:

Let's consider an example with:
- Sequence length \(n = 3\)
- Embedding dimension \(d = 4\)
- Number of heads \(h = 2\)
- Dimension of each head \(d_k = 2\)

Given input \(X\):
\[
X = \begin{bmatrix}
1 & 0 & 1 & 0 \\
0 & 1 & 0 & 1 \\
1 & 1 & 1 & 1 \\
\end{bmatrix}
\]

**Linear Projections for Head 1:**
\[
Q_1 = XW_1^Q, \quad K_1 = XW_1^K, \quad V_1 = XW_1^V
\]
where \(W_1^Q, W_1^K, W_1^V \in \mathbb{R}^{4 \times 2}\).

**Attention Calculation:**
\[
\text{Attention}_1 = \text{softmax}\left(\frac{Q_1K_1^T}{\sqrt{2}}\right)V_1
\]

**Repeat for Head 2 and Concatenate:**
\[
\text{MultiHead}(Q, K, V) = \text{Concat}(\text{Attention}_1, \text{Attention}_2)W^O
\]

The result is a richer representation of the input sequence, enabling the model to understand complex patterns and dependencies.

In [17]:
import torch 
import math 
from torch import nn 
import torch.nn.functional as F 

In [70]:
def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    scaled = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(d_k)
    print(f"scaled.size(): {scaled.size()}")
    if mask is not None:
        print(f"Adding the Mask of shape {mask.size()}")
        scaled += mask 
        print(f"scaled.size(): {scaled.size()}")
    attention = F.softmax(scaled, dim=-1)
    print(f"Attention.size(): {attention.size()}")
    values = torch.matmul(attention, v)
    return values, attention


In [71]:


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.qkv_layer = nn.Linear(d_model, 3 * d_model)
        self.linear_layer = nn.Linear(d_model, d_model)
    
    def forward(self, x, mask=None):
        batch_size, max_sequence_length, d_model = x.size()
        print(f"x.size(): {x.size()}")
        qkv = self.qkv_layer(x)
        print(f"qkv.size: {qkv.size()}")
        qkv = qkv.reshape(batch_size, max_sequence_length, self.num_heads, 3 * self.head_dim)
        print(f"qkv.size: {qkv.size()}")
        qkv = qkv.permute(0, 2, 1, 3)  # (batch_size, num_heads, max_sequence_length, 3 * head_dim)
        q, k, v = qkv.chunk(3, dim=-1)
        print(f"q.size: {q.size()}")
        print(f"k.size: {k.size()}")
        print(f"v.size: {v.size()}")

        values, attention = scaled_dot_product(q, k, v, mask)
        print(f"values.size(): {values.size()}")
        print(f"attention.size(): {attention.size()}")
        values = values.reshape(batch_size, max_sequence_length, self.num_heads * self.head_dim)
        print(f"values.size(): {values.size()}")
        out = self.linear_layer(values)
        print(f"out.size(): {out.size()}")
        return out, attention

In [72]:



class LayerNormalization(nn.Module):
    def __init__(self, parameters_shape, eps=1e-5):
        super().__init__()
        self.parameters_shape = parameters_shape
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta = nn.Parameter(torch.zeros(parameters_shape))
    
    def forward(self, inputs):
        dims = [-(i + 1) for i in range(len(self.parameters_shape))]
        mean = inputs.mean(dim=dims, keepdim=True)
        variance = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
        std = (variance + self.eps).sqrt()
        y = (inputs - mean) / std
        out = self.gamma * y + self.beta 
        return out


In [73]:

class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, hidden, drop_prob=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, hidden)
        self.linear2 = nn.Linear(hidden, d_model)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(drop_prob)
    
    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x


In [74]:

class EncoderLayer(nn.Module):
    def __init__(self, d_model, ffn_hidden, num_heads, drop_prob):
        super(EncoderLayer, self).__init__()
        self.attention = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        self.norm1 = LayerNormalization(parameters_shape=[d_model])
        self.dropout1 = nn.Dropout(p=drop_prob)
        self.ffn = PositionwiseFeedForward(d_model, hidden=ffn_hidden, drop_prob=drop_prob)
        self.norm2 = LayerNormalization(parameters_shape=[d_model])
        self.dropout2 = nn.Dropout(p=drop_prob)
    
    def forward(self, x):
        residual_x = x 
        x, _ = self.attention(x, mask=None)
        x = self.dropout1(x)
        x = self.norm1(x + residual_x)
        residual_x = x
        x = self.ffn(x)
        x = self.dropout2(x)
        x = self.norm2(x + residual_x)
        print(f"Shape of x :{x.size()}")
        return x 


In [75]:

class Encoder(nn.Module):
    def __init__(self, d_model, ffn_hidden, num_heads, drop_prob, num_layers):
        super().__init__()
        self.layers = nn.ModuleList([EncoderLayer(d_model, ffn_hidden, num_heads, drop_prob) for _ in range(num_layers)])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x 

In [76]:




d_model = 512 
num_heads = 8 
drop_prob = 0.1
batch_size = 30 
max_sequence_length = 200
ffn_hidden = 2048
num_layers = 5 


In [77]:
encoder = Encoder(d_model,ffn_hidden,num_heads,drop_prob,num_layers)

In [78]:
encoder

Encoder(
  (layers): ModuleList(
    (0-4): 5 x EncoderLayer(
      (attention): MultiHeadAttention(
        (qkv_layer): Linear(in_features=512, out_features=1536, bias=True)
        (linear_layer): Linear(in_features=512, out_features=512, bias=True)
      )
      (norm1): LayerNormalization()
      (dropout1): Dropout(p=0.1, inplace=False)
      (ffn): PositionwiseFeedForward(
        (linear1): Linear(in_features=512, out_features=2048, bias=True)
        (linear2): Linear(in_features=2048, out_features=512, bias=True)
        (relu): ReLU()
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (norm2): LayerNormalization()
      (dropout2): Dropout(p=0.1, inplace=False)
    )
  )
)

In [79]:
x = torch.randn((batch_size,max_sequence_length,d_model))


In [80]:
x.shape

torch.Size([30, 200, 512])

In [81]:
out = encoder(x)

x.size(): torch.Size([30, 200, 512])
qkv.size: torch.Size([30, 200, 1536])
qkv.size: torch.Size([30, 200, 8, 192])
q.size: torch.Size([30, 8, 200, 64])
k.size: torch.Size([30, 8, 200, 64])
v.size: torch.Size([30, 8, 200, 64])
scaled.size(): torch.Size([30, 8, 200, 200])
Attention.size(): torch.Size([30, 8, 200, 200])
values.size(): torch.Size([30, 8, 200, 64])
attention.size(): torch.Size([30, 8, 200, 200])
values.size(): torch.Size([30, 200, 512])
out.size(): torch.Size([30, 200, 512])
Shape of x :torch.Size([30, 200, 512])
x.size(): torch.Size([30, 200, 512])
qkv.size: torch.Size([30, 200, 1536])
qkv.size: torch.Size([30, 200, 8, 192])
q.size: torch.Size([30, 8, 200, 64])
k.size: torch.Size([30, 8, 200, 64])
v.size: torch.Size([30, 8, 200, 64])
scaled.size(): torch.Size([30, 8, 200, 200])
Attention.size(): torch.Size([30, 8, 200, 200])
values.size(): torch.Size([30, 8, 200, 64])
attention.size(): torch.Size([30, 8, 200, 200])
values.size(): torch.Size([30, 200, 512])
out.size(): torc

In [82]:
out

tensor([[[-2.5220e-01, -1.0149e+00, -9.2678e-01,  ..., -1.0509e+00,
          -8.8911e-01,  1.0389e+00],
         [ 1.0142e-02, -4.3849e-01,  6.8647e-01,  ..., -2.1143e+00,
          -8.7580e-01, -1.1077e-01],
         [-1.5120e+00,  3.0915e-01, -1.0055e+00,  ..., -5.4164e-01,
          -1.3031e-01, -8.3582e-01],
         ...,
         [ 1.3842e+00,  2.8302e-02, -7.8733e-01,  ..., -1.4032e+00,
           4.1988e-01, -2.3569e-01],
         [ 3.0116e-01,  7.8232e-01,  9.2847e-01,  ..., -5.4237e-01,
           5.7213e-01,  4.7925e-01],
         [-3.3266e-01,  4.1286e-01, -2.7234e-01,  ...,  6.6261e-02,
           6.8822e-01, -2.4220e+00]],

        [[ 1.4110e+00,  1.0553e+00, -1.4921e-01,  ..., -1.3799e+00,
          -5.5247e-01, -1.5545e+00],
         [ 1.5510e+00,  4.7343e-01, -1.4959e+00,  ...,  4.8307e-01,
           3.7061e-01, -1.8659e+00],
         [ 5.6035e-01, -7.3902e-01,  8.4590e-01,  ..., -2.7455e+00,
          -4.0794e-01, -1.9714e+00],
         ...,
         [ 1.3779e+00, -3