### Importing the libraries


In [None]:
import torch
import torch.nn as nn
from torch import unsqueeze

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        """
        We are going to split the embed size between these heads. If we have 256 size embedding and 8 heads then we will have 32 embed size for each embedding
        """
        super(SelfAttention,self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        
        assert self.embed_size%self.heads == 0 , "To make sure that embed size is properly divisible by heads"
        
        self.head_dim = embed_size//heads

        """
        Now we are defining the Query value and key vectors as Linear layers. 
        We are setting bias = False, because we dont need that
        """
        self.values = nn.Linear(self.head_dim, self.head_dim, bias = False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias = False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias = False)
        self.fc_out = nn.Linear(heads*self.head_dim, embed_size)

    def forward(self, values, keys, queries, mask):
        N = queries.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]

        """
        split embedding into self.heads pieces
        """
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = queries.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        #step 1: multiply query and key
        
        # queries shape : (N, query_len, heads, heads_dim)
        # keys shape: (N, key_len, heads, heads_dim)
        # energy shape: (N, heads,query_len = target source sentence, key_len = source sentence)
        """
        As we have a batch matrix multiplier einsum is quite handy for it 
        """
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries,keys]) #it is used for matrix multiplication where we have several other dimensions 
        

        if mask is not None:
            energy = energy.masked_fill(mask==0, float("-1e20"))

        attention = torch.softmax(energy/(self.embed_size**(1/2)),dim=3)


        #attention shape: (N,heads, query_len, key_len)
        #value shape: (N, value_len, heads, heads_dim)
        #out shape: (N, Query_len, heads, heads_dim)
        out = torch.einsum("nhql,nlhd->nqhd", [attention,values])


        #concatanation part 
        out  = out.reshape(N,query_len, self.heads*self.head_dim)

        out = self.fc_out(out)
        return out

### Transformer Block

In [None]:
class TransformerBlock(nn.Module):
    """
    embedding -> multiheaded_attention -> add&norm -> feed forward -> add&norm 
    """
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock,self).__init__()
        self.attention = SelfAttention(embed_size=embed_size, heads = heads)
        self.norm1 = nn.LayerNorm(embed_size) #layernorm and batchnorm are almost similar...but layer norm has more computation
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion*embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion*embed_size,embed_size)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)

        """
        we needed a skip connection. query is for the skip connection
        """
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward+x))
        print(out)
        return out

In [None]:
class Encoder(nn.Module):
    def __init__(self, src_vocab_size, embed_size, num_layers, heads, device, forward_expansion, dropout, max_length):
        """
        Encoder block takes a lot of parameters due to hyperparameter. The parameters are explained below:
        ---------------------------------------------------------------------------------------------------
        src_vocab_size = size of source vocabulary 
        embed_size  = dimension of embedding 
        num_layers = number of transformer layer in encoder
        heads = number of heads in multiheads 
        device = the device on which we want to train
        forward_expansion  = the ratio by which we want to expand the size
        dropout  = dropout probability
        max_length = max sentence length. 
        maximum length of string to ensure positional embedding which is requeired for ensuring we have attention. 
        What transformer does is we wnat to ensure that some sort of sequence is maintained even is the layer does not have any recurrent unit. It helps the transformer for ensuring parallelization
        """
        super(Encoder,self).__init__()
        self.embed_size = embed_size
        self.device = device
        self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
        self.positional_embedding = nn.Embedding(max_length, embed_size)


        self.layers = nn.ModuleList(
            [ 
              TransformerBlock(embed_size=embed_size, heads=heads, dropout=dropout, forward_expansion=forward_expansion)  for _ in range(num_layers)
            ] 
        )
        self.dropout = nn.Dropout(dropout)

    
    def forward(self, x, mask):
        
        N, seq_length = x.shape

        positions = torch.arange(0,seq_length).expand(N,seq_length).to(self.device)

        out = self.dropout(self.word_embedding(x)+self.positional_embedding(positions))

        for layer in self.layers:
            out = layer(out,out,out,mask)
        
        return out

In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion, dropout, device):
        super(DecoderBlock,self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm = nn.LayerNorm(embed_size)
        self.transformer_block = TransformerBlock(embed_size, heads, dropout, forward_expansion)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, value, key, src_mask, trg_mask):
        """
        Decoder block takes a lot of parameters. The parameters are explained below:
        ----------------------------------------------------------------------------
        x : input 
        value, key : for self_attention
        src_mask: source mask. Although it is optional still we need it. For example, let we have more than one example in the input. In those cases src_mask is needed to make all the sentences equal also we dont need to to extra computations for the masks that are padded
        trg_mask: trg_mask is required to make sure that everything works fine
        """
        attention = self.attention(x,x,x,trg_mask)
        query = self.dropout(self.norm(attention+x))
        out = self.transformer_block(value, key, query, src_mask)

        return out


In [None]:
class Decoder(nn.Module):
    def __init__(self, trg_vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, device, max_length):
        super(Decoder,self).__init__()
        self.device = device
        self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
        self.positional_embedding = nn.Embedding(max_length,embed_size)

        self.layers = nn.ModuleList(
            [DecoderBlock(embed_size, heads, forward_expansion, dropout, device) for _ in range(num_layers)]
        )

        self.fc_out = nn.Linear(embed_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, enc_out, src_mask, trg_mask):
        N, seq_length = x.shape
        positions = torch.arange(0,seq_length).expand(N,seq_length).to(self.device)
        x = self.dropout(self.word_embedding(x)+self.positional_embedding(positions))

        for layer in self.layers:
            x = layer(x,enc_out, enc_out, src_mask, trg_mask)

        out = self.fc_out(x)

        return out


In [None]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, embed_size = 256, num_layers = 6, forward_expansion = 4, heads = 8, dropout = 0, device = "cuda", max_length = 100):
        super(Transformer,self).__init__()
        self.encoder = Encoder(src_vocab_size,embed_size, num_layers, heads, device, forward_expansion, dropout, max_length)
        self.decoder = Decoder(trg_vocab_size,embed_size, num_layers, heads, forward_expansion, dropout, device, max_length)
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx

        self.device = device


    def make_src_mask(self,src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        return src_mask.to(self.device)

    def make_trg_mask(self,trg):
        N, trg_len = trg.shape

        trg_mask = torch.tril(torch.ones((trg_len,trg_len))).expand(N,1,trg_len, trg_len)

        return trg_mask.to(self.device)

    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)

        enc_src = self.encoder(src, src_mask)
        out = self.decoder(trg, enc_src, src_mask, trg_mask)

        return out




In [None]:
if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    x = torch.tensor([[1,5,6,4,3,9,5,2,0],[1,8,7,3,4,5,6,7,2]]).to(device)
    trg = torch.tensor([[1,7,4,3,5,9,2,0],[1,8,7,3,4,5,6,2]]).to(device)

    src_pad_idx = 0

    trg_pad_idx = 0
    src_vocab_size = 10
    trg_vocab_Size = 10

    model = Transformer(src_vocab_size, trg_vocab_Size, src_pad_idx, trg_pad_idx).to(device)

    out = model(x,trg[:,:-1])
    print(out.shape)

RuntimeError: ignored