In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (self.head_dim*self.heads==self.embed_size), "embed_dim need to be diveded by heads"

        self.k = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.q = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.v = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.fc_out = nn.Linear(self.embed_size, self.embed_size)


    def forward(self, value, key, query, mask):
        b = query.shape[0]
        value_len, key_len, query_len = value.shape[1], key.shape[1], query.shape[1]

        q, k, v = self.q(query), self.k(key), self.v(value)

        q = q.reshape(b, query_len, self.heads, self.head_dim)
        k = k.reshape(b, key_len, self.heads, self.head_dim)
        v = v.reshape(b, value_len, self.heads, self.head_dim)

        attention = torch.einsum("bqhd,bkhd->bhqk", [q, k])

        if mask is not None:
            attention = attention.masked_fill(mask==0, float("-1e20"))
        attention = F.softmax(attention/math.sqrt(self.head_dim), dim=3)
        out = torch.einsum("bhqk,bkhd->bqhd", [attention, v]).reshape(b, query_len, self.embed_size)
        out = self.fc_out(out)
        return out



