# Transformer Encoder Block

In [221]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as f

In [222]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self,d_model,hidden,dropout=0.1):
        super(PositionwiseFeedForward,self).__init__()
        self.linear1=nn.Linear(d_model,hidden) # 512*512
        self.linear2=nn.Linear(hidden,d_model) #2048*512
        self.relu=nn.ReLU()
        self.dropout=nn.Dropout(p=dropout)
        
    def forward(self,x): #32*100*512
        x=self.linear1(x) #32*100*2048
        x=self.relu(x)#32*100*2048
        x=self.dropout(x)#32*100*2048
        x=self.linear2(x)#32*100*512
        return x

In [223]:
class LayerNormalization(nn.Module):
    def __init__(self,parameters_shape,eps=1e-5):
        super().__init__()
        self.parameters_shape=parameters_shape #512
        self.eps=eps
        self.gamma=nn.Parameter(torch.ones(parameters_shape)) # 512
        self.beta=nn.Parameter(torch.zeros(parameters_shape))# 512
         
    def forward(self,inputs): # 32*100*512
        dims=[(-i+1) for i in rnage(len(self.parameters_shape))] #[-1]
        mean=inut.mean(dim=dims,keepdim=True) # 32*512*1
        var=((inputs-mean)**2).mean(dim=dims,keepdim=True) # 32*100*1
        std=(var+self.eps).sqrt()  # 32*100*1 
        y=(inputs-mean)/std #32*100*512
        out=self.gamma*y+self.beta # 32*100*512
        return out

In [224]:
class EncoderLayer(nn.Module):
    def __init__(self,d_model,ffn_hidden,nums_heads,dropout):
        super(EncoderLayer,self).__init__()
        self.attention=MultiHeadAttention(d_model=d_model,num_heads=num_heads)
        self.norm1=LayerNormalization(parameters_shape=[d_model])
        self.dropout1=nn.Dropout(p=dropout)
        self.ffn=PositionwiseFeedForward(d_model=d_model,hidden=ffn_hidden,dropout=dropout)
        self.norm2=LayerNormalization(parameters_shape=[d_model])
        self.dropout2=nn.Dropout(p=dropout)
        
    def forward(self,x):
        residual_x=x # 32*100*512
        x=self.attention(x,mask=None)# 32*100*512
        x=self.dropout1(x)#32*100*512
        x=self.norm1(x+residual_x) #32*100*512
        residual_x=x #32*100*512
        x=self.ffn(x) # 32*100*512
        x=self.dropout2(x)# 32*100*512
        x=self.norm2(x+residual_x)# 32*100*512
        return x

In [225]:
class Encoder(nn.Module):
    def __init__(self,d_model,ffn_hidden,num_heads,num_layers,dropout):
        super().__init__()
        self.layers=nn.Sequential(*[EncoderLayer(d_model,ffn_hidden,num_heads,dropout) for _ in range(num_layers)])
    def forward(self,x):
        x=self.layers(x)
        return x

In [226]:
def ScaleDotProduct(q,k,v,mask=None):
    #q,k,v=32*8*100*64
    d_k=q.size()[-1] # 64
    scaled=torch.matmul(q,k.transpose[-1,-2])/math.sqrt(d_k) # 32*8*100*100
    if(mask is not None):
        scaled+=mask # 100*100
    attention=f.softmax(scaled,dim=-1) # 32*8*100*100
    values=torch.matmul(attention,v) # 32*8*100*64
    return values,attention  
    

