In [1]:
import torch as pt
import torch.nn as nn

In [2]:
class SelfAttention(nn.Module):
    def __init__(self, d_m, d_k) -> None:
        super(SelfAttention, self).__init__()
        self.d_m = d_m
        self.d_k = d_k
        self.Wq = nn.Linear(in_features=self.d_m, out_features=self.d_k)
        self.Wk = nn.Linear(in_features=self.d_m, out_features=self.d_k)
        self.Wv = nn.Linear(in_features=self.d_m, out_features=self.d_k)

    def forward(self, x):
        q,k,v = self.Wq(x), self.Wk(x),self.Wv(x)
        score = pt.einsum('nci, ncj -> nc',q,k)
        score /= pt.sqrt(pt.tensor(self.d_k))
        score = nn.functional.softmax(score, dim=-1)
        out = v * score[:,:,None]
        return out

In [10]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads) -> None:
        super(MultiHeadSelfAttention, self).__init__()
        assert embed_dim % num_heads == 0, 'embed_dim must be divided by num_heads'
        self.heads = nn.ModuleList([SelfAttention(embed_dim, embed_dim//num_heads) for _ in range(num_heads)])
        self.Wo = nn.Linear(in_features=embed_dim, out_features=embed_dim)
    
    def forward(self, x):
        Z = pt.cat([head(x) for head in self.heads], dim=-1)
        return self.Wo(Z)
 
 
''' example '''
a = pt.randn(2, 3, 10)
m = MultiHeadSelfAttention(10, 5)
b = m(a)
b

In [None]:
m=nn.MultiheadAttention(256,1)
a=pt.randn(2,3,256)
x,_=m(a,a,a)
x