In [2]:
import torch 
import torch.nn as nn
import torch.nn.functional as F

In [15]:
class MultiHeadAttention(nn.Module):
    def __init__(self,embed_dim,num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0

        self.embed_dim=embed_dim
        self.num_heads=num_heads
        self.dk=embed_dim//num_heads


        self.WQ=nn.Linear(embed_dim,embed_dim,bias=False)
        self.WK=nn.Linear(embed_dim,embed_dim,bias=False)
        self.WV=nn.Linear(embed_dim,embed_dim,bias=False)


        self.WO=nn.Linear(embed_dim,embed_dim,bias=False)


        
    def forward(self,X,mask=None):
        B,T,D=X.shape
        Q=self.WQ(X)
        K=self.WK(X)
        V=self.WV(X)

        Q=Q.view(B,T,self.num_heads,self.dk).transpose(1,2)
        K=K.view(B,T,self.num_heads,self.dk).transpose(1,2)
        V=V.view(B,T,self.num_heads,self.dk).transpose(1,2)

    ## Q,K,V =- (B,H,T,dK)
        scores=torch.matmul(Q,K.transpose(-2,-1))
        scores=scores/(self.dk**0.5)## scaling factor to reduce vanish gradient
    ## scores=[Q,K,Ti,Tj]
        attn_weights=F.softmax(scores,dim=-1)## dim -1 means only effect on Tj
        head_output=torch.matmul(attn_weights,V)##what we select 
    ## head_output=(B,H,T,DK)
        concat=head_output.transpose(1,2).contiguous()
        concat=concat.view(B,T,D)

        output=self.WO(concat)

        return output,attn_weights

In [16]:
B = 2
T = 5
D = 32
h = 4

X = torch.randn(B, T, D)

mha = MultiHeadAttention(D, h)

output, attn = mha(X)

print(output.shape)   # (2, 5, 32)
print(attn.shape)     # (2, 4, 5, 5)


torch.Size([2, 5, 32])
torch.Size([2, 4, 5, 5])
