In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [21]:
batch_size = 1000
seq_len = 60
d_model = hidden_size = 128
input_dim = feature_num = 23
x = torch.randn(batch_size, seq_len, input_dim)
x.shape


torch.Size([1000, 60, 23])

In [22]:
qkv_layer = nn.Linear(input_dim, 3 * d_model)
qkv = qkv_layer(x)
# qkv: (batch_size, seq_len, 3 * d_model)

In [23]:
head_num = 8
head_dim = d_model // head_num
assert head_dim * head_num == d_model, "d_model must be divisible by head_num"

qkv = qkv.view(batch_size, seq_len, head_num, 3 * head_dim)
qkv = qkv.permute(0, 2, 1, 3) # (batch_size, head_num, seq_len, 3 * head_dim) 
# This makes it easier to perform parrellel copmutation on "seq_len" and "3*head_dim"

q,k,v = qkv.split([head_dim, head_dim, head_dim], dim=-1)
# each output: (batch_size, head_num, seq_len, head_dim)

Self Attention

In [27]:
# SELF ATTENTION
d_k = q.shape[-1]
scaled = torch.matmul(q,k.transpose(-2,-1)) / math.sqrt(d_k)
# scaled: (batch_size, head_num, seq_len, seq_len)

mask = torch.full(scaled.shape, float('-inf'))
mask = torch.triu(mask, diagonal=1)

scaled += mask
scaled[0][0]

attention = torch.softmax(scaled, dim=-1) # attention along the last seq_len dimension

In [28]:
values = torch.matmul(attention, v)
# values: (batch_size, head_num, seq_len, head_dim)
values.shape

torch.Size([1000, 8, 60, 16])

In [29]:
# combining the above steps into a function:
def scaled_dot_attentio(q,k,v,mask = None): # mask for decoder
    # q,k,v: (batch_size, head_num, seq_len, head_dim)
    d_k = q.shape[-1] # head_dim
    scaled = torch.matmul(q,k.transpose(-2,-1)) / math.sqrt(d_k)
    if mask is not None:
        scaled += mask
    attention = torch.softmax(scaled, dim=-1) # attention along the last seq_len dimension
    values = torch.matmul(attention, v)
    return values, attention
    # values: (batch_size, head_num, seq_len, head_dim)
    # attention: (batch_size, head_num, seq_len, seq_len)

In [30]:
values, attention = scaled_dot_attentio(q,k,v, mask = mask)
print(attention.shape)
print(attention[0][0])

torch.Size([1000, 8, 60, 60])
tensor([[1.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.4422, 0.5578, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.2279, 0.2571, 0.5149,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0107, 0.0124, 0.0127,  ..., 0.0255, 0.0000, 0.0000],
        [0.0078, 0.0177, 0.0128,  ..., 0.0134, 0.0091, 0.0000],
        [0.0181, 0.0275, 0.0166,  ..., 0.0114, 0.0074, 0.0141]],
       grad_fn=<SelectBackward0>)


In [31]:
values = values.reshape(batch_size, seq_len, head_num * head_dim)
values.shape
# values: (batch_size, seq_len, d_model)

torch.Size([1000, 60, 128])

In [35]:
class MultiheadAttention(nn.Module):
    def __init__(self, input_dim, d_model, num_heads):
        super().__init__()
        self.input_dim = input_dim
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        assert self.head_dim * num_heads == d_model, "d_model must be divisible by num_heads"
        self.qkv_layer = nn.Linear(input_dim, 3 * d_model)
        self.output_layer = nn.Linear(d_model, d_model)
    
    def forward(self, x, mask = None):
        batch_size, seq_len, input_dim = x.shape
        assert input_dim == self.input_dim, "input_dim must be equal to self.input_dim"
        qkv = self.qkv_layer(x)
        qkv = qkv.view(batch_size, seq_len, self.num_heads, 3 * self.head_dim)
        qkv = qkv.permute(0, 2, 1, 3) # (batch_size, head_num, seq_len, 3 * head_dim)
        q,k,v = qkv.split([self.head_dim, self.head_dim, self.head_dim], dim=-1)
        values, attention = scaled_dot_attentio(q,k,v, mask = mask)
        values = values.reshape(batch_size, seq_len, self.d_model)
        output = self.output_layer(values)
        return output, attention
    # output: (batch_size, seq_len, d_model)
    # attention: (batch_size, head_num, seq_len, seq_len)


In [38]:
model = MultiheadAttention(input_dim, d_model, head_num)
output, attention = model(x, mask = mask)
print(output.shape)
print(attention.shape)

torch.Size([1000, 60, 128])
torch.Size([1000, 8, 60, 60])
