In [1]:
import torch
import torch.nn as nn
import math

In [2]:
class InputEmbedding(nn.Module):
    
    def __init__(self, d_model: int, vocab_size: int):
        super().__init__()
        self.d_model = d_model
        self.vocab_size =  vocab_size
        self.embedding = nn.Embedding(vocab_size,d_model)
    
    def forward(self, x):
        return self.embedding(x)

In [17]:
class PositionalEncoding(nn.Module):
    
    def __init__(self, d_model:int, seq_len:int, dropout:float):
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        pe = torch.zeros(seq_len,d_model)
        positions = torch.arange(0, seq_len, dtype=torch.float)
        div_term = torch.exp(torch.arange(0,d_model,2) * (-math.log(10000)/d_model))
        self.dropout = nn.Dropout(dropout)
        
        # Positonal encodings
        pe[:,::2] = torch.sin(positions * div_term)
        pe[:,1::2] = torch.cos(positions * div_term)
        pe.unsqueeze_(0)

        # make it non-trainable, does not consider as parameter
        self.register_buffer('pe',pe)

    def forward(self, x):
        x = x + (self.pe[:,:x.size(1),:]).requires_grad_(False)
        return self.dropout(x)


In [4]:
class LayerNormalization(nn.Module):

    def __init__(self, eps=10**-6):
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(1)) # Scaling
        self.bias = nn.Parameter(torch.zeros(1)) # Adding
        
    def forward(self,x):
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return self.alpha * (x-mean)/(std+self.eps) + self.bias


In [5]:
class FeedForwardBlock(nn.Module):
    
    def __init__(self, d_model:int, dim_ff:int,dropout:float):
        super().__init__()
        self.ff1 = nn.Linear(d_model,dim_ff)
        self.ff2 = nn.Linear(dim_ff,d_model)
        self.dropout = nn.Dropout(dropout)
        self.relu = nn.ReLU()

    def forward(self,x):
        x = self.ff1(x)
        x = torch.relu(x)
        x = self.dropout(x)
        x = self.ff2(x)
        return x

In [11]:
class MultiHeadAttention(nn.Module):
    
    def __init__(self, d_model, heads:int=8):
        super().__init__()
        # Initialize Q, K, V metrix
        self.query_weights = nn.Linear(d_model,d_model)
        self.key_weights = nn.Linear(d_model, d_model)
        self.value_weights = nn.Linear(d_model, d_model)
        self.output_weights = nn.Linear(d_model, d_model)
        self.heads = heads

        assert d_model % heads == 0, "d_model should be dividable by heads"
        self.d_k = d_model // self.heads
        
    def forward(self,query, key, value, mask):
        queries = self.query_weights(query) # (batch,seq,d_model)
        keys = self.key_weights(key)
        values = self.value_weights(value)
        
        batch_size = x.size(0)
        seq_len = x.size(1)

        # Splitting heads,(batch,seq,d_model) -> (batchsize,seq,heads,d_k)
        queries = queries.view(batch_size, seq_len, self.heads, self.d_k) 
        queries = queries.transpose(1,2)  # (batch, head, seq, d_k)

        keys = keys.view(batch_size, seq_len, self.heads, self.d_k) 
        keys = keys.transpose(1,2)

        values = values.view(batch_size, seq_len, self.heads, self.d_k) 
        values = values.transpose(1,2)

        # Attention
        context_vectors_heads ,attention_scores = MultiHeadAttention.attention_score(queries, keys, values, mask)
        context_vectors = context_vectors_heads.transpose(2,1).contigous().view(batch_size, seq_len, self.heads*self.d_k) #batch,h,seq,dk -> batch,seq,h,dk -> batch,seq,d_model

        return self.output_weights(context_vectors), attention_scores

    @staticmethod
    def attention_score(queries, keys, values, mask, d_k):
        attention_score = (queries @ keys.transpose(-1,-2)) / torch.sqrt(d_k)
        if mask is not None:
            attention_score.masked_fill_(mask==0, -1e9) # masking with very small value instead of '-inf'
        
        # Context vectors
        attention_score_softmax = attention_score.softmax(dim=-1)
        context_vectores_heads = attention_score_softmax @ values

        return context_vectores_heads, attention_score  

        

        

In [None]:
class ResidualConnection(nn.Module):

    def __init__(self, dropout:float):
        super().__init__()
        self.norm = LayerNormalization()
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, sublayer):
        x_norm = self.norm(x)
        sublayer_out = sublayer(x_norm)
        x = x + sublayer_out
        return x 


In [16]:
class EncoderBlock(nn.Module):
    
    def __init__(self, multihead_attention:MultiHeadAttention, feedforward:FeedForwardBlock, dropout:float):
        super().__init__()
        self.residual_connections = nn.ModuleList([ResidualConnection(dropout) for _ in range(2)])
        self.multihead_attention = multihead_attention
        self.feedforward = feedforward
    
    def forward(self,x, src_mask):
        x = self.residual_connections[0](x, lambda x :self.multihead_attention(x, x, x, src_mask))
        x = self.residual_connections[1](x, self.feedforward)
        return x


In [None]:
class Encoder(nn.Module):

    def __init__(self,multihead_attention:MultiHeadAttention,
                 feedforward:FeedForwardBlock, 
                 dropout:float, 
                 n_encoder:int=6):
        super().__init__()

        self.n_encoder = n_encoder
        self.norm = LayerNormalization()
        self.encoder_layers = nn.ModuleList([EncoderBlock(multihead_attention,feedforward,dropout) for _ in range(n_encoder)])
    
    def forward(self,x, src_mask):
        for layer in self.encoder_layers:
            x = layer(x, src_mask)
        return self.norm(x)

        



In [18]:
class DecoderBlock(nn.Module):

    def __init__(self, self_attention: MultiHeadAttention,cross_attention:MultiHeadAttention, feedforward:FeedForwardBlock, dropout:float):
        super().__init__()
        self.self_attention = self_attention
        self.cross_attention = cross_attention
        self.feedforward = feedforward
        self.dropout = nn.Dropout(dropout)
        self.residual_connections = nn.ModuleList([ResidualConnection(dropout) for _ in range(3)])

    def forward(self, dec_query, encoder_output,tar_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention(x, x, x, tar_mask))
        x = self.residual_connections[1](x, lambda x: self.cross_attention(dec_query, encoder_output, encoder_output, tar_mask))
        x = self.residual_connections[2](x, self.feedforward)