<a href="https://colab.research.google.com/github/anesmeftah/deep-learning-roadmap/blob/main/PyTorch/Multi_head_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [7]:
def scaled_dot_product(q , k , v , mask = None):
  d_k = q.size()[-1]
  scaled = torch.matmul(q , k.transpose(-1, -2) / math.sqrt(d_k))
  if mask is not None:
    scaled += mask
  attention = F.softmax(scaled , dim = -1)
  values = torch.matmul(attention , v)
  return values , attention

In [8]:
class MultiHeadAttention():
  def __init__(self , input_dim , d_model , num_heads):
    super().__init__()
    self.input_dim = input_dim
    self.d_model = d_model
    self.num_heads = num_heads

    self.head_dim = d_model // num_heads

    self.qkv_layer = nn.Linear(input_dim , 3 * d_model)
    self.linear_layer = nn.Linear(d_model,d_model)

  def forward(self , x , mask = None):
    batch_size , sequence_length , input_dim = x.size()
    print("x.size : " , batch_size , sequence_length , input_dim)

    qkv = self.qkv_layer(x)
    print("qkv size : " , qkv.size())

    qkv = qkv.reshape(batch_size, sequence_length, self.num_heads, 3 * self.head_dim)
    print("qkv size : " , qkv.size())

    qkv = qkv.permute(0, 2, 1, 3)
    print("qkv size : " , qkv.size())

    q, k, v = qkv.chunk(3, dim=-1)
    print(f"q size: {q.size()}, k size: {k.size()}, v size: {v.size()}")

    values, attention = scaled_dot_product(q, k, v, mask)
    print(f"values.size(): {values.size()}, attention.size: {attention.size()}")

    values = values.reshape(batch_size, sequence_length, self.num_heads * self.head_dim)
    print(f"values.size(): {values.size()}")

    out = self.linear_layer(values)
    print(f"out.size(): {out.size()}")
    return out

In [9]:
# Model/inputs setup
input_dim = 1024   # Input feature size per token
d_model = 512      # Embedding/model size (must divide num_heads)
num_heads = 8
batch_size = 30
sequence_length = 5

# Create random input
x = torch.randn((batch_size, sequence_length, input_dim))

# Instantiate MultiheadAttention class and run
model = MultiHeadAttention(input_dim, d_model, num_heads)
output = model.forward(x)

x.size :  30 5 1024
qkv size :  torch.Size([30, 5, 1536])
qkv size :  torch.Size([30, 5, 8, 192])
qkv size :  torch.Size([30, 8, 5, 192])
q size: torch.Size([30, 8, 5, 64]), k size: torch.Size([30, 8, 5, 64]), v size: torch.Size([30, 8, 5, 64])
values.size(): torch.Size([30, 8, 5, 64]), attention.size: torch.Size([30, 8, 5, 5])
values.size(): torch.Size([30, 5, 512])
out.size(): torch.Size([30, 5, 512])