In [227]:
class MultiHeadAttention(nn.Module):
    def __init__(self,d_model,num_heads):
        super().__init__()
        self.d_model=d_model # 512
        self.head_dim=num_heads # 8
        self.head_dim=d_model//num_heads # 64
        self.qkv_layer=nn.Linear(d_model,3*d_model) # 512*1536
        self.linear_layer=nn.Linear(d_model,d_model) # 512*512
        
    def forward(self,x,mask=None):
        batch_size,seq_len,d_model=x.size() # 32*100*512
        qkv=self.qkv_layer(x) # 32*100*1536
        qkv=qkv.reshape(batch_size,seq_len,self.num_heads,3*self.head_dim) # 32*100*8*192
        qkv=qkv.permute(0,2,1,3) # 32*8*100*192
        q,k,v=qkv.chunk(3,dim=-1) # 32*8*100*(192/3)
        values=values.reshape(batch_size,seq_le,self.num_heads*self.head_dim)
        values,attenton=ScaleDotProduct(q,k,v,mask) # 32*8*100*64
        values=values.reshape(batch_size,seq_len,self.num_heads*self.head_dim) # 32*100*512
        out=self.linear_layer(values)
        return out

In [228]:
d_model=512
num_heads=8
dropout=0.1
batch_size=32
max_seq_len=200
ffn_hidden=2048
num_layers=5

encoder=Encoder(d_model,ffn_hidden,num_heads,num_layers,dropout)

## Transformer Decoder block

In [229]:
class PositionwiseFeedForwarnetwork(nn.Module):
    def __init__(self,d_model,hidden,dropout=0.1):
        super(PositionwiseFeedForwarnetwork,self).__init__()
        self.linear1=nn.Linear(d_model,hidden) # 512*512
        self.linear2=nn.Linear(hidden,d_model) #2048*512
        self.relu=nn.ReLU()
        self.dropout=nn.Dropout(p=dropout)
        
    def forward(self,x): #32*100*512
        x=self.linear1(x) #32*100*2048
        x=self.relu(x) #32*100*2048
        x=self.dropout(x)#32*100*2048
        x=self.linear2(x)#32*100*512
        return x

In [230]:
class LayerNormalization(nn.Module):
    def __init__(self,parameters_shape,eps=1e-5):
        super().__init__()
        self.parameters_shape=parameters_shape #512
        self.eps=eps
        self.gamma=nn.Parameter(torch.ones(parameters_shape)) # 512
        self.beta=nn.Parameter(torch.zeros(parameters_shape))# 512
         
    def forward(self,inputs): # 32*200*512
        dims=[(-i+1) for i in range(len(self.parameters_shape))] #[-1]
        mean=inputs.mean(dim=dims,keepdim=True) # 32*512*1
        var=((inputs-mean)**2).mean(dim=dims,keepdim=True) # 32*200*1
        std=(var+self.eps).sqrt()  # 32*200*1 
        y=(inputs-mean)/std #32*200*512
        out=self.gamma*y+self.beta # 32*200*512
        return out

In [231]:
def scaled_dot_product(q,k,v,mask=None):
    #q,k,v->(30*8*200*64)
    d_k=q.size()[-1]
    scaled=torch.matmul(q,k.transpose(-1,-2))/math.sqrt(d_k)# 32*8*200*200
    if mask is not None:
        scaled+=mask#32*8*200*200
    attention=f.softmax(scaled,dim=-1) # 30*8*200*200
    values=torch.matmul(attention,v)#30*8*200*64
    return values,attention

In [232]:
class MultiheadAttention(nn.Module):
    def __init__(self,d_model,num_heads):
        super().__init__()
        self.d_model=d_model
        self.num_heads=num_heads
        self.head_dim=d_model//num_heads
        self.qkv_layer=nn.Linear(d_model,3*d_model) #1536
        self.linear_layer=nn.Linear(d_model,d_model)
        
    def forward(self,c,mask=None):
        batch_size,seq_len,d_model=x.size() # 32*200*512
        qkv=self.qkv_layer(x)#30*200*1536
        qkv=qkv.reshape(batch_size,seq_len,self.num_heads,3*self.head_dim) #30*200*8*192
        qkv=qkv.permute(0,2,1,3) #32*8*200*192
        q,k,v=qkv.chunk(3,dim=-1) #30*8*200*64<-q,k,v
        values,attention=scaled_dot_product(q,k,v,mask)#32*8*200*64
        values=values.reshape(batch_size,seq_len,self.num_heads*self.head_dim)# 30*200*512
        return out
        

