In [1]:
import torch
from torch import nn
import math
import torch.nn.functional as F

In [2]:
batch_size = 3
h = 8
d_model = 512
d_k = d_model // h
seq_length = 5

In [3]:
class AttentionHead(nn.Module):
    """
    One head of the self-attention layer
    """

    def __init__(self, d_k, d_model, seq_length):
        super().__init__()
        
        self.wq = nn.Linear(in_features=d_model,out_features=d_k,bias=False)
        self.wk = nn.Linear(in_features=d_model,out_features=d_k,bias=False)
        self.wv = nn.Linear(in_features=d_model,out_features=d_k,bias=False)
        
    def forward(self, x):
        # B (batch_size), T (seq_length), C (d_model)
        B, T, C = x.shape

        # (batch_size, seq_length, d_model) -> (batch_size, seq_length, d_k)
        q = self.wq(x)
        k = self.wk(x)
        v = self.wv(x)

        # print(k.shape)

        # (batch_size, seq_length, d_k) * (batch_size, d_k, seq_length) 
        # -> (batch_size, seq_length, seq_length)
        qkt = q @ k.transpose(-2,-1) / math.sqrt(d_k)
        
        # apply mask (seq_length*seq_length) now
        mask = torch.tril((torch.ones(seq_length,seq_length)==1))==False
        qkt = qkt.masked_fill(mask, -torch.inf)
        
        sm = F.softmax(qkt, dim=-1)

        # (batch_size, seq_length, seq_length) * (batch_size, seq_length, d_model)
        # -> (batch_size, seq_length, d_model)
        att  = sm @ v
        
        return att

In [4]:
class MultiHeadAttention(nn.Module):

    def __init__(self, h, d_model, seq_length):
        super().__init__()

        self.h = h
        self.d_model = d_model
        self.seq_length = seq_length
        self.d_k = int(d_model // h)

        torch.manual_seed(3)
        self.mheads = nn.ModuleList([AttentionHead(self.d_k, d_model, seq_length) for i in range(h)])
        self.wo = nn.Linear(in_features=d_model,out_features=d_model)

    def forward(self, x):

        yh_cat = self.mheads[0](x)
        for i in range(1, self.h):
            yhi = self.mheads[i](x)
            yh_cat = torch.cat((yh_cat, yhi),-1)

        y = self.wo(yh_cat)
        
        return y

In [5]:
class FeedForward(nn.Module):
    """
    A simple linear layer followed by ReLu
    """

    def __init__(self, num_embed):
        super().__init__()
        torch.manual_seed(3)
        self.net = nn.Sequential(
            # in the Attention is All You Need paper
            # authors are using the size of the ffwd layer 2048
            # and the output of the model is 512
            # so we apply the same factor of 4
            nn.Linear(num_embed, 4 * num_embed),
            nn.GELU(),
            # apply the linear projection layer
            nn.Linear(4 * num_embed, num_embed),
        )

    def forward(self, x):
        return self.net(x)

In [6]:
class FeedForward2(nn.Module):
    # my own implementation
    def __init__(self, num_embed):
        super().__init__()
        torch.manual_seed(3)
        
        self.ff = nn.Sequential(
          nn.Linear(in_features=num_embed,out_features=4*num_embed,bias=True),
          nn.GELU(),
          nn.Linear(in_features=4*num_embed,out_features=num_embed,bias=True),
        )

    def forward(self, x):
        return(self.ff(x))        

In [7]:
x = torch.rand(batch_size, seq_length, d_model)

In [8]:
ff = FeedForward(d_model)
ff2 = FeedForward2(d_model)

In [9]:
ff(x).shape, ff2(x).shape

(torch.Size([3, 5, 512]), torch.Size([3, 5, 512]))

In [10]:
(ff(x)-ff2(x)).abs().mean()

tensor(0., grad_fn=<MeanBackward0>)