In [1]:
import torch
import torch.nn as nn

In [14]:
class MultiHeadAttention(nn.Module):

    def __init__(self,din,dout,contextlength,dropout,numheads,qkvbias=False):
        super().__init__()
        assert (dout % numheads == 0), \
            "dout must be divisible by numheads"
        
        self.dout=dout
        self.numheads=numheads
        self.headdim=dout // numheads
        
        self.wquery=nn.Linear(din,dout,bias=qkvbias)
        self.wkey=nn.Linear(din,dout,bias=qkvbias)
        self.wvalue=nn.Linear(din,dout,bias=qkvbias)
        self.outproj=nn.Linear(dout,dout)
        self.dropout=nn.Dropout(dropout)
        self.register_buffer('mask',torch.triu(torch.ones(contextlength,contextlength),diagonal=1))
        
    
    
    
    def forward (self,x):
        b,numtokens,din=x.shape
        keys=self.wkey(x)
        values=self.wvalue(x)
        queries=self.wvalue(x)

        keys=keys.view(b,numtokens,self.numheads,self.headdim)
        values=values.view(b,numtokens,self.numheads,self.headdim)
        queries=queries.view(b,numtokens,self.numheads,self.headdim)

        keys=keys.transpose(1,2)
        queries=queries.transpose(1,2)
        values=values.transpose(1,2)

        attnscores = queries @ keys.transpose(2,3)

        maskbool=self.mask.bool()[:numtokens,:numtokens]

        attnscores.masked_fill_(maskbool,-torch.inf)

        attnweights=torch.softmax(attnscores/keys.shape[-1]**0.5,dim=-1)
        attnweights=self.dropout(attnweights)

        contextvec=(attnweights @ values). transpose(1,2)

        contextvec=contextvec.contiguous().view(b,numtokens,self.dout)
        contextvec=self.outproj(contextvec)

        return contextvec
        
        

In [12]:
torch.manual_seed(123)

inputs= torch.tensor(
    [[0.43,0.15,0.89,0.55,0.87,0.66],
     [0.57,0.85,0.64,0.22,0.58,0.33],
     [0.77,0.25,0.10,0.05,0.80,0.55]]
)
batch=torch.stack((inputs,inputs),dim=0)
print(batch.shape)

torch.Size([2, 3, 6])


In [15]:
batchsize,contextlength,din=batch.shape
dout=6
mha=MultiHeadAttention(din,dout,contextlength,0.0,numheads=2)
contextvecs=mha(batch)
print(contextvecs)
print("context_vecs.shape:",contextvecs.shape)

tensor([[[-0.3364,  0.0697, -0.1239,  0.0898,  0.3611,  0.3157],
         [-0.2439,  0.0562, -0.0920,  0.0899,  0.3436,  0.2111],
         [-0.1770,  0.0371, -0.1079, -0.0033,  0.3518,  0.1444]],

        [[-0.3364,  0.0697, -0.1239,  0.0898,  0.3611,  0.3157],
         [-0.2439,  0.0562, -0.0920,  0.0899,  0.3436,  0.2111],
         [-0.1770,  0.0371, -0.1079, -0.0033,  0.3518,  0.1444]]],
       grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 3, 6])
