https://towardsdatascience.com/transformers-explained-visually-part-3-multi-head-attention-deep-dive-1c1ff1024853

https://nn.labml.ai/

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

In [19]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
sequence_length = 4
batch_size = 1
input_dim = 512
d_model = 512
x = torch.randn( (batch_size, sequence_length, input_dim) )

In [7]:
print(x.shape)

torch.Size([1, 4, 512])


In [4]:
x

tensor([[[-1.3894,  1.8079, -0.4922,  ...,  0.4141,  0.7742, -1.3017],
         [ 0.5475, -0.6181,  0.4773,  ...,  0.4435,  0.0295,  0.9851],
         [-0.5884,  0.3026, -0.5085,  ..., -0.2530,  1.3438, -0.5324],
         [ 1.8149,  0.8070,  0.2444,  ...,  0.9461, -1.0576,  1.7348]]])

In [11]:
qkv_layer = nn.Linear(input_dim , 3 * d_model)
qkv_layer

Linear(in_features=512, out_features=1536, bias=True)

In [12]:
qkv = qkv_layer(x)
qkv

tensor([[[-1.7449, -0.1229, -0.5324,  ..., -0.2585, -0.3609,  0.8873],
         [ 0.3557, -0.1613, -0.5249,  ...,  0.0855,  0.0497,  0.1466],
         [-0.3019,  0.3314, -0.1162,  ...,  0.3522, -0.4374,  0.2474],
         [ 0.2902, -0.0559, -0.0318,  ..., -0.2233,  0.5550, -1.1974]]],
       grad_fn=<ViewBackward0>)

In [13]:
qkv.size()

torch.Size([1, 4, 1536])

In [14]:
num_heads = 8
head_dim = d_model // num_heads
qkv = qkv.reshape(batch_size, sequence_length, num_heads, 3 * head_dim)
qkv

