In [1]:
import torch
from torch import nn
import numpy as np

In [2]:
class InputEmbedding(nn.Module):

  def __init__(self,vocab_size,emb_size):
    super().__init__()

    self.emb_size = emb_size
    self.vocab_size= vocab_size
    
    self.input_emb = nn.Embedding(vocab_size,emb_size)

  def forward(self,x):
    return self.input_emb(x) * np.sqrt(self.emb_size)

In [3]:
class PositionalEmbedding(nn.Module):
  def __init__(self,batch_size,seq_len,emb_size):


    super().__init__()

    self.batch_size = batch_size
    self.seq_len = seq_len
    self.emb_size = emb_size


    positions = np.arange(seq_len)[:,np.newaxis]
    depth = np.arange(emb_size)[np.newaxis, :]
    depth = (2*depth//2)/emb_size

    angle_rates = 1 / (10000**depth)

    angle_rads  = positions * angle_rates
    angle_rads[:,0::2] = np.sin(angle_rads[:,0::2])
    angle_rads[:,1::2] = np.sin(angle_rads[:,1::2])


    self.positions = positions * angle_rads

  def forward(self):
    return torch.tensor(np.broadcast_to(self.positions,[self.batch_size,self.seq_len,self.emb_size]),dtype=torch.float32)

In [4]:
class MultiHeadAttention(nn.Module):
    
    def __init__(self,emb_size,batch_size,heads,seq_len,decoder=False,mode='self'):
        super(MultiHeadAttention,self).__init__()
        
        
        """"
        Parameters:
                emb_size (int): Embedding size (e.g 512)
                batch_size (int): Batch Size
                heads (int): Number of heads (e.g 8)
                seq_len (int): Number of words in each sequence
                decoder (bool): False (if inputs comes from decoder side)
                mode (str): self or mask -> attention

                
        Returns:
            Out (Tensor)
        
        """
        self.emb_size= emb_size
        self.head_dim = emb_size//heads
        self.seq_len = seq_len
        self.batch_size = batch_size
        
        # Queries, Keys and Values Matrices Layers
        self.queries = nn.Linear(self.emb_size,self.emb_size)
        self.keys = nn.Linear(self.emb_size,self.emb_size)
        self.values = nn.Linear(self.emb_size,self.emb_size)
        self.out_projection = nn.Linear(self.emb_size,self.emb_size) 
        
        self.softmax = nn.Softmax(dim=-1)
        

        self.mode = mode
        self.decoder = decoder
      
    
    def self_attention(self,queries,keys,values,masked=False):
        """
         queries: (batch_size,seq_len,dim)
         keys: (batch_size,seq_len,dim)
         values: (batch_size,seq_len,dim)
         
        """
        
        scores = torch.matmul(queries,keys.transpose(-2,-1))
        scores = scores/np.sqrt(self.head_dim)

        if masked:
            mask = np.tril(np.ones((self.seq_len,self.seq_len)))
            mask[mask==0] = -np.inf
            scores = scores + torch.tensor(mask,dtype=torch.float32)
        
        
        scores = self.softmax(scores)
        atten = torch.matmul(scores,values)
        
        return atten 
    
     
    def forward(self,x,enc_key=[],enc_value=[]):
        
        
        # As mention in the paper first we multiply each word embedding in our case 512 with 512x512 Matrcis
        # We pass our data through the dense layer
        
        # For Multiheaded Attention
        if self.decoder==False:
            queries = self.queries(x)
            keys = self.keys(x)
            values = self.values(x)
            
        else:
            
            # Multi Headed when keys and values come from encoder part
            queries = self.queries(x)
            keys = self.keys(enc_key)
            values = self.values(enc_value)
        


        
          
        if self.mode == 'self':
            
            # Self Attention
            attention = self.self_attention(queries,keys,values)
            
        # Apply masked multiheaded attention
        if self.mode == 'mask':
            attention = self.self_attention(queries,keys,values,masked=True)
        
        
            
        # Last Projection Matrix 
        
        out = self.out_projection(attention)
                          
        return out


# Encoder

In [5]:
class Encoder(nn.Module):
    def __init__(self,
               batch_size,
               seq_len,
               emb_size=512,
               heads=8,
               forward_expansion=4):
    
        super(Encoder,self).__init__()

        self.emb_size = emb_size
        self.seq_len = seq_len
        self.batch_size = batch_size
        self.heads = heads
        self.expansion_rate = forward_expansion

        self.mha = MultiHeadAttention(
                                  self.emb_size,
                                  self.batch_size,
                                  self.heads,self.seq_len
                                      )
        self.relu = nn.ReLU()
        self.dense_1 = nn.Linear(self.emb_size,int(self.emb_size*self.expansion_rate))
        self.dense_2 =nn.Linear((self.emb_size * self.expansion_rate),self.emb_size)
        
        self.layer_norm = nn.LayerNorm(self.emb_size)

  
    def forward(self,x):

        self_attention = self.mha(x)

        x = x + self_attention

        x = self.layer_norm(x)

        dense = self.dense_1(x)
        dense = self.relu(dense)
        dense = self.dense_2(dense)
        dense = self.relu(dense)
                          
        x = x + dense
        x = self.layer_norm(x)


        return x

# Decoder

In [6]:
class Decoder(nn.Module):
    
    def __init__(self,
               batch_size,
               seq_len,
               emb_size=512,
               heads=8,
               forward_expansion=4):
        
        
        super(Decoder,self).__init__()

        self.emb_size = emb_size
        self.seq_len = seq_len
        self.batch_size = batch_size
        self.heads = heads
        self.expansion_rate = forward_expansion
        
        
        #Mask Multihead Attention
        self.causal_attention = MultiHeadAttention(self.emb_size,
                                                self.batch_size,
                                                self.heads,
                                                self.seq_len,
                                                mode = 'mask')
        
        # MultiHeaded Attention
        self.mha = MultiHeadAttention(self.emb_size,
                                      self.batch_size,
                                      self.heads,
                                      self.seq_len,decoder=True)
        
        

        self.dense_1 = nn.Linear(self.emb_size,int(self.emb_size*self.expansion_rate))
        self.dense_2 = nn.Linear(int(self.emb_size*self.expansion_rate),self.emb_size)
        self.relu = nn.ReLU()

        self.layer_norm = nn.LayerNorm(self.emb_size)

    def forward(self,x,enc_key,enc_value):
        
       
        # Apply mask Attention
        mask_attention = self.causal_attention(x)
        
        x = x + mask_attention

        x = self.layer_norm(x)
        
        # Self Attention
        self_attention = self.mha(x,enc_key,enc_value)
        
        x = x + self_attention
        x = self.layer_norm(x)
        
        
        
        dense = self.dense_1(x)
        dense = self.relu(dense)
        dense = self.dense_2(dense)
        dense = self.relu(dense)

        x = x + dense
        x = self.layer_norm(x)
        
        return x


# Transformer


In [7]:
class Transformer(nn.Module):
    
    def __init__(self,
                 src_vocab_size,
                 trg_vocab_size,
                 seq_len,
                 batch_size,
                 emb_size=512,
                 heads=8,
                 expansion_rate=4,
                 num_modules=6
                 ):
        
        
        super(Transformer,self).__init__()
        
        
        self.src_vocab_size = src_vocab_size
        self.trg_vocab_size = trg_vocab_size
        self.emb_size = emb_size
        self.heads = heads
        self.expansion_rate = expansion_rate
        self.num_modules = num_modules
        self.batch_size = batch_size
        self.seq_len = seq_len
        
        self.encoder_layers = [Encoder(batch_size,seq_len) for _ in range(num_modules)]
        
        self.decoder_layers = [Decoder(batch_size,seq_len) for _ in range(num_modules)]
        
        self.linear = nn.Linear(self.emb_size,trg_vocab_size)
        self.softmax = nn.Softmax(dim=-1)
   
    
    def forward(self,input1,input2):
        
        
        # Encoder part
        
        # input embeddings
        input_embeddings = InputEmbedding(self.src_vocab_size,self.emb_size)(input1)
        
        #positional encoding
        
        positional_encodings = PositionalEmbedding(self.batch_size,self.seq_len,self.emb_size)()
        
        enc_out = input_embeddings + positional_encodings
        


        for layer in self.encoder_layers:
            enc_out = layer(enc_out)
            
        
        
        # Decoder Part
        
         # input embeddings
        input_embeddings = InputEmbedding(self.trg_vocab_size,self.emb_size)(input2)
        
        #positional encoding
        
        positional_encodings = PositionalEmbedding(self.batch_size,self.seq_len,self.emb_size)()
        
        dec_out = input_embeddings + positional_encodings
        
        for layer in self.decoder_layers:
            dec_out = layer(dec_out,enc_out,enc_out)
        
        # linear Layer
        
        out = self.linear(dec_out)
        
        # apply softmax
        out = self.softmax(out)

        
        return out


#Define Model

In [8]:
batch_size = 10
seq_len = 5
src_vocab_size = 100
trg_vocab_size = 150

transformer = Transformer(src_vocab_size ,trg_vocab_size,seq_len,batch_size)
transformer

Transformer(
  (linear): Linear(in_features=512, out_features=150, bias=True)
  (softmax): Softmax(dim=-1)
)

In [9]:
src_tokens = torch.randint(1, 100, size=(10, 5))
trg_tokens = torch.randint(1, 150, size=(10, 5))

out = transformer(src_tokens,trg_tokens)
out.shape

torch.Size([10, 5, 150])