In [1]:
import torch
import torch.nn as nn
import math

In [2]:
class InputEmbedding(nn.Module):
    def __init__(self, vocab_size: int, d_model:int):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size,d_model)
    def forward(self,x):
        return self.embedding(x) * math.sqrt(self.d_model)   #  

In [3]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model:int , seq_len : int , dropout:float) -> None:
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        #To avoid overfitting we use a dropout layer
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(seq_len,d_model)
        position = torch.arange(0,seq_len,1,dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0,d_model,2).float() * (-math.log(1000)/ d_model))
        #Separate the odd and even numbers
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0)  #(1,seq_len , d_model)
        self.register_buffer("pe",pe)
        
    def forward(self ,x):
        x = x + (self.pe[: , :x.shape(1),:]).requires_grad(False)
        return self.dropout(x)


In [4]:
class LayerNormalization(nn.Module):
    def __init__(self, eps : float = 10e-6) :
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(1))
        self.beta = nn.Parameter(torch.ones(1))
        
        def forward(self,x):
            # x = (batch_size, seq_len , d_model)
            mean = x.mean(dim= -1 , keepdim= True)
            std = x.std(dim = -1 , keepdim= True)
            return self.alpha * (x - mean)/(std + eps) + self.beta

In [5]:
class FeedForwardBlock(nn.Module):
    def __init__(self, d_model : int , d_ff: int , dropout:float) -> None:
        super().__init__()
        self.d_model = d_model
        self.d_ff = d_ff
        self.dropout = nn.Dropout(dropout)
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.linear_2 = nn.Linear(d_ff,d_model)
        
    def forward(self,x):
        # x : (batch_size , seq_len ,d_model) -->(batch_size , seq_len ,d_ff) -->(batch_size , seq_len ,d_model)
        return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))

In [18]:
class MultiHeadAttentionBlock(nn.Module):
    def __init__(self, d_model: int , h : int , dropout: float ) -> None:
        super().__init__()
        self.h = h
        self.d_model = d_model
        self.dropout = nn.Dropout(dropout)
        assert d_model % h == 0 ,"Can not divide d_model by h"
        
        self.d_k = d_model // h
        self.w_q = nn.Linear(d_model, d_model, bias = False )
        self.w_k= nn.Linear(d_model, d_model, bias = False )
        self.w_v = nn.Linear(d_model, d_model, bias = False )
        self.w_o = nn.Linear(d_model, d_model, bias = False )
        
    @staticmethod
    def attention(query,key,value,mask , dropout: nn.Dropout):
        d_k= query.shape[-1]
        attention_scores = (query @ key.transpose(-2,-1))/ math.sqrt(d_k)
        if mask is not None:
            attention_scores = attention_scores.mask_fill_(mask == 0 , -1e9)
        attention_scores = attention_scores.softmax(dim = -1) #(batch_size,h,seq_len,d_model,d_model)
        if dropout is not None :
            attention_scores = dropout(attention_scores)
        return (attention_scores @ value) , attention_scores   # Take the attention scores for visualization
            
    def forward(self, q, v, k , mask):
        query = self.w_q(q)  # (batch_size , seq_len , d_model ) --> (batch_size,seq_len,d_model)
        value = self.w_v(v)  # (batch_size , seq_len , d_model ) --> (batch_size,seq_len,d_model)
        key   = self.w_k(k)  # (batch_size , seq_len , d_model ) --> (batch_size,seq_len,d_model)

        #Split the embeddings
        # (batch_size , seq_len , d_model ) --> (batch_size,h,seq_len,d_model,d_k)
        query = query.view(query.shape[0],query.shape[1],self.h,self.d_k ).transpose(1,2)
        key = key.view(key.shape[0],key.shape[1],self.h,self.d_k ).transpose(1,2)
        value = value.view(value.shape[0],value.shape[1],self.h,self.d_k ).transpose(1,2)
        
        x , self.attention_score = MultiHeadAttentionBlock.attention(key,query,value,mask , self.dropout)
        #(batch , h , seq_len , d_k) -->(batch_size,seq_len,d_model,d_k) -->(batch_size,seq_len,d_model)
        x = x.transpose(1,2).contiguous().view(x.shape[0],-1,self.h * self.d_k)
        return self.w_o(x)
    

In [19]:
class ResidualConnection(nn.Module):
    def __init__(self,dropout : float ) -> None:
        super().__init__()
        self.dropout = nn.dropout(dropout)
        self.norm = LayerNormalization()
        
    def forward(self,x , sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

In [20]:
class EncoderBlock(nn.Module):
    def __init__(self, self_attention_block : MultiHeadAttentionBlock , feed_forward_block : FeedForwardBlock , dropout: float ) -> None:
        super().__init__()
        self.self_attention_block= self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connexions = nn.ModuleList([ResidualConnection(dropout) for _ in range(2)])
        
    def forward(self , x , src_mask):
        x = self.residual_connexions[0](x, lambda x: self.self_attention_block(x,x,x,src_mask))
        x = self.residual_connexions[0](x, self.feed_forward_block)
        

In [21]:
class Encoder(nn.Module):
    def __init__(self, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization()
    def forward(self , x ,mask):
        for l in self.layers:
            x = l(x,mask)
        return self.norm(x)