In [78]:
import torch as t
import torch.nn as nn

In [79]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_dims=512, heads=8):
        super(MultiHeadAttention, self).__init__()
        self.embedding_dims = embedding_dims
        self.heads = heads
        self.feature_map = embedding_dims//heads
        self.Wq = nn.Parameter(t.randn(embedding_dims, embedding_dims, dtype = t.float64)) # parameter class makes it learnable
        self.Wk = nn.Parameter(t.randn(embedding_dims, embedding_dims, dtype = t.float64)) # parameter class makes it learnable
        self.Wv = nn.Parameter(t.randn(embedding_dims, embedding_dims, dtype = t.float64)) # parameter class makes it learnable
    def forward(self,X:t.Tensor):
        self.Xshape = X.shape
        self.Xq = X.clone()
        self.Xk = X.clone()
        self.Xv = X.clone()
        self.Q = t.matmul(self.Xq,self.Wq)
        self.K = t.matmul(self.Xk,self.Wk)
        self.V = t.matmul(self.Xv,self.Wv)
        # concatination
        self.Q = self.Q.view(self.Xshape[0], self.Xshape[1], self.heads, self.feature_map).transpose(1,2)
        self.K = self.K.view(self.Xshape[0], self.Xshape[1], self.heads, self.feature_map).transpose(1,2)
        self.V = self.V.view(self.Xshape[0], self.Xshape[1], self.heads, self.feature_map).transpose(1,2)
    def product(self):
        product = t.matmul(self.Q, self.K.transpose(-2,-1))
        scale = t.sqrt(t.tensor(self.embedding_dims//self.heads))
        product = product/scale
        product = t.nn.functional.softmax(product, dim=-1)
        contextual_embedding = t.matmul(product, self.V)
        contextual_embedding = contextual_embedding.transpose(1, 2).contiguous()
        output = contextual_embedding.view(self.Xshape[0], self.Xshape[1], self.embedding_dims)
        return output

In [80]:
attention = MultiHeadAttention()
X = t.randn(1,4,512, dtype = t.float64)

In [81]:
attention.forward(X)

In [82]:
contextual_embedding = attention.product()

In [83]:
contextual_embedding.shape

torch.Size([1, 4, 512])