In [52]:
import torch
from torch import nn
import math
import torch.nn.functional as F

In [53]:
batch_size = 3
h = 8
d_model = 512
d_k = d_model // h
seq_length = 5

In [54]:
class AttentionHead(nn.Module):
    """
    One head of the self-attention layer
    """

    def __init__(self, d_k, d_model, seq_length):
        super().__init__()
        
        self.wq = nn.Linear(in_features=d_model,out_features=d_k,bias=False)
        self.wk = nn.Linear(in_features=d_model,out_features=d_k,bias=False)
        self.wv = nn.Linear(in_features=d_model,out_features=d_k,bias=False)
        
    def forward(self, x):
        # B (batch_size), T (seq_length), C (d_model)
        B, T, C = x.shape

        # (batch_size, seq_length, d_model) -> (batch_size, seq_length, d_k)
        q = self.wq(x)
        k = self.wk(x)
        v = self.wv(x)

        # print(k.shape)

        # (batch_size, seq_length, d_k) * (batch_size, d_k, seq_length) 
        # -> (batch_size, seq_length, seq_length)
        qkt = q @ k.transpose(-2,-1) / math.sqrt(d_k)
        
        # apply mask (seq_length*seq_length) now
        mask = torch.tril((torch.ones(seq_length,seq_length)==1))==False
        qkt = qkt.masked_fill(mask, -torch.inf)
        
        sm = F.softmax(qkt, dim=-1)

        # (batch_size, seq_length, seq_length) * (batch_size, seq_length, d_model)
        # -> (batch_size, seq_length, d_model)
        att  = sm @ v
        
        return att

In [55]:
class MultiHeadAttention(nn.Module):
    """
    Multiple Heads of self-attention in parallel
    """

    def __init__(self, num_heads, head_size, num_embed, block_size):
        super().__init__()

        torch.manual_seed(3)
        self.heads = nn.ModuleList(
            [
                AttentionHead(
                    d_k=head_size,
                    d_model=num_embed,
                    seq_length=block_size,
                )
                for _ in range(num_heads)
            ]
        )
        self.proj = nn.Linear(num_embed, num_embed)

    def forward(self, x):
        # output of the self-attention
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        # apply the linear projection layer
        out = self.proj(out)
        return out

In [61]:
# my own implementation
class MultiHeadAttention2(nn.Module):

    def __init__(self, h, d_model, seq_length):
        super().__init__()

        self.h = h
        self.d_model = d_model
        self.seq_length = seq_length
        self.d_k = int(d_model // h)

        torch.manual_seed(3)
        self.mheads = nn.ModuleList([AttentionHead(self.d_k, d_model, seq_length) for i in range(h)])
        self.wo = nn.Linear(in_features=d_model,out_features=d_model)

    def forward(self, x):

        yh_cat = self.mheads[0](x)
        for i in range(1, self.h):
            yhi = self.mheads[i](x)
            yh_cat = torch.cat((yh_cat, yhi),-1)

        y = self.wo(yh_cat)
        
        return y

In [62]:
mha1 = MultiHeadAttention(h, d_model//h, d_model, seq_length)
mha2 = MultiHeadAttention2(h, d_model, seq_length)

In [63]:
x = torch.rand(batch_size, seq_length, d_model)

In [64]:
(mha1(x)-mha2(x)).abs().max()

tensor(0., grad_fn=<MaxBackward1>)

In [65]:
mha2(x).shape

torch.Size([3, 5, 512])