In [None]:
import torch
from torch import nn
import math

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_dim, n_head):
        super().__init__()
        self.n_head = n_head      
        self.embeddimg_dim = embedding_dim
        self.model_dim = self.embeddimg_dim // self.n_head
        
        assert (
            self.model_dim * self.n_head == self.embeddimg_dim
        ), "Embedding size needs to be divisible by heads"
        
        self.w_q = nn.Linear(embedding_dim, embedding_dim)
        self.w_k = nn.Linear(embedding_dim, embedding_dim)
        self.w_v = nn.Linear(embedding_dim, embedding_dim)
        self.w_combine = nn.Linear(embedding_dim, embedding_dim)
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, q, k, v, mask=None): # ? Why don't we normalize q, k, v before use it
        batch_size, seq_len, emb_dim = q.shape
        q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)
        q = q.view(batch_size, seq_len, self.n_head, self.model_dim).permute(0, 2, 1, 3)
        k = k.view(batch_size, seq_len, self.n_head, self.model_dim).permute(0, 2, 1, 3)
        v = v.view(batch_size, seq_len, self.n_head, self.model_dim).permute(0, 2, 1, 3)
        score = q @ k.transpose(2, 3) / math.sqrt(self.model_dim)
        
        if mask is not None:
            score.masked_fill_(mask=mask, value=-math.inf)
        
        score = self.softmax(score) @ v
        print(score.shape)
        score = score.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len, emb_dim)  # When something goes wrong with permute, try contiguous and see what will happen
        output = self.w_combine(score)
        return output

In [None]:
x = torch.rand(128, 32, 512)
w_q = nn.Linear(512, 10)
w_q(x).shape

In [None]:
attn = MultiHeadAttention(512, 8)
q = torch.rand(128, 32, 512)
k = torch.rand(128, 32, 512)
v = torch.rand(128, 32, 512)
attn(q, k ,v).shape