In [1]:
import torch
import torch.nn as nn


In [None]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embedDimension, numHeads, dropout = 0.2):
        super().__init__()

        assert embedDimension%numHeads == 0, "Embedding Dimension is Not Divisible By NumHeads"
        self.embedDimension = embedDimension
        self.numHeads = numHeads
        self.headDim = embedDimension//numHeads

        self.queryKeyValue = nn.Linear(embedDimension, embedDimension * 3, bias=False)
        self.drop = nn.Dropout(dropout)
        self.scale = self.headDim ** -0.5 
        self.outProjection = nn.Linear(embedDimension, embedDimension)

        nn.init.zeros_(self.queryKeyValue.weight)
        nn.init.zeros_(self.outProjection.weight)

    def forward(self, x):
        BatchSize, N, EmbedDim = x.shape

        qkv = self.queryKeyValue(x)
        qkv = qkv.reshape(BatchSize, N, 3, self.numHeads, EmbedDim // self.numHeads)
        q, k, v = qkv.unbind(2)
        print(q.shape, k.shape, v.shape)
        attentionScore = (q @ k.transpose(-2, -1)) * self.scale
        att = attentionScore.softmax(dim=-1)
        out = att @ v 
        print(out.shape)
        out = out.transpose(1, 2).reshape(BatchSize, N, EmbedDim)
        out = self.outProjection(out)
        out = self.drop(out)
        return out
    
mhsa = MultiHeadSelfAttention(embedDimension = 1024, numHeads = 8)
x = torch.randn(2, 64, 1024)
out = mhsa(x)
out.shape

torch.Size([2, 64, 8, 128]) torch.Size([2, 64, 8, 128]) torch.Size([2, 64, 8, 128])
torch.Size([2, 64, 8, 8])


torch.Size([2, 64, 1024])

# Key and Value from Encoder and Query from Decoder

In [None]:
# Key