In [1]:
import torch
import torch.nn as nn
import math

In [3]:
class InputEmbeddings(nn.Module):
    def __init__(self,d_m: int, vocab_size: int) -> None:
        super().__init__()
        self.d_m = d_m
        self.embeddings = nn.Embedding(vocab_size,d_m)

    def forward(self,x):
        ## (batch, seq_len) -> (batch,seq_len,d_m)
        ## Multiply by sqrt(d_m) to scale the embeddings
        return self.embeddings(x) * math.sqrt(self.d_m)

In [7]:
class PositionalEncoddings(nn.Module):
    def __init__(self,d_m: int, seq_len: int, dropout: float) -> None:
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        ## create a matrix of shape (seq_len,d_m)
        pe = torch.zeros(seq_len,d_m)
        ## create a vector of seq_len
        position = torch.arange(0,seq_len,dtype=torch.float).unsqueeze(0) ## (seq_len,1)
        ## create a vector of d_m for divider
        div_term = torch.exp(torch.arange(0,d_m,2).float() * (-math.log(100000.0)/d_m)) ## (d_m/2)
        ## apply sine to even indices
        pe[:,0::2] = torch.sin(position*div_term)
        ## apply cosine to odd indices
        pe[:,1::2] = torch.cos(position*div_term)
        ## add a batch dimension to the pe
        pe = pe.unsqueeze(0) # (1,seq_len,d_m)
        ## register the positional encoding as buffer
        self.register_buffer('pe',pe)
    
    def forward(self,x):
        x = x + (self.pe[:,:x.shape[1],:]).requires_grad(False) ## (batch,seq_len,d_m)
        return x


In [8]:
class LayerNormalization(nn.Module):
    def __init__(self,features: int,eps:float=10**-6):
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(features)) ## learnable parameter
        self.bias = nn.Parameter(torch.zeros(features)) ## learnable parameter

    def forward(self,x):
        ## x -> (batch,seq_len,hidden_size)
        ## keep dimension for broadcasting
        mean = x.mean(dim=-1,keepdim=True) ## (batch,seq_len,1)
        std = x.std(dim=-1,keepdim=True) ## (batch,seq_len,1)
        ## eps is to prevent dividing by zero or when std is very small
        return self.alpha * (x-mean) / (std + self.eps) + self.bias

In [10]:
class FeedForwardBlock(nn.Module):
    def __init__(self,d_m: int, d_ff: int,dropout: float) -> None:
        super().__init__()
        self.linear_1 = nn.Linear(d_m,d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff,d_m)

    def forward(self,x):
        ## x -> (batch,seq_len,d_m) -> (batch,seq_len,d_ff) -> (batch,seq_len,d_m)
        return self.linear_2(self.dropout(nn.ReLU(self.linear_1(x))))


In [11]:
class ResidualConnect(nn.Module):
    def __init__(self,features: int, dropout: float):
        super().__init__()
        self.norm = LayerNormalization(features)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self,x,sublayer):
        ## x -> (batch,seq_len,d_m)
        x = x + self.dropout(sublayer(self.norm(x)))

In [12]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_m: int, heads : int, dropout: float) -> None:
        super().__init__()
        self.d_m = d_m
        self.heads = self.heads
        self.d_k = self.d_m/self.heads
        self.dropout = nn.Dropout(dropout)
        self.w_Q = nn.Linear(d_m,d_m, bias=False) ## Wq
        self.w_K = nn.Linear(d_m,d_m, bias=False) ## Wk
        self.w_V = nn.Linear(d_m,d_m, bias=False) ## Wv
        self.w_O = nn.Linear(d_m,d_m, bias=False) ## Wo

    @staticmethod
    def attention(query,key,value,mask,dropout: nn.Dropout):
        d_k = query.shape[-1]
        ## Just apply the formula from the paper
        ## (batch,heads,seq_len,d_k) -> (batch,heads,seq_len,seq_len)
        attention_score = (query @ key.transpose(-2,-1)) / math.sqrt(d_k)
        if mask is not None:
            attention_score.masked_fill_(mask==0,-1e9)
        attention_score = attention_score.softmax(dim=1) ## (batch,heads,seq_len,seq_len)
        if dropout is not None:
            attention_score = dropout(attention_score)
        ## (batch,heads,seq_len,seq_len) -> (batch,heads,seq_len,d_k)
        return (attention_score @ value), attention_score

    def forward(self,q,k,v,mask):
        query = self.w_Q(q)
        key = self.w_K(k)
        value = self.w_V(v)

        # (batch,seq_len,d_m) -> (batch,seq_len,heads,d_k) -> (batch.head,seq_len,d_k)
        query = query.view(query.shape[0],query.shape[1],self.heads,self.d_k).transpose(1,2)
        key = key.view(key.shape[0],key.shape[1],self.heads,self.d_k).transpose(1,2)
        value = value.view(value.shape[0],value.shape[1],self.heads,self.d_k).transpose(1,2)

        x,self.attention_score = MultiHeadAttention.attention(query,key,value,mask,self.dropout)

        # combine all aheads together
        # (batch,heads,seq_len,d_k) -> (batch,seq_len,d_k)
        x = x.transpose(1,2).contiguous().view(x.shape[0],-1,self.h*self.d_k)

        return self.w_O(x)
        

