<a href="https://colab.research.google.com/github/AyoubMDL/transformers_from_scratch/blob/main/multi_head_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [12]:
# q.shape, k.shape, v.shape

(torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]))

In [24]:
# mask = torch.full(scaled.size(), float('-inf'))
# mask = torch.triu(mask, diagonal=1)

In [10]:
import math

def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    scaled = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)

    if mask:
        scaled += mask
    
    attention = F.softmax(scaled, dim=-1)
    new_v = torch.matmul(attention, v)

    return new_v, attention

In [11]:
class MultiHeadAttention(nn.Module):
    def __init__(self, input_dim, d_model, num_heads):
        super().__init__()
        self.input_dim = input_dim
        self.d_model = d_model
        self.num_heads = num_heads 
        self.head_dim = d_model // num_heads
        self.qkv_layer = nn.Linear(input_dim, 3 * d_model)
        self.linear_layer = nn.Linear(d_model, d_model)
    
    def forward(self, x, mask=None):
        # shape x : (batch_size, seq_length, emb_dim)
        batch_size, seq_length, input_dim = x.size()

        # shape qkv : (batch_size, seq_length, 3 * d_model)
        qkv = self.qkv_layer(x)

        # shape qkv : (batch_size, seq_length, num_heads, 3 * head_dim)
        qkv = qkv.reshape(batch_size, seq_length, num_heads, 3 * self.head_dim)

        # shape qkv : (batch_size, num_heads, seq_length, 3 * head_dim) | permute seq_length and num_heads in order to apply parallelism 
        qkv = qkv.permute(0, 2, 1, 3)

        # shape q,k,v : (batch_size, num_heads, seq_length, head_dim)
        q, k, v = qkv.chunk(3, dim=-1)

        # shape values : (batch_size, num_heads, seq_length, head_dim)
        # shape attention : (batch_size, num_heads, seq_length, seq_length)
        values, attention = scaled_dot_product(q, k, v, mask)

        # shape value : (batch_size, seq_length, d_model)
        values = values.reshape(batch_size, seq_length, num_heads * self.head_dim)

        out = self.linear_layer(values)

        return out

In [3]:
input_dim = 1024
d_model = 512
num_heads = 8

batch_size = 10
seq_length = 5

x = torch.randn((batch_size, seq_length, input_dim))

In [4]:
x.shape

torch.Size([10, 5, 1024])

In [12]:
model = MultiHeadAttention(input_dim, d_model, num_heads)

In [13]:
out = model(x)

In [14]:
out.shape

torch.Size([10, 5, 512])