In [3]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [14]:
def scaled_dot_product(q, k, v, mask = None):
    d_k = q.size()[-1]
    scaled = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scaled += mask
    attention = F.softmax(scaled, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

class MultiheadAttention(nn.Module):
    def __init__(self, input_dim, d_model, num_heads):
        super().__init__()
        self.input_dim = input_dim
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.qkv_layer = nn.Linear(input_dim, 3 * d_model)
        self.linear_layer = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        batch_size, sequence_length, input_dim = x.size()
        print(f"x.size():{x.size()}")
        qkv = self.qkv_layer(x)
        print(f"qkv.size(): {qkv.size()}")
        qkv = qkv.reshape(batch_size, sequence_length, self.num_heads, 3 * self.head_dim)
        print(f"qkv.size(): {qkv.size()}")
        qkv = qkv.permute(0, 2, 1, 3)
        print(f"qkv.size(): {qkv.size()}")
        q, k, v = qkv.chunk(3, dim = -1)
        print(f"q size: {q.size()}, k size: {k.size()}, v size: {v.size()} ")
        values, attention = scaled_dot_product(q, k, v, mask=mask)
        print(f"value.size(): {values.size()}, attention.size: {attention.size()}")
        values = values.reshape(batch_size, sequence_length, self.num_heads * self.head_dim)
        print(f"value.size(): {values.size()}")
        out= self.linear_layer(values)
        print(f"out.size(): {out.size()}")
        return out
        

In [15]:
input_dim = 1024
d_model = 512
num_heads = 8

batch_size = 30
sequence_length = 5

x = torch.randn( (batch_size, sequence_length, input_dim) )
model = MultiheadAttention(input_dim, d_model, num_heads)
out = model.forward(x)

x.size():torch.Size([30, 5, 1024])
qkv.size(): torch.Size([30, 5, 1536])
qkv.size(): torch.Size([30, 5, 8, 192])
qkv.size(): torch.Size([30, 8, 5, 192])
q size: torch.Size([30, 8, 5, 64]), k size: torch.Size([30, 8, 5, 64]), v size: torch.Size([30, 8, 5, 64]) 
value.size(): torch.Size([30, 8, 5, 64]), attention.size: torch.Size([30, 8, 5, 5])
value.size(): torch.Size([30, 5, 512])
out.size(): torch.Size([30, 5, 512])
