In [2]:
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

### Defining the sequence_length =, batch_size, input_dimension after embedding layer, output_dimensions.

In [3]:
seq_len = 4  # No of words in the input sentence
batch_size = 1 # A batch size of 1 for demonstration
input_dim = 512 # Dimension of the embeddings for every word in the input sentence
d_model = 512
num_heads = 8 # No of attention heads (BERT has 12 Attention heads)
head_dim = d_model // num_heads
x = torch.randn( (batch_size, seq_len, input_dim))

In [4]:
x.shape

torch.Size([1, 4, 512])

In [5]:
qkv_layer = nn.Linear(input_dim, 3 * d_model)

In [6]:
qkv = qkv_layer(x)

In [7]:
qkv.shape

torch.Size([1, 4, 1536])

In [8]:
qkv = qkv.reshape(batch_size, seq_len, num_heads, 3 * head_dim)

In [9]:
qkv.shape

torch.Size([1, 4, 8, 192])

In [10]:
qkv = qkv.permute(0, 2, 1, 3)
qkv.shape

torch.Size([1, 8, 4, 192])

In [11]:
q, k, v = qkv.chunk(3, dim = -1)

In [12]:
q.shape, k.shape, v.shape

(torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]))

In [13]:
d_k = q.size()[-1]
scaled = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)

In [14]:
scaled.shape

torch.Size([1, 8, 4, 4])

In [15]:
k.shape

torch.Size([1, 8, 4, 64])

In [17]:
k.transpose(-2, -1).shape

torch.Size([1, 8, 64, 4])

In [48]:
scaled[0][7].shape # Selected the first batch and we are looking into the last attention head (7) Since, its 0 indexed.

torch.Size([4, 4])

### We will be doing the masking logic that we did in Self-Attention notebook for the Decoder block. But here I won't be covering that since later cells in this notebook will have a proper class writted that will take care of mask.

In [49]:
attention = F.softmax(scaled, dim=-1) # dim = -1 calculates the softmax across columns ( Meaning softmax for elements in every column, row- wise, The usual softmax stuff that we do using numpy)

In [50]:
attention.shape

torch.Size([1, 8, 4, 4])

In [51]:
values = torch.matmul(attention, v)
values.shape

torch.Size([1, 8, 4, 64])

### Object - Oriented class that helps us understand the Multi-headed attention in detailed.

### Reference material to understand the matrix dimensions for the Multi-headed attention in the Encoder Block check out, Jay Alammar's blog on the Illustrated Transformer at  https://jalammar.github.io/illustrated-transformer/

### Purpose of FFNN after the Multi-headed self-attention mechanism:
The purpose of the FFNN is to introduce non-linearity and learn complex, context-dependent transformations on each attended representation individually. It enables the model to capture higher-order relationships and perform more expressive computations on the attended representations.

In [52]:
class MultiheadedAttention(nn.Module):

    def __init__(self, input_dim, d_model, num_heads):
        super().__init__()
        self.input_dim = input_dim
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.qkv_layer = nn.Linear(input_dim , 3 * d_model)
        self.linear_layer = nn.Linear(d_model, d_model)

    def scaled_dot_product(self, q, k, v, mask=None):
      d_k = q.size()[-1]
      scaled = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(d_k)
      if mask is not None:
          scaled += mask
      attention = F.softmax(scaled, dim=-1)
      values = torch.matmul(attention, v)
      return values, attention

    def forward(self, x, mask=None):
        batch_size, sequence_length, input_dim = x.size()
        print(f"x.size(): {x.size()}")

        qkv = self.qkv_layer(x)
        print(f"qkv.size(): {qkv.size()}")

        qkv = qkv.reshape(batch_size, sequence_length, self.num_heads, 3 * self.head_dim)
        print(f"qkv.size(): {qkv.size()}")

        qkv = qkv.permute(0, 2, 1, 3)
        print(f"qkv.size(): {qkv.size()}")

        q, k, v = qkv.chunk(3, dim=-1)
        print(f"q size: {q.size()}, k size: {k.size()}, v size: {v.size()}, ")

        values, attention = self.scaled_dot_product(q, k, v, mask)
        print(f"values.size(): {values.size()}, attention.size:{ attention.size()} ")

        values = values.reshape(batch_size, sequence_length, self.num_heads * self.head_dim)
        print(f"values.size(): {values.size()}")

        out = self.linear_layer(values)
        print(f"out.size(): {out.size()}")
        return out


In [53]:
input_dim = 512
d_model = 512
num_heads = 8

batch_size = 30
sequence_length = 5
x = torch.randn( (batch_size, sequence_length, input_dim) )

model = MultiheadedAttention(input_dim, d_model, num_heads)
out = model.forward(x)

x.size(): torch.Size([30, 5, 512])
qkv.size(): torch.Size([30, 5, 1536])
qkv.size(): torch.Size([30, 5, 8, 192])
qkv.size(): torch.Size([30, 8, 5, 192])
q size: torch.Size([30, 8, 5, 64]), k size: torch.Size([30, 8, 5, 64]), v size: torch.Size([30, 8, 5, 64]), 
values.size(): torch.Size([30, 8, 5, 64]), attention.size:torch.Size([30, 8, 5, 5]) 
values.size(): torch.Size([30, 5, 512])
out.size(): torch.Size([30, 5, 512])
