In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
class PositionalEncoding(nn.Module):
    def __init__(self, seq_length:int, d_model:int, n = 10000, dtype = torch.float32):
        super(PositionalEncoding,self).__init__()
        self.seq_length = seq_length
        self.d_model = d_model
        self.n = n
        self.dtype = dtype
        self.encode_table = nn.Parameter(self._create_table())

    def _create_table(self) -> torch.tensor:
        table = torch.zeros((self.seq_length, self.d_model),dtype=self.dtype)

        for pos in torch.arange(self.seq_length):
            # Here d_model is divide by 2 because the 2*i and 2*i + 1
            # The interaction foward two by two 
            for i in torch.arange(self.d_model//2):
                denominator = 2*i/self.d_model
                calculation = pos/torch.pow(self.n, denominator)
                table[pos,2*i]      = torch.sin(calculation) 
                table[pos,2*i + 1]  = torch.cos(calculation)

        return table

    def forward(self,x:torch.tensor) -> torch.tensor:
        x += self.encode_table

        return x

In [4]:
def scaled_dot_product_attention(query:torch.tensor, key:torch.tensor, value:torch.tensor) -> [torch.tensor,torch.tensor]:
    factor = 1/torch.sqrt(torch.tensor(key.size(-1)))
    attn = F.softmax(torch.matmul(query,key.transpose(-2,-1))*factor,dim=-1)

    x = torch.matmul(attn,value)

    return x, attn

In [5]:
class HeadAttention(nn.Module):
    def __init__(self, d_model:int, d_k:int, d_v:int):
        super(HeadAttention,self).__init__()
        self.weights_query = nn.Parameter(torch.randn(d_model,d_k))
        self.weights_key = nn.Parameter(torch.randn(d_model,d_k))
        self.weights_value = nn.Parameter(torch.randn(d_model,d_v))

    def forward(self, query:torch.tensor, key:torch.tensor, value:torch.tensor) -> torch.tensor:
        q = torch.matmul(query,self.weights_query)
        k = torch.matmul(key,self.weights_key)
        v = torch.matmul(value,self.weights_value)

        x, _ = scaled_dot_product_attention(q,k,v)
        return x
        

In [6]:
class MultiHead(nn.Module):
    def __init__(self, d_model:int, d_k:int, d_v:int, h:int):
        super(MultiHead,self).__init__()
        self.weights_concat = nn.Parameter(torch.randn(d_v*h,d_model))

        self.multi_head = nn.ModuleList([
            HeadAttention(d_model=d_model, d_k=d_k, d_v=d_v) for _ in range(h)
        ])

        self.norm_layer = nn.LayerNorm(d_model)

    def forward(self, query:torch.tensor, key:torch.tensor, value:torch.tensor) -> torch.tensor:
        x = torch.concat([module_(query, key, value) for module_ in self.multi_head],dim=-1)
        x = torch.matmul(x,self.weights_concat)

        x += query

        x = self.norm_layer(x)
        
        return x

In [7]:
class PositionWiseFeedForwardNetworks(nn.Module):
    def __init__(self, d_model:int, d_ff:int):
        super(PositionWiseFeedForwardNetworks,self).__init__()
        self.linear_inner_layer = nn.Linear(d_model,d_ff)
        self.linear_layer = nn.Linear(d_ff,d_model)

        self.norm_layer = nn.LayerNorm(d_model)

    def forward(self, x:torch.tensor) -> torch.tensor:
        x = self.linear_inner_layer(x)
        x = x.relu()
        x = self.linear_layer(x)

        x += x
        x = self.norm_layer(x)

        return x

In [8]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model:int, d_k:int, d_v:int, h:int, d_ff:int):
        super(EncoderLayer,self).__init__()
        self.multi_head_attention = MultiHead(d_model=d_model,d_k=d_k,d_v=d_v,h=h)
        self.position_wise_feed_forward_networks = PositionWiseFeedForwardNetworks(d_model=d_model,d_ff=d_ff)

    def forward(self, query:torch.tensor, key:torch.tensor, value:torch.tensor) -> torch.tensor:
        x = self.multi_head_attention(query=value, key=key, value=value)
        x = self.position_wise_feed_forward_networks(x)
        return x

In [9]:
class Encoder(nn.Module):
    def __init__(self, seq_length: int, n_layers:int, d_model:int, d_k:int, d_v:int, h:int, d_ff:int):
        super(Encoder,self).__init__()
        self.positional_encoding = PositionalEncoding(seq_length=seq_length,d_model=d_model)
        self.encoder_layer = nn.ModuleList([
            EncoderLayer(d_model=d_model,d_k=d_k,d_v=d_v,h=h,d_ff=d_ff) for _ in range(n_layers)
        ])

    def forward(self, src:torch.tensor) -> torch.tensor:
        src = self.positional_encoding(src)
        for module_ in self.encoder_layer:
            src += module_(src,src,src)

        return src
        

In [19]:
from torchinfo import summary

In [21]:
summary(encoder, input_size=(64, 725, 512))

Layer (type:depth-idx)                                  Output Shape              Param #
Encoder                                                 [64, 725, 512]            --
├─PositionalEncoding: 1-1                               [64, 725, 512]            371,200
├─ModuleList: 1-2                                       --                        --
│    └─EncoderLayer: 2-1                                [64, 725, 512]            --
│    │    └─MultiHead: 3-1                              [64, 725, 512]            525,312
│    │    └─PositionWiseFeedForwardNetworks: 3-2        [64, 725, 512]            2,100,736
│    └─EncoderLayer: 2-2                                [64, 725, 512]            --
│    │    └─MultiHead: 3-3                              [64, 725, 512]            525,312
│    │    └─PositionWiseFeedForwardNetworks: 3-4        [64, 725, 512]            2,100,736
│    └─EncoderLayer: 2-3                                [64, 725, 512]            --
│    │    └─MultiHead: 3-5     

In [10]:
encoder = Encoder(725,4,512,64,64,4,2048)

In [11]:
encoder = encoder.to(torch.device("cuda"))

In [13]:
input_tensor = torch.randn(64,725,512).to(torch.device("cuda"))