In [None]:
import math
from torch import nn
import torch.nn.functional as F

def scaled_dot_product(q,k,v,mask=None):
    #q,k,v=30*8*200*64
    d_k=q.size()[-1]#64
    scaled=torch.matmul(q,k.transpose(-1,-2))/math.sqrt(d_k)#30*8*200*200
    if mask is not None:
        scaled+=mask#30*8*200*200
    attention=F.softmax(scaled,dim=-1)#30*8*200*200
    values=torch.matmul(attention,v)#30*8*200*64
    return values,attention

class LayerNormalization(nn.Module):
    def __init__(self,parameters_shape,eps=1e-5):
        super().__init__()
        self.parameters_shape=parameters_shape#512
        self.eps=eps#
        self.gamma=nn.Parameter(torch.ones(parameters_shape))#512
        self.beta=nn.Parameter(torch.zeros(parameters_shape))#512

    def forward(self,inputs):#30*200*512
        dims=[-(i+1) for i in range(len(self.parameters_shape))]#-1
        mean=inputs.mean(dims=dims,keepdim=True)#300*200*1
        var=((input-mean)**2).mean(dim=dims,keepdims=True)#300*200*1
        std=(var+self.eps).sqrt()#300*200*1
        y=(inputs-mean)/std#300*200*512
        out=self.gamma*y+self.beta
        return out

class PositionwiseFeedForward(nn.Module):
    def __init__(self,d_model,hidden,drop_prob=.1):
        super(PositionwiseFeedForward,self).__init__()
        self.linear1=nn.Linear(d_model,hidden)#512*2048
        self.linear2=nn.Linear(hidden,d_model)#2048*512
        self.relu=nn.Relu()
        self.dropout=nn.Dropout(p=drop_prob)


    def forward(self,x):#30*200*512
        x=self.linear1(x)#30*200*2048
        x=self.relu(x)#30*200*2048
        x=self.dropout(x)#30*200*2048
        x=self.linear2(x)#30*200*512
        return x

class MultiHeadAttention(nn.Module):
    def __init__(self,d_model,num_heads):
        super().__init__()
        self.d_model=d_model#512
        self.num_heads=num_heads#8
        self.head_dim=d_model//num_heads#64
        self.qkv_layer=nn.Linear(d_model,3*d_model)#512,1536
        self.linear_layer=nn.Linear(d_model,d_model)#512,512

    def forward(self,x,mask=None):
        batch_size,sequence_length,d_model=x.size()#30*200*512
        qkv=self.qkv_layer(x)#30*200*1536
        qkv=qkv.reshape(batch_size,sequence_length,self.num_heads,3*self.head_dim)#30*200*8*192
        qkv=qkv.permute(0,2,1,3)#30*8*200*192
        q,k,v=qkv.chunk(3,dim=-1)#each are 30*8*200*64
        values,attention=scaled_dot_product(q,k,v,mask)#attention=30*8*200*200 , values=30*8*200*64
        values=values.reshape(batch_size,sequence_length,self.num_heads*self.head_dim)+30*200*512
        out=self.linear_layer(values)
        return out

class MultiHeadCrossAttention(nn.Module):
    def __init__(self,d_model,num_heads):
        super().__init__()
        self.d_model=d_model#512
        self.num_heads=num_heads#8
        self.head_dim=d_model//num_heads#64
        self.kv_layer=nn.Linear(d_model,2*d_model)#1024
        self.q_layer=nn.Linear(d_model,d_model)
        self.linear_layer=nn.Linear(d_model,d_model)#512,512

    def forward(self,x,mask=None):
        batch_size,sequence_length,d_model=x.size()#30*200*512
        kv=self.kv_layer(x)#30*200*1024
        q=self.q_layer(y)#30*200*512
        kv=kv.reshape(batch_size,sequence_length,self.num_heads,2*self.head_dim)#30*200*8*128
        q=q.reshape(batch_size,sequence_length,self.num_heads,self.head_dim)#30*200*8*64
        kv=kv.permute(0,2,1,3)#30*8*200*128
        q=q.permute(0,2,1,3)#30*8*200*64
        k,v=kv.chunk(2,dim=-1)#30*8*200*64 each
        values,attention=scaled_dot_product(q,k,v,mask)#30*8*200*64
        values=values.reshape(batch_size,sequence_length,self.num_heads*self.head_dim)
        out=self.linear_layer(values)#30*200*512
        return out

class DecoderLayer(nn.Module):
    def __init__(self,d_model,ffn_hidden,num_heads,drop_prob):
        super(DecoderLayer,self).__init__()
        self.self_attention=MultiHeadAttention(d_model=d_model,num_heads=num_heads)
        self.norm1=LayerNormalization(parameters_shape=[d_model])
        self.dropout1=nn.Dropout(p=drop_prob)
        self.encoder_decoder__attention=MultiHeadCrossAttention(d_model=d_model,num_heads=num_heads)
        self.norm2=LayerNormalization(parameters_shape=[d_model])
        self.ffn=PositionwiseFeedForward(d_model=d_model,hidden=ffn_hidden,drop_prob=drop_prob)
        self.norm3=LayerNormalization(parameters_shape=[d_model])
        self.dropout3=nn.Dropout(p=drop_prob)

    def forward(self,x,y,decoder_mask):
        _y=y
        y=self.self_attention(y,mask=decoder_mask)
        y=self.dropout1(y)
        y=self.norm1(y+_y)

         _y=y
        y=self.self.encoder_decoder_attentiony(x,y,mask=None)#30*200*512
        y=self.dropout2(y)#30*200*512
        y=self.norm2(y+_y)#30*200*512


         _y=y
        y=self.self.ffn(y)#30*200*512
        y=self.dropout3(y)#30*200*512
        y=self.norm3(y+_y)#30*200*512
        return y

    def forward(self,x):
        residual_x=x#300*200*512
        x=self.attention(x,mask=None)#300*200*512
        x=self.dropout1(x)#300*200*512
        x=self.norm1(x+residual_x)#300*200*512
        residual_x=x#300*200*512
        x=self.ffn(x)#300*200*512
        x=self.dropout2(x)#300*200*512
        x=self.norm2(x+residual_x)#300*200*512
        return x

In [None]:
class SequentialDecode(nn.Sequential):
      def forward(self,*inputs):
          x,y,mask=inputs
          for module in self._module.values():
              y=module(x,y,mask)
          return y    

In [None]:
class Decoder(nn.Module):
    def __init__(self,d_model,ffn_hidden,num_heads,drop_prob,num_layers=1):
        super().__init__()
        self.layers=SequentialDecoder(*[DecoderLayer(d_model,ffn_hidden,num_heads,drop_prob) for _ in range(num_layers)])

    def forward(self,x,y,mask):
        #x=30*200*512 and y=30*200*512,mask=200*200
        y=self.layers(x,y,mask)
        return y    

In [None]:
d_model=512
num_heads=8
drop_prob=.1
batch_size=30
max_sequence_length=200
ffn_hidden=2048
num_layer=5

x=torch.randn((batch_size,max_sequence_length,d_model))# english sentence positional encoding
y=torch.randn((batch_size,max_sequence_length,d_model))# hindi sentence positional encoding
mask=torch.full([max_sequence_length,max_sequence_length],float('-inf'))
mask=torch.triu(mask,diagonal=1)
decoder=Decoder(d_model,ffn_hidden,num_heads,drop_prob,num_layers)
out=decoder(x,y,mask)
