## Multi Head Attention

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
sequence_length = 4
batch_size = 1
d_model = 512
x = torch.randn( (batch_size, sequence_length, d_model) )

In [None]:
x.size()

torch.Size([1, 4, 512])

In [None]:
qkv_layer = nn.Linear(d_model , 3 * d_model)

In [None]:
qkv = qkv_layer(x)

In [None]:
qkv.shape

torch.Size([1, 4, 1536])

In [None]:
num_heads = 8
head_dim = d_model // num_heads
qkv = qkv.reshape(batch_size, sequence_length, num_heads, 3 * head_dim)

In [None]:
qkv.shape

torch.Size([1, 4, 8, 192])

In [None]:
qkv = qkv.permute(0, 2, 1, 3) # [batch_size, num_heads, sequence_length, 3*head_dim]
qkv.shape

torch.Size([1, 8, 4, 192])

In [None]:
q, k, v = qkv.chunk(3, dim=-1)
q.shape, k.shape, v.shape

(torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]))

## Self Attention for multiple heads

For a single head:
$$
\text{self attention} = softmax\bigg(\frac{Q.K^T}{\sqrt{d_k}}+M\bigg)
$$

$$
\text{new V} = \text{self attention}.V
$$

In [None]:
import math

In [None]:
d_k = q.size()[-1]
scaled = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  #k.T

scaled.shape

torch.Size([1, 8, 4, 4])

In [None]:
mask = torch.full(scaled.size() , float('-inf'))
mask = torch.triu(mask, diagonal=1)
mask[0][1] # mask for input to a single head

tensor([[0., -inf, -inf, -inf],
        [0., 0., -inf, -inf],
        [0., 0., 0., -inf],
        [0., 0., 0., 0.]])

In [None]:
(scaled + mask)[0][0]

tensor([[-0.0701,    -inf,    -inf,    -inf],
        [-0.4234, -0.1145,    -inf,    -inf],
        [ 0.3848, -0.4978, -0.0802,    -inf],
        [-0.1478,  0.1420, -0.2592,  0.2360]], grad_fn=<SelectBackward0>)

In [None]:
scaled += mask

In [None]:
attention = F.softmax(scaled, dim=-1)

In [None]:
attention.shape

torch.Size([1, 8, 4, 4])

In [None]:
attention[0][0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.4234, 0.5766, 0.0000, 0.0000],
        [0.4897, 0.2026, 0.3076, 0.0000],
        [0.2128, 0.2844, 0.1904, 0.3124]], grad_fn=<SelectBackward0>)

In [None]:
values = torch.matmul(attention, v)
values.shape

torch.Size([1, 8, 4, 64])

## Function

In [None]:
import math

def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    scaled = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(d_k)
    if mask is not None:
        scaled += mask
    attention = F.softmax(scaled, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

In [None]:
values, attention = scaled_dot_product(q, k, v, mask=mask)

In [None]:
attention.shape

torch.Size([1, 8, 4, 4])

In [None]:
attention[0][0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.4234, 0.5766, 0.0000, 0.0000],
        [0.4897, 0.2026, 0.3076, 0.0000],
        [0.2128, 0.2844, 0.1904, 0.3124]], grad_fn=<SelectBackward0>)

In [None]:
values.size()

torch.Size([1, 8, 4, 64])

In [None]:
values = values.reshape(batch_size, sequence_length, num_heads * head_dim)
values.size()

torch.Size([1, 4, 512])

In [None]:
linear_layer = nn.Linear(d_model, d_model)

In [None]:
out = linear_layer(values)

In [None]:
out.shape

torch.Size([1, 4, 512])

In [None]:
out

tensor([[[ 0.0553,  0.2046, -0.1234,  ...,  0.1395, -0.1228, -0.0310],
         [-0.1912,  0.0805, -0.1417,  ..., -0.2382,  0.0612,  0.1027],
         [-0.0061, -0.2826, -0.0949,  ...,  0.5048,  0.5449,  0.0107],
         [-0.3642,  0.1160,  0.2601,  ...,  0.0895, -0.1580,  0.1855]]],
       grad_fn=<ViewBackward0>)

## Class

In [None]:
import torch
import torch.nn as nn
import math

def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    scaled = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(d_k)
    if mask is not None:
        scaled += mask
    attention = F.softmax(scaled, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

class MultiheadAttention(nn.Module):

    def __init__(self, input_dim, d_model, num_heads):
        super().__init__()
        self.input_dim = input_dim
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.qkv_layer = nn.Linear(input_dim , 3 * d_model)
        self.linear_layer = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        batch_size, sequence_length, input_dim = x.size()
        print(f"x.size(): {x.size()}")
        qkv = self.qkv_layer(x)
        print(f"qkv.size(): {qkv.size()}")
        qkv = qkv.reshape(batch_size, sequence_length, self.num_heads, 3 * self.head_dim)
        print(f"qkv.size(): {qkv.size()}")
        qkv = qkv.permute(0, 2, 1, 3)
        print(f"qkv.size(): {qkv.size()}")
        q, k, v = qkv.chunk(3, dim=-1)
        print(f"q size: {q.size()}, k size: {k.size()}, v size: {v.size()}, ")
        values, attention = scaled_dot_product(q, k, v, mask)
        print(f"values.size(): {values.size()}, attention.size:{ attention.size()} ")
        values = values.reshape(batch_size, sequence_length, self.num_heads * self.head_dim)
        print(f"values.size(): {values.size()}")
        out = self.linear_layer(values)
        print(f"out.size(): {out.size()}")
        return out


## Input

In [None]:
d_model = 1024
num_heads = 8

batch_size = 30
sequence_length = 5
x = torch.randn( (batch_size, sequence_length, d_model) )

model = MultiheadAttention(d_model, d_model, num_heads)
out = model.forward(x)

x.size(): torch.Size([30, 5, 1024])
qkv.size(): torch.Size([30, 5, 3072])
qkv.size(): torch.Size([30, 5, 8, 384])
qkv.size(): torch.Size([30, 8, 5, 384])
q size: torch.Size([30, 8, 5, 128]), k size: torch.Size([30, 8, 5, 128]), v size: torch.Size([30, 8, 5, 128]), 
values.size(): torch.Size([30, 8, 5, 128]), attention.size:torch.Size([30, 8, 5, 5]) 
values.size(): torch.Size([30, 5, 1024])
out.size(): torch.Size([30, 5, 1024])
