In [1]:
import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
sequence_length = 4
batch_size = 1
input_dim = 512
d_model = 512

x = torch.randn(batch_size, sequence_length,input_dim)

In [3]:
x.size()

torch.Size([1, 4, 512])

In [4]:
qkv_layer = nn.Linear(input_dim,3*d_model)

In [5]:
qkv = qkv_layer(x)

In [7]:
qkv.size()

torch.Size([1, 4, 1536])

In [8]:
num_head = 8
head_dim = d_model//num_head
qkv = qkv.reshape(batch_size,sequence_length,num_head,3*head_dim)

In [9]:
qkv.size()

torch.Size([1, 4, 8, 192])

In [15]:
qkv = qkv.permute(0,2,1,3)
qkv.size()

torch.Size([1, 8, 4, 192])

In [13]:
q,k,v = qkv.chunk(3,dim=-1)
q.shape,k.shape,v.shape

(torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]))

In [19]:
import math

In [24]:
d_k = q.size()[-1]
scaled = torch.matmul(q,k.transpose(-2,-1)) / math.sqrt(d_k)
scaled.shape

torch.Size([1, 8, 4, 4])

In [25]:
mask = torch.full(scaled.size(),float('-inf'))
mask = torch.triu(mask,diagonal=1)
mask[0][1]

tensor([[0., -inf, -inf, -inf],
        [0., 0., -inf, -inf],
        [0., 0., 0., -inf],
        [0., 0., 0., 0.]])

In [28]:
scaled+mask[0][0]

tensor([[[[ 1.0154e-01,        -inf,        -inf,        -inf],
          [ 3.6130e-01, -1.4847e-01,        -inf,        -inf],
          [-1.8076e-01,  3.2821e-01,  1.7455e-01,        -inf],
          [ 3.2415e-02,  1.1816e-01,  2.4722e-01, -2.9844e-02]],

         [[-2.0668e-01,        -inf,        -inf,        -inf],
          [-2.7394e-01,  2.1942e-01,        -inf,        -inf],
          [-1.5957e-01, -2.8030e-01, -3.9850e-02,        -inf],
          [ 1.4530e-01, -2.1599e-01,  1.1453e-01, -4.2530e-01]],

         [[-6.7421e-02,        -inf,        -inf,        -inf],
          [-6.6072e-01,  2.7708e-02,        -inf,        -inf],
          [ 2.2837e-01,  1.7095e-01, -2.4681e-01,        -inf],
          [-1.9716e-01,  7.1135e-02, -2.2222e-01,  3.1593e-01]],

         [[ 2.9641e-01,        -inf,        -inf,        -inf],
          [ 4.9337e-01, -8.1680e-01,        -inf,        -inf],
          [-3.3092e-01, -1.4306e-01, -1.1627e-01,        -inf],
          [-1.5345e-01, -3.0604e-0

In [31]:
scaled+=mask

In [32]:
attention = F.softmax(scaled,dim=-1)

In [33]:
attention

tensor([[[[1.0000, 0.0000, 0.0000, 0.0000],
          [0.6248, 0.3752, 0.0000, 0.0000],
          [0.2445, 0.4067, 0.3488, 0.0000],
          [0.2343, 0.2552, 0.2904, 0.2201]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.3791, 0.6209, 0.0000, 0.0000],
          [0.3318, 0.2941, 0.3740, 0.0000],
          [0.3094, 0.2156, 0.3001, 0.1749]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.3344, 0.6656, 0.0000, 0.0000],
          [0.3897, 0.3680, 0.2423, 0.0000],
          [0.2019, 0.2640, 0.1969, 0.3372]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.7875, 0.2125, 0.0000, 0.0000],
          [0.2902, 0.3502, 0.3597, 0.0000],
          [0.2595, 0.2934, 0.1779, 0.2693]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.3664, 0.6336, 0.0000, 0.0000],
          [0.3888, 0.3093, 0.3019, 0.0000],
          [0.2919, 0.2649, 0.2588, 0.1844]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.3885, 0.6115, 0.0000, 0.0000],
          [0.1716, 0.4

In [36]:
values = torch.matmul(attention,v)
values.shape

torch.Size([1, 8, 4, 64])

function


In [37]:
import torch
import torch.nn as nn
import math

def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    scaled = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(d_k)
    if mask is not None:
        scaled += mask
    attention = F.softmax(scaled, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention



Class

In [38]:
class MultiheadAttention(nn.Module):

    def __init__(self, input_dim, d_model, num_heads):
        super().__init__()
        self.input_dim = input_dim
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.qkv_layer = nn.Linear(input_dim , 3 * d_model)
        self.linear_layer = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        batch_size, sequence_length, input_dim = x.size()
        print(f"x.size(): {x.size()}")
        qkv = self.qkv_layer(x)
        print(f"qkv.size(): {qkv.size()}")
        qkv = qkv.reshape(batch_size, sequence_length, self.num_heads, 3 * self.head_dim)
        print(f"qkv.size(): {qkv.size()}")
        qkv = qkv.permute(0, 2, 1, 3)
        print(f"qkv.size(): {qkv.size()}")
        q, k, v = qkv.chunk(3, dim=-1)
        print(f"q size: {q.size()}, k size: {k.size()}, v size: {v.size()}, ")
        values, attention = scaled_dot_product(q, k, v, mask)
        print(f"values.size(): {values.size()}, attention.size:{ attention.size()} ")
        values = values.reshape(batch_size, sequence_length, self.num_heads * self.head_dim)
        print(f"values.size(): {values.size()}")
        out = self.linear_layer(values)
        print(f"out.size(): {out.size()}")
        return out


In [39]:
input_dim = 1024
d_model = 512
num_heads = 8

batch_size = 30
sequence_length = 5
x = torch.randn( (batch_size, sequence_length, input_dim) )

model = MultiheadAttention(input_dim, d_model, num_heads)
out = model.forward(x)

x.size(): torch.Size([30, 5, 1024])
qkv.size(): torch.Size([30, 5, 1536])
qkv.size(): torch.Size([30, 5, 8, 192])
qkv.size(): torch.Size([30, 8, 5, 192])
q size: torch.Size([30, 8, 5, 64]), k size: torch.Size([30, 8, 5, 64]), v size: torch.Size([30, 8, 5, 64]), 
values.size(): torch.Size([30, 8, 5, 64]), attention.size:torch.Size([30, 8, 5, 5]) 
values.size(): torch.Size([30, 5, 512])
out.size(): torch.Size([30, 5, 512])
