In [2]:
import torch 
import torch.nn as nn

In [3]:
x = torch.rand(128, 32, 512)

In [4]:
import torch
import math
class MutiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads=8):
        super(MutiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.w_q = nn.Linear(d_model, num_heads)
        self.w_k = nn.Linear(d_model, num_heads)
        self.w_v = nn.Linear(d_model, num_heads)
        self.w_o = nn.Linear(d_model, num_heads)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, q, k, v, mask=None):
        batch, time, dimension = q.shape
        n_d = self.d_model // self.num_heads
        q = self.w_q(q).view(batch, time, self.num_heads, n_d).permute(0, 2, 1, 3)
        k = self.w_k(k).view(batch, time, self.num_heads, n_d).permute(0, 2, 1, 3)
        v = self.w_v(v).view(batch, time, self.num_heads, n_d).permute(0, 2, 1, 3) 
        score = q@k.transpose(2, 3) / math.sqrt(n_d)
        if mask is not None:
            score = score.masked_fill(mask == 0, -1e9)
        score = self.softmax(score)
        score = score@v
        score = score.permute(0, 2, 1, 3).contiguous().view(batch, time, dimension)
        out = self.w_o(score)
        return out 
