In [7]:
import math
import numpy as np

In [8]:
import torch
import torch.nn as nn

In [9]:
class InputEmbedding(nn.Module):

    def __init__(self, d_model, vocab_size):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, d_model)
        # shape = [num of words * dimension of embedding layer]

    def forward(self, x):
        return self.embedding(x) * math.sqrt(self.d_model)
        # dimension same

In [10]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, seq_length, dropout = 0):
        super().__init__()
        self.d_model = d_model
        self.seq_length = seq_length
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(self.seq_length, self.d_model)  # To get the matrix of dimension as of embedding layer
        positions = torch.arange(0, self.seq_length, dtype = torch.float32).unsqueeze(1)  # matrix of [seq_length x 1]
        div_term = (positions /(torch.pow(10000, 2 * torch.arange(0, d_model, 2).float() /self.d_model))) #to calculate say (angle)  pos/(10000^(2i/dmodel))
        pe[:, 0::2] = torch.sin(div_term)   #Apply sine formula in even positions
        pe[:, 1::2] = torch.cos(div_term)   # Appply cosine formula in odd positions
        
        self.pe = pe.unsqueeze(0)  # for batches dimension [1 x seq_length x d_model]

        # self.register_buffer('pe', self.pe) # By adding this in register buffer this stores pe too while saving the model without considering it as a learning parameter
                

    def forward(self, x):
        x = x + self.pe.required_grad_(False)  #To make it not to learn
        return self.dropout(x)
    # def forward(self, ..):
        # pe = torch.zeros()

In [11]:
class LayerNormalization(nn.Module):
    def __init__(self, epsilon=1e-5):
        super(LayerNormalization, self).__init__()
        self.epsilon = epsilon
        self.gamma = nn.Parameter(torch.ones(d_model))  # Scale
        self.beta = nn.Parameter(torch.zeros(d_model))  # Shift

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True) 
        return self.gamma * (x - mean) / torch.sqrt(var + self.epsilon) + self.beta


In [12]:
class FeedForward(nn.Module):

    def __init__(self, d_model, dff, dropout = 0.5):
        super(FeedForward, self).__init__()
        self.forward1 = nn.Linear(d_model, dff)
        self.dropout = nn.Dropout(dropout)
        self.forward2 = nn.Linear(dff, d_model)

    def forward(self, x):
        return self.forward2(self.dropout(torch.relu(self.forward1(x))))

In [13]:
# HERE I USED ALL EMBEDDING FOR EACH HEAD AND CONCATENATE THEM AND USE LINEAR TRANSFORMATION TO GET THE OUTPUT SAME DIMENSION AS INPUT
# class MultiHeadAttention(nn.Module):

#     def __init__(self, d_model, heads, dropout = 0.5):
#         super(MultiHeadAttention, self).__init__()
#         self.d_model = d_model
#         self.heads = heads
#         self.dropout = dropout

#         self.w_q = nn.ModuleList(nn.Linear(d_model, d_model) for _ in range(heads))
#         self.w_k = nn.ModuleList(nn.Linear(d_model, d_model) for _ in range(heads))
#         self.w_v = nn.ModuleList(nn.Linear(d_model, d_model) for _ in range(heads))

#         self.w_o = nn.Linear(d_model * heads, d_model)

#         self.softmax = nn.Softmax(dim = -1)
#         self.dropout = nn.Dropout(dropout)

#     def forward(self, embeded_layer):

#         attention_outputs = []

#         for head in range(self.heads):
        
#             query = self.w_q[head](embeded_layer)
#             key = self.w_k[head](embeded_layer)
#             value = self.w_v[head](embeded_layer)

#             similarity = torch.matmul(query, torch.transpose(key, -2, -1))  / math.sqrt(self.d_model)

#             sim = self.softmax(similarity)
#             sim = self.dropout(sim)

#             final = torch.matmul(sim, value)

#             attention_outputs.append(final)
            
#         concat_matrix = torch.cat(attention_outputs, -1)
#         print(concat_matrix.shape)
#         print(self.w_o.weight.shape)
#         return self.w_o(concat_matrix)
        

        
        

In [14]:
class MultiHeadAttention(nn.Module):

    def __init__(self, d_model, heads, dropout = 0.5):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.heads = heads
        self.d_heads = d_model//heads

        self.w_q = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)

        self.w_o = nn.Linear(d_model, d_model)

        self.softmax = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

    
    # def splitweights(self, x):
    #     batch_size, seq_len, d_model = x.shape
    #     x = x.view(batch_size, seq_len, self.heads, -1)
    #     return x.permute(0, 2, 1, 3)
        

    def forward(self, x_q, x_k, x_v, mask = None):

        batch_size, seq_len, d_model = x_q.shape
        print(x.shape)
        print(self.w_q.weight.shape)

        query = self.w_q(x_q).view(batch_size, seq_len, self.heads, -1).permute(0, 2, 1, 3)
        key = self.w_k(x_k).view(batch_size, seq_len, self.heads, -1).permute(0, 2, 1, 3)
        value = self.w_v(x_v).view(batch_size, seq_len, self.heads, -1).permute(0, 2, 1, 3)

        # query = self.splitweights(self.w_q(x))
        # key = self.splitweights(self.w_k(x))
        # value = self.splitweights(self.w_v(x))
        print(query.shape)

        similarity = torch.matmul(query, key.transpose(-2, -1))/math.sqrt(self.d_heads)

        # print(similarity.shape)

        if mask is not None:
            # print(mask)
            mask = mask.unsqueeze(0).unsqueeze(0)
            print(mask)
            # print(similarity)
            similarity = similarity.masked_fill(mask == 0, float('-inf'))
        print(similarity)

        sim = self.softmax(similarity)
        print(sim)
        sim = self.dropout(sim)

        # print(sim)

        final = torch.matmul(sim, value)

        final = final.permute(0, 2, 1, 3).contiguous()
        final = final.view(batch_size, seq_len, self.d_model)
        

        # print(final.shape)
        return self.w_o(final)
        

