In [2]:
import torch 
import torch.nn as nn

In [3]:
x = torch.rand(128, 32, 512)

In [4]:
import torch
import math
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads=8):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.num_heads = num_heads
        self.d_model = d_model
        self.d_k = d_model // num_heads
        
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, q, k, v, mask=None):
        batch_size, seq_len, _ = q.shape
        
        q = self.w_q(q).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        k = self.w_k(k).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        v = self.w_v(v).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        score = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            score = score.masked_fill(mask == 0, -1e9)
        
        score = self.softmax(score)
        attention = torch.matmul(score, v)
        
        attention = attention.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        output = self.w_o(attention)
        
        return output