In [13]:
class EncoderBlock(nn.Module):
    def __init__(self, features: int, self_attention_block: MultiHeadAttention, feedforward_block: FeedForwardBlock, dropout: float) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feedforward_block
        self.residual_connection = nn.ModuleList([ResidualConnect(features,dropout) for _ in range(2)])

    def forward(self,x, src_mask):
        x = self.residual_connection[0](x,lambda x: self.self_attention_block(x,x,x,src_mask))
        x = self.residual_connection[1](x,self.feed_forward_block)
        return x

In [14]:
class Encoder(nn.Module):
    def __init__(self, features: int, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)
    
    def forward(self,x,mask):
        for layer in self.layers:
            x = layer(x,mask)
        return self.norm(x)

In [15]:
class DecoderBlock(nn.Module):
    def __init__(self,features: int,self_attention_block: MultiHeadAttention,cross_attention_block:MultiHeadAttention,feed_forward_block:FeedForwardBlock,dropout:float) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnect(features,dropout) for _ in range(3)])

    def forward(self,x,encoder_output,src_mask,tgt_mask):
        x = self.residual_connections[0](x,lambda x: self.self_attention_block(x,x,x,tgt_mask))
        x = self.residual_connections[1](x,lambda x: self.cross_attention_block(x,encoder_output,encoder_output,src_mask))
        x = self.residual_connections[2](x,self.feed_forward_block)
        return x


In [16]:
class Decoder(nn.Module):
    def __init__(self,features: int, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)
    
    def forward(self,x,encoder_output,src_mask,tgt_mask):
        for layer in self.layers:
            x = layer(x,encoder_output,src_mask,tgt_mask)
        return self.norm(x)

In [18]:
class ProjectionLayer(nn.Module):
    def __init__(self, d_m,vocab_size) -> None:
        super().__init__()
        self.proj = nn.Linear(d_m,vocab_size)

    def forward(self,x):
        return self.proj(x)

In [19]:
class Transformer(nn.Module):
    def __init__(self, encoder: Encoder, decoder: Decoder, src_embd: InputEmbeddings, 
                 tgt_embd: InputEmbeddings, src_pos: PositionalEncoddings, 
                 tgt_pos: PositionalEncoddings, proj_layer: ProjectionLayer) -> None:
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embd = src_embd
        self.tgt_embd = tgt_embd
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.proj_layer = proj_layer

    def encode(self,src,src_mask):
        ## (batch,seq_len,d_m)
        src = self.src_embd(src)
        src = self.src_pos(src)
        return self.encoder(src,src_mask)
    
    def decode(self,tgt,tgt_mask):
        ## (batch,seq_len,d_m)
        src = self.tgt_embd(src)
        src = self.tgt_pos(src)
        return self.decoder(tgt,tgt_mask)    
    
    def project(self,x):
        return self.proj_layer(x)

In [20]:
def build_transformer(src_vocab_size,tgt_vocab_size,
                      src_seq_len,tgt_seq_len,
                      d_m = 512, N = 6, h = 8, dropout = 0.1,
                      d_ff = 2048):
    
    ## Create embedding layers
    src_embd = InputEmbeddings(d_m,src_vocab_size)
    tgt_embd = InputEmbeddings(d_m,tgt_vocab_size)

    ## Postional Layer
    src_pos = PositionalEncoddings(d_m,src_seq_len,dropout)
    tgt_pos = PositionalEncoddings(d_m,tgt_seq_len,dropout)

    ## create the encoder block
    encoder_blocks = []
    for _ in range(N):
        encoder_self_attention_block = MultiHeadAttention(d_m,h,dropout)
        feed_forward_block = FeedForwardBlock(d_m,d_ff,dropout)
        encoder_block = EncoderBlock(d_m,encoder_self_attention_block,feed_forward_block,dropout)
        encoder_blocks.appped(encoder_block)

    ## create the decoder block
    decoder_blocks = []
    for _ in range(N):
        decoder_self_attention_block = MultiHeadAttention(d_m,h,dropout)
        decoder_cross_attention_block = MultiHeadAttention(d_m,h,dropout)
        feed_forward_block = FeedForwardBlock(d_m,d_ff,dropout)
        decoder_block = DecoderBlock(d_m,decoder_self_attention_block,decoder_cross_attention_block,feed_forward_block,dropout)
        decoder_blocks.append(decoder_block)

    ## create the encoder and decoder
    encoder = Encoder(d_m,nn.ModuleList(encoder_blocks))
    decoder = Decoder(d_m,nn.ModuleList(decoder_blocks))

    ## create the projection layer
    projection_layer = ProjectionLayer(d_m,tgt_vocab_size)

    ## create transformer Model
    transformer = Transformer(encoder,decoder,src_embd,tgt_embd,src_pos,tgt_pos,projection_layer)

    ## initialize paramters
    for p in transformer.parameters:
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    
    return transformer