In [233]:
class MultiHeadCrossAttention(nn.Module):
    def __init__(self,d_model,num_heads):
        super().__init__()
        self.d_model=d_model
        self.num_heads=num_heads
        self.head_dim=d_model//num_heads
        self.kv_layer=nn.Linear(d_model,2*d_model)# 1024
        self.q_layer=nn.Linear(d_model,d_model)
        self.linear_layer=nn.Linear(d_model,d_model)
        
    def forward(self,x,y,mask=None):
        batch_size,seq_len,d_model=x.size()# 32*200*512
        kv=self.kv_layer(x)# 32*200*1024
        q=self.q_layer(y)#32*200*512
        kv=kv.reshape(batch_size,seq_len,self.num_heads,2*self.head_dim)#32*200*8*128
        q=q.reshape(batch_size,seq_len,self.num_heads,self.head_dim)# 32*200*8*64
        kv=kv.permute(0,2,1,3)# 32*8*200*128
        q=q.permute(0,2,1,3) # 32*8*200*64
        k,v=kv.chunk(2,dim=-1)# 32*8*200*64
        values,attention=scaled_dot_product(q,k,v,mask)
        values=values.reshape(batch_size,seq_len,d_model)
        return values

In [234]:
class DecoderLayer(nn.Module):
    def __init__(self,d_model,ffn_hidden,num_heads,dropout):
        super(DecoderLayer,self).__init__()
        
        self.self_attention=MultiheadAttention(d_model=d_model,num_heads=num_heads)
        self.norm1=LayerNormalization(parameters_shape=[d_model])
        self.dropout1=nn.Dropout(p=dropout)
        
        self.encoder_decoder_attention=MultiHeadCrossAttention(d_model=d_model,num_heads=num_heads)
        self.norm2=LayerNormalization(parameters_shape=[d_model])
        self.dropout2=nn.Dropout(p=dropout)
        
        self.ffn=PositionwiseFeedForwarnetwork(d_model,hidden=ffn_hidden,dropout=dropout)
        self.norm3=LayerNormalization(parameters_shape=[d_model])
        self.dropout3=nn.Dropout(p=dropout)
        
    def forward(self,x,y,decoder_mask):
        _y=y
        y=self.self_attention(y,mask=decoder_mask)#32*200*512
        y=self.dropout1(y)#32*200*512
        y=self.norm1(y+_y)#32*200*512

        _y=y#32*200*512
        y=self.encoder_decoder_attention(x,y,mask=None)
        y=self.dropout2(y)
        y=self.norm2(y+_y)
        return y

In [235]:
class SequentialDecoder(nn.Sequential):
    def forward(self,*inputs):
        x,y,mask=inputs
        for module in self._modules.values():
            y=module(x,y,mask)
        return y

In [236]:
class Decoder(nn.Module):
    def __init__(self,d_model,ffn_hidden,num_heads,dropout,num_layers=1):
        super().__init__()
        self.layers=SequentialDecoder(*[DecoderLayer(d_model,ffn_hidden,num_heads,dropout) for _ in range(num_layers)])
        
    def forward(self,x,y,mask):
        #x 32*200*512
        #y 32*200*512
        #mask 200*200 
        y=self.layers(x,y,mask)
        return y

In [237]:
d_model=512
num_heads=8
dropout=0.1
batch_size=30
max_seq_len=200
ffn_hidden=20148
num_layers=5

x=torch.randn((batch_size,max_seq_len,d_model))
y=torch.randn((batch_size,max_seq_len,d_model))
mask=torch.full([max_seq_len,max_seq_len],float('1e-9'))
mask=torch.triu(mask,diagonal=1)
decoder=Decoder(d_model,ffn_hidden,num_heads,dropout,num_layers)
out=decoder(x,y,mask)