In [15]:
class ResidualConnection(nn.Module):

    def __init__(self, d_model ,dropout):

        super(ResidualConnection, self).__init__()
        self.ln = LayerNormalization(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x1, x2):

        return self.ln(x1 + self.dropout(x2))

In [16]:
class EncoderBlock(nn.Module):

    def __init__(self,  dff, d_model, heads, dropout):

        super(EncoderBlock, self).__init__()

        self.multi_attention = MultiHeadAttention(d_model, heads, dropout)
        self.residual_connections = nn.ModuleList([ResidualConnection(d_model, dropout) for _ in range(2)])

        self.feed_forward = FeedForward(d_model, dff)

    def forward(self, x):

        x1 = self.multi_attention(x, x, x)
        x2 = self.residual_connections[0](x, x1)
        x3 = self.feed_forward(x2)
        x4 = self.residual_connections[1](x2, x3)
        return x4
        
        

In [17]:
class Encoder(nn.Module):

    def __init__(self, vocab_size, dff, seq_length, d_model, heads,dropout, n = 6):

        super(Encoder, self).__init__()
        self.embedding_layer = InputEmbedding(d_model, vocab_size)
        self.positional_encoding = PositionalEncoding(d_model, seq_length, dropout)
        
        self.encoder_blocks = nn.ModuleList([EncoderBlock(dff, d_model, heads,dropout) for _ in range(n)])
        print(type(self.encoder_blocks))

    def forward(self,x):

        x = self.positional_encoding(self.embedding_layer(x))

        for block in self.encoder_blocks:
            x = block(x)
        return x

In [18]:
class DecoderBlock(nn.Module):

    def __init__(self, dff, d_model, heads, dropout):

        super(DecoderBlock, self).__init__()
        # self.masked_attention = masked_attention
        # self.residual_connections = residual_connections
        # self.cross_attention = cross_attention
        # self.feed_forward = feed_forward
        self.masked_attention = MultiHeadAttention(d_model, heads)
        self.residual_connections = nn.ModuleList([ResidualConnection(d_model, dropout) for _ in range(3)])
        self.cross_attention = MultiHeadAttention(d_model, heads)
        self.feed_forward = FeedForward(d_model, dff)


    def forward(self, x, x_enc, mask):

        x1= self.masked_attention(x, x, x, mask)
        x2 = self.residual_connections[0](x, x1)

        x3= self.cross_attention(x2, x_enc, x_enc)
        x4 = self.residual_connections[1](x2, x3)

        x5 = self.feed_forward(x4)
        x6 = self.residual_connections[2](x4, x5)

        return x6

In [19]:
class Decoder(nn.Module):

    def __init__(self, vocab_size, dff, seq_length, d_model, heads,dropout, n = 6):

        super(Decoder, self).__init__()
        
        self.embedding_layer = InputEmbedding(d_model, vocab_size)
        self.positional_encoding = PositionalEncoding(d_model, seq_length, dropout)

        self.decoder_blocks = nn.ModuleList([DecoderBlock(dff, d_model, heads,dropout) for _ in range(n)])

        self.mask = m = torch.tril(torch.ones(seq_length, seq_length))

        self.linear = nn.Linear(d_model, vocab_size)
        


    def forward(self, x, x_enc):

        x = self.positional_encoding(self.embedding_layer(x))
        for block in self.decoder_blocks:
            x = block(x, x_enc, self.mask)
        return (self.linear(x))
        
        
        

In [20]:
class Transformer(nn.Module):

    def __init__(self, vocab_size_source, vocab_size_target, seq_length_source, seq_length_target, d_model = 512, heads = 8, dropout = 0.1, dff = 2048):

        super(Transformer, self).__init__()

        self.encoder = Encoder(vocab_size_source, dff, seq_length_source, d_model, heads, dropout)
        self.decoder = Decoder(vocab_size_target, dff, seq_length_target, d_model, heads, dropout)

        for parameter in self.parameters():
            if isinstance(parameter, nn.Linear):
                nn.init.xavier_uniform_(parameter)

    def forward(self, x_in, x_op):

        x_enc= self.encoder(x_in)
        output = self.decoder(x_op, x_enc, x_enc)