tensor([[[[-1.7449e+00, -1.2295e-01, -5.3236e-01,  ..., -1.8445e-01,
            1.2201e-01,  6.9133e-01],
          [ 1.7298e-01,  2.5862e-01, -5.3273e-01,  ..., -5.8945e-01,
            5.6178e-01,  3.4953e-01],
          [-1.3525e-01,  8.2513e-01, -7.3156e-01,  ...,  8.5626e-02,
           -7.6213e-02, -7.1592e-01],
          ...,
          [-5.4403e-01, -3.5871e-01,  8.7709e-01,  ...,  4.3994e-01,
            6.0564e-01, -1.0475e+00],
          [ 5.3648e-01, -2.6786e-01,  1.4037e-01,  ..., -1.6509e-01,
           -3.5366e-01, -7.4077e-01],
          [-2.8201e-01, -5.6094e-01, -6.3094e-01,  ..., -2.5846e-01,
           -3.6089e-01,  8.8730e-01]],

         [[ 3.5571e-01, -1.6133e-01, -5.2491e-01,  ..., -3.4181e-01,
            4.4692e-01,  7.6768e-01],
          [-4.2524e-01,  7.3394e-01, -7.7988e-02,  ..., -7.8987e-01,
           -1.2612e-02, -2.1245e-02],
          [-3.3983e-01,  5.3461e-01, -1.5202e-01,  ..., -2.1632e-01,
            3.1865e-01,  8.4330e-01],
          ...,
     

In [15]:
qkv.size()

torch.Size([1, 4, 8, 192])

In [16]:
qkv = qkv.permute(0, 2, 1, 3) # [batch_size, num_heads, sequence_length, 3*head_dim]
qkv.shape

torch.Size([1, 8, 4, 192])

In [17]:
q, k, v = qkv.chunk(3, dim=-1)
q.shape, k.shape, v.shape

(torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]))

In [20]:
d_k = q.size()[-1]
scaled = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
scaled.shape

torch.Size([1, 8, 4, 4])

In [21]:
k.T.shape

  """Entry point for launching an IPython kernel.


torch.Size([64, 4, 8, 1])

In [22]:
y = torch.randn(2, 3)
torch.transpose(y, 0, 1)

tensor([[-1.6494,  0.0590],
        [-0.1545, -0.8527],
        [ 1.1263,  1.5299]])

In [23]:
k.transpose(-1, -2) == k.transpose(-2, -1)

tensor([[[[True, True, True, True],
          [True, True, True, True],
          [True, True, True, True],
          ...,
          [True, True, True, True],
          [True, True, True, True],
          [True, True, True, True]],

         [[True, True, True, True],
          [True, True, True, True],
          [True, True, True, True],
          ...,
          [True, True, True, True],
          [True, True, True, True],
          [True, True, True, True]],

         [[True, True, True, True],
          [True, True, True, True],
          [True, True, True, True],
          ...,
          [True, True, True, True],
          [True, True, True, True],
          [True, True, True, True]],

         ...,

         [[True, True, True, True],
          [True, True, True, True],
          [True, True, True, True],
          ...,
          [True, True, True, True],
          [True, True, True, True],
          [True, True, True, True]],

         [[True, True, True, True],
          [True, 

In [24]:
k.transpose(-1, -2).shape

torch.Size([1, 8, 64, 4])

In [25]:
mask = torch.full(scaled.size() , float('-inf'))
mask = torch.triu(mask, diagonal=1)
mask[0][1] # mask for input to a single head

tensor([[0., -inf, -inf, -inf],
        [0., 0., -inf, -inf],
        [0., 0., 0., -inf],
        [0., 0., 0., 0.]])

In [26]:
(scaled + mask)[0][0]

tensor([[ 0.2927,    -inf,    -inf,    -inf],
        [ 0.1736, -0.5931,    -inf,    -inf],
        [-0.0556, -0.0105,  0.1946,    -inf],
        [ 0.3081, -0.1581,  0.2623, -0.1884]], grad_fn=<SelectBackward0>)

In [27]:
scaled += mask

In [28]:
np.exp(0.5596) / (np.exp(0.5596) + np.exp(0.0404))

0.6269606805367254

In [29]:
attention = F.softmax(scaled, dim=-1)

In [30]:
attention.shape

torch.Size([1, 8, 4, 4])

In [31]:
values = torch.matmul(attention, v)
values.shape

torch.Size([1, 8, 4, 64])

#### Function

In [32]:
import math

def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    scaled = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(d_k)
    if mask is not None:
        scaled += mask
    attention = F.softmax(scaled, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

In [33]:
values, attention = scaled_dot_product(q, k, v, mask=mask)

In [34]:
attention.shape

torch.Size([1, 8, 4, 4])

In [35]:
attention[0][0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.6828, 0.3172, 0.0000, 0.0000],
        [0.3003, 0.3141, 0.3856, 0.0000],
        [0.3134, 0.1966, 0.2993, 0.1907]], grad_fn=<SelectBackward0>)

In [36]:
values.size()

torch.Size([1, 8, 4, 64])

In [37]:
values = values.reshape(batch_size, sequence_length, num_heads * head_dim)
values.size()

torch.Size([1, 4, 512])

In [38]:
linear_layer = nn.Linear(d_model, d_model)

In [39]:
out = linear_layer(values)

In [40]:
out.shape

torch.Size([1, 4, 512])

In [41]:
out

tensor([[[ 0.0141, -0.1967,  0.1734,  ..., -0.1613, -0.1859, -0.2678],
         [ 0.1112, -0.0805, -0.1931,  ...,  0.0825,  0.0424, -0.2536],
         [-0.0491, -0.0923,  0.1491,  ...,  0.1260,  0.1025,  0.0531],
         [-0.5752, -0.2584,  0.2120,  ..., -0.0178,  0.0558, -0.2604]]],
       grad_fn=<ViewBackward0>)

### Class

In [42]:
import torch
import torch.nn as nn
import math

def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    scaled = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(d_k)
    if mask is not None:
        scaled += mask
    attention = F.softmax(scaled, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

class MultiheadAttention(nn.Module):

    def __init__(self, input_dim, d_model, num_heads):
        super().__init__()
        self.input_dim = input_dim
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.qkv_layer = nn.Linear(input_dim , 3 * d_model)
        self.linear_layer = nn.Linear(d_model, d_model)
    
    def forward(self, x, mask=None):
        batch_size, sequence_length, input_dim = x.size()
        print(f"x.size(): {x.size()}")
        qkv = self.qkv_layer(x)
        print(f"qkv.size(): {qkv.size()}")
        qkv = qkv.reshape(batch_size, sequence_length, self.num_heads, 3 * self.head_dim)
        print(f"qkv.size(): {qkv.size()}")
        qkv = qkv.permute(0, 2, 1, 3)
        print(f"qkv.size(): {qkv.size()}")
        q, k, v = qkv.chunk(3, dim=-1)
        print(f"q size: {q.size()}, k size: {k.size()}, v size: {v.size()}, ")
        values, attention = scaled_dot_product(q, k, v, mask)
        print(f"values.size(): {values.size()}, attention.size:{ attention.size()} ")
        values = values.reshape(batch_size, sequence_length, self.num_heads * self.head_dim)
        print(f"values.size(): {values.size()}")
        out = self.linear_layer(values)
        print(f"out.size(): {out.size()}")
        return out

### Input

In [43]:
input_dim = 1024
d_model = 512
num_heads = 8

batch_size = 30
sequence_length = 5
x = torch.randn( (batch_size, sequence_length, input_dim) )

model = MultiheadAttention(input_dim, d_model, num_heads)
out = model.forward(x)

x.size(): torch.Size([30, 5, 1024])
qkv.size(): torch.Size([30, 5, 1536])
qkv.size(): torch.Size([30, 5, 8, 192])
qkv.size(): torch.Size([30, 8, 5, 192])
q size: torch.Size([30, 8, 5, 64]), k size: torch.Size([30, 8, 5, 64]), v size: torch.Size([30, 8, 5, 64]), 
values.size(): torch.Size([30, 8, 5, 64]), attention.size:torch.Size([30, 8, 5, 5]) 
values.size(): torch.Size([30, 5, 512])
out.size(): torch.Size([30, 5, 512])
