In [1]:
# Import libraries
import torch.nn as nn
import torch
import math

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Code the positional encoding function 
class PositionalEncoding(nn.Module):
    ''' This function helps to preserve the order of the words in a sequence by encoding the position of each word and add 
    it to its corresponding embedding'''
    def __init__(self, embed_size, max_length,device):
        super(PositionalEncoding, self).__init__()
        self.embed_size = embed_size
        self.max_length = max_length
        self.device = device
        self.pe_matrix = torch.zeros(self.max_length,self.embed_size).to(device)

    def forward(self, embedding):
        for pos in range(self.max_length):
            for i in range(self.embed_size,2):
                a = math.sin(pos / (10000 ** ((2 * i)/self.embed_dim)))
                self.pe_matrix[pos,i] = math.sin(pos / (10000 ** ((2 * i)/self.embed_dim)))
                self.pe_matrix[pos,i+1] =  math.cos(pos / (10000 ** ((2 * (i + 1))/self.embed_dim)))
        # Reshape the pe_matrix where position 0 is of dimension 1
        self.pe_matrix = self.pe_matrix.unsqueeze(0)
        seq_length = embedding.size(1)
        # add the positional encoding matrix to the embedding matrix
        output = embedding + self.pe_matrix[:, :seq_length]
        return output

In [3]:
# Code the multihead attention layer
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = self.embed_dim // self.num_heads
        assert self.embed_dim == self.head_dim * self.num_heads

        self.key = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
        self.value = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
        self.query = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
        self.dense_out = nn.Linear(self.embed_dim, self.embed_dim)

    def forward(self, v,k,q, mask=None):
        # v,q, and k are equal to the output of the embedding layer
        v = self.value(v)
        k = self.key(k)
        q = self.query(q)
        # N is the number of samples 
        N = q.size(0)
        # query_length is the embed size of each query.
        query_length = q.size(1)
        # split the v,k, and q into the number of heads
        v = v.reshape(v.size(0), v.size(1), self.num_heads, self.head_dim)
        k = k.reshape(k.size(0), k.size(1), self.num_heads, self.head_dim)
        q = q.reshape(q.size(0), q.size(1), self.num_heads, self.head_dim)
        # multiply queries and keys with einsum
        attention_weights = torch.einsum("nqhd,nkhd ->nhqk",[q,k])
        if mask is not None:
            attention_weights =  attention_weights.masked_fill(mask==0, float("-1e20"))
       
        # apply a softmax on the attention_weights to get probabilities as output 
        attention_weights = nn.Softmax(dim=3)(attention_weights/math.sqrt(self.head_dim))
        # multiply the attention_weights by the values v
        attention = torch.einsum("nhqk,nkhd ->nqhd",[attention_weights,v]) # k and v always have the same size (length)
        # reshape the output attention matrix to the initial input shape
        attention = attention.reshape(N,query_length, self.num_heads * self.head_dim)
        # add the linear layer 
        output = self.dense_out(attention)
        return output

In [4]:
# build the whole encoder block  which is composed of the embedding layer, multiheadattention layer, normalization layer,
# and linear layer
class EncoderBlock(nn.Module):
    def __init__(self, emed_size, num_heads):
        super(EncoderBlock,self).__init__()
        self.attention = MultiHeadSelfAttention(emed_size,num_heads)
        self.norm1 = nn.LayerNorm(emed_size)
        self.norm2 = nn.LayerNorm(emed_size)
        self.feed_forward = nn.Sequential(nn.Linear(emed_size,emed_size),
                                          nn.ReLU(),
                                          nn.Linear(emed_size,emed_size))
    def forward(self, v,k,q,mask):
        attention = self.attention(v,k,q,mask)
        x = self.norm1(attention + q)
        forward = self.feed_forward(x)
        output = self.norm2(x + forward)
        return output

In [5]:
# Build the encoder which composes of a stack of encoder blocks
class Encoder(nn.Module):
    def __init__(self,embed_size, num_heads, vocab_size,max_length, num_layers, device):
        super(Encoder, self).__init__()
        self.embed_size = embed_size
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.max_length = max_length
        self.vocab_size = vocab_size
        self.device = device
        self.embedding = nn.Embedding(self.vocab_size,self.embed_size).to(self.device)
        self.positional_encoding = PositionalEncoding(self.embed_size, self.max_length, self.device)
        # ModuleList is a list of encoder blocks
        self.layers = nn.ModuleList([EncoderBlock(self.embed_size, self.num_heads) for i in range(self.num_layers)]).to(self.device)


    def forward(self, input, mask):

        x = self.embedding(input)
        x = self.positional_encoding(x)
        for layer in self.layers:
            x = layer(x,x,x,mask)
        return x

