In [1]:
import torch
from torch import nn
import math
import torch.nn.functional as F

In [2]:
batch_size = 3
h = 8
d_model = 512
d_k = d_model // h
seq_length = 5

In [3]:
class AttentionHead(nn.Module):

    def __init__(self, d_k, d_model, seq_length):
        super().__init__()
        
        self.wq = nn.Linear(in_features=d_model,out_features=d_k,bias=False)
        self.wk = nn.Linear(in_features=d_model,out_features=d_k,bias=False)
        self.wv = nn.Linear(in_features=d_model,out_features=d_k,bias=False)
        
    def forward(self, x):
        # B (batch_size), T (seq_length), C (d_model)
        B, T, C = x.shape

        # (batch_size, seq_length, d_model) -> (batch_size, seq_length, d_k)
        q = self.wq(x)
        k = self.wk(x)
        v = self.wv(x)

        # print(k.shape)

        # (batch_size, seq_length, d_k) * (batch_size, d_k, seq_length) 
        # -> (batch_size, seq_length, seq_length)
        qkt = q @ k.transpose(-2,-1) / math.sqrt(d_k)
        
        # apply mask (seq_length*seq_length) now
        mask = torch.tril((torch.ones(seq_length,seq_length)==1))==False
        qkt = qkt.masked_fill(mask, -torch.inf)
        
        sm = F.softmax(qkt, dim=-1)

        # (batch_size, seq_length, seq_length) * (batch_size, seq_length, d_model)
        # -> (batch_size, seq_length, d_model)
        att  = sm @ v
        
        return att

In [4]:
class MultiHeadAttention(nn.Module):

    def __init__(self, h, d_model, seq_length):
        super().__init__()

        self.h = h
        self.d_model = d_model
        self.seq_length = seq_length
        self.d_k = int(d_model // h)

        self.mheads = nn.ModuleList([AttentionHead(self.d_k, d_model, seq_length) for i in range(h)])
        self.wo = nn.Linear(in_features=d_model,out_features=d_model)

    def forward(self, x):

        yh_cat = self.mheads[0](x)
        for i in range(1, self.h):
            yhi = self.mheads[i](x)
            yh_cat = torch.cat((yh_cat, yhi),-1)

        y = self.wo(yh_cat)
        
        return y

In [5]:
class FeedForward(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        
        self.ff = nn.Sequential(
          nn.Linear(in_features=d_model,out_features=4*d_model,bias=True),
          nn.GELU(),
          nn.Linear(in_features=4*d_model,out_features=d_model,bias=True),
        )

    def forward(self, x):
        return(self.ff(x))        

In [6]:
class TransformerBlock(nn.Module):
    """
    This calss will group together MultiHead Attention and
    FeedForward NN, so that we can copy it in Transformer
    """

    def __init__(self, h, d_model, seq_length):
        super().__init__()

        torch.manual_seed(3)
        
        self.sa = MultiHeadAttention(
            h=h,
            d_model=d_model,
            seq_length=seq_length,
        )
        self.ffwd = FeedForward(d_model=d_model)
        # add the layer normalization
        self.ln1 = nn.LayerNorm([d_model])
        self.ln2 = nn.LayerNorm([d_model])

    def forward(self, x):
        # "x +" is the skip (or residual) connection
        # it helps with optimization
        # also we apply layer normalization before self-attention
        # and feed-forward (a reshufle from original paper)
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

In [7]:
class TransformerBlock2(nn.Module):
    # my own implementation
    
    def __init__(self, h, d_model, seq_length):
        super().__init__()

        torch.manual_seed(3)
        self.mha = MultiHeadAttention(h, d_model, seq_length)
        self.ff = FeedForward(d_model)
        self.ln1 = nn.LayerNorm([d_model]);
        self.ln2 = nn.LayerNorm([d_model]);
        
    def forward(self, x):
        x = x + self.mha(self.ln1(x))
        x = x + self.ff(self.ln2(x))
        return x        

In [8]:
x = torch.rand(batch_size, seq_length, d_model)

In [9]:
tb = TransformerBlock(h, d_model, seq_length)
tb2 = TransformerBlock2(h, d_model, seq_length)

In [10]:
(tb(x)-tb2(x)).abs().mean()

tensor(0., grad_fn=<MeanBackward0>)