In [6]:
# Build the decoder block which is quite similar to the encoder block
class DecoderBlock(nn.Module):
    def __init__(self, embed_size, num_heads):
        super(DecoderBlock, self).__init__()
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.transformer = EncoderBlock(self.embed_size, self.num_heads)
        self.attention = MultiHeadSelfAttention(self.embed_size,self.num_heads)
        self.norm = nn.LayerNorm(self.embed_size)

    def forward(self,input, value, key, encoder_mask, decoder_mask):

        x = self.attention(input, input, input, decoder_mask)
        out1 = self.norm(input+x)
        out2 = self.transformer(value, key, out1, encoder_mask)
        return out2

In [7]:
# Build the decoder block which composes of a stack of decoder blocks
class Decoder(nn.Module):
    def __init__(self,embed_size, num_heads, decoder_vocab_size, max_length,num_layers,device):
        super(Decoder, self).__init__()
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.decoder_vocab_size = decoder_vocab_size
        self.max_length = max_length
        self.num_layers = num_layers
        self.device = device
        self.embedding = nn.Embedding(self.decoder_vocab_size, self.embed_size).to(self.device)
        self.positional_encoding = PositionalEncoding(self.embed_size, self.max_length, self.device)
        self.layers = nn.ModuleList([DecoderBlock(self.embed_size, self.num_heads) for i in range(self.num_layers)]).to(self.device)
        self.output = nn.Linear(embed_size,decoder_vocab_size).to(self.device)

    def forward(self, input, encoder_out, encoder_mask, decoder_mask):
        x = self.embedding(input)
        x = self.positional_encoding(x)
        for layer in self.layers:
            x = layer(x, encoder_out,encoder_out,encoder_mask,decoder_mask)

        output = self.output(x)
        return output

In [8]:
# Build the transformer block from the encoder and decoder
class Transformer(nn.Module):
    def __init__(self,embed_size, num_heads, encoder_vocab_size,decoder_vocab_size, max_length, num_layers, encoder_pad_idx, decoder_pad_idx,device):
        super(Transformer, self).__init__()
        self.encoder_pad_idx = encoder_pad_idx
        self.decoder_pad_idx = decoder_pad_idx
        self.device = device
        self.encoder = Encoder(embed_size, num_heads, encoder_vocab_size,max_length, num_layers,self.device)
        self.decoder = Decoder(embed_size, num_heads, decoder_vocab_size,max_length, num_layers,self.device)


    def make_encoder_mask(self,encoder_input):

        encoder_mask = (encoder_input != self.encoder_pad_idx).unsqueeze(1).unsqueeze(2)
        return encoder_mask.to(self.device)


    def make_decoder_mask(self, decoder_input):

        N, decoder_input_len = decoder_input.shape
        # torch.tril set the upper part of a tensor to zero
        # torch.expand expand the tensor by replicating rows and columns
        decoder_mask = torch.tril(torch.ones((decoder_input_len,decoder_input_len))).expand(N,1,decoder_input_len,decoder_input_len)
        return decoder_mask .to(self.device)


    def forward(self, encoder_input, decoder_input):

        encoder_mask = self.make_encoder_mask(encoder_input)
        decoder_mask = self.make_decoder_mask(decoder_input)
        enc_output = self.encoder(encoder_input, encoder_mask)
        final_output = self.decoder(decoder_input, enc_output, encoder_mask, decoder_mask)

        return final_output

In [13]:
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    x = torch.tensor(([1,5,6,4,3,9,5,2,0],[1,5,7,3,4,5,6,7,2])).to(device)
    trg = torch.tensor(([1,7,4,3,5,9,2,0],[1,5,6,2,4,7,6,2])).to(device)
    encoder_pad_idx = 0
    decoder_pad_idx = 0
    encoder_vocab_size = 10
    decoder_vocab_size = 10
    max_length = 9
    num_layers = 6
    embed_size = 64
    num_heads = 8

    model = Transformer(embed_size, num_heads,encoder_vocab_size,decoder_vocab_size, max_length, num_layers,encoder_pad_idx, decoder_pad_idx,device)
    print(model)
    a = trg[:,:-1]
    out= model(x, trg)
    print(out.shape)

Transformer(
  (encoder): Encoder(
    (embedding): Embedding(10, 64)
    (positional_encoding): PositionalEncoding()
    (layers): ModuleList(
      (0): EncoderBlock(
        (attention): MultiHeadSelfAttention(
          (key): Linear(in_features=64, out_features=64, bias=False)
          (value): Linear(in_features=64, out_features=64, bias=False)
          (query): Linear(in_features=64, out_features=64, bias=False)
          (dense_out): Linear(in_features=64, out_features=64, bias=True)
        )
        (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (feed_forward): Sequential(
          (0): Linear(in_features=64, out_features=64, bias=True)
          (1): ReLU()
          (2): Linear(in_features=64, out_features=64, bias=True)
        )
      )
      (1): EncoderBlock(
        (attention): MultiHeadSelfAttention(
          (key): Linear(in_features=64, out_features=64, bias=False)
    