In [1]:
import torch
import torch.nn as nn

In [3]:
class Decoder(nn.Module):
    def __init__(self,emb_dim,hidden_dim,num_heads,drop_prob=0.1):
        super().__init__()
        self.multi_attention=MultiHeadAttention(emb_dim,num_heads)
        self.norm1=nn.LayerNorm(emb_dim)
        self.dropout1=nn.Dropout(p=drop_prob)
        self.cross_attention=CrossMultiHeadAttention(emb_dim,num_heads)
        self.norm2=nn.LayerNorm(emb_dim)
        self.dropout2=nn.Dropout(p=drop_prob)
        self.ffn=FeedForwardNetwork(emb_dim,hidden_dim)
        self.norm3=nn.LayerNorm(emb_dim)
        self.dropout3=nn.Dropout(p=drop_prob)
        
    def forward(self,x,y,mask):
        residual_x=x
        x=self.multi_attention(x, mask=mask)
        x=self.dropout1(x)
        x=x+residual_x
        x=self.norm1(x)
        
        residual_x=x
        x=self.cross_attention(x,y,mask=None)
        x=self.dropout2(x)
        x=x+residual_x
        x=self.norm2(x)

        residual_x=x
        x=self.ffn(x)
        x=self.dropout3(x)
        x=x+residual_x
        x=self.norm3(x)
        return x


        
        
        
        
        
        

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self,emb_dim,num_heads):
        super().__init__()
        self.emb_dim=emb_dim                                # Stores dimension of each word 
        self.num_heads=num_heads                            # No of self attentions running parallel
        self.head_dim=emb_dim//num_heads                    # Each word dimension in 1 self attention
        
        self.qkv_layer=nn.Linear(emb_dim,3*emb_dim)         # Splitting into q,k,v (for better performance joing q,k,v[512:512:512]
        self.linear=nn.Linear(emb_dim,emb_dim)              # This layer do not change dim but improves learning after concating of q,k,v

        
    def forward(self,x,mask=None):
        batch_size,each_sen_length,emb_dim=x.shape 
        # Shape of x is [batch_size,each_sen_length,emb_dim] -> {32,100,512}
        x=self.qkv_layer(x)                                       
        # After concatenation of q,k,v-->(emb_dim+emb_dim+emb_dim) shape of x is [batch_size,each_sen_length,emb_dim*3]-> {32,100,3*512}
        x=x.reshape(batch_size,each_sen_length,self.num_heads,3*self.head_dim)
        # This line divides into no of self attention(num heads)  shape of x is [batch_size,each_sen_length,num_heads,(emb_dim*3)/num_heads] -> {32,100,8,3*64}
        x=x.permute(0,2,1,3)  
        # shape of x is [batch_size,num_heads,each_sen_length,(emb_dim*3)/num_heads] -> {32,8,100,3*64}
        query,key,value=x.chunk(3,dim=-1)
        # Dividing into q,k,v -> each dim = [batch_size,num_heads,each_sen_length,(emb_dim)/num_heads] -> {32,8,100,64}
        values=scaled_dot_attention(query,key,value,mask)
        # dim of values = [batch_size,num_heads,each_sen_length,(emb_dim)/num_heads] -> {32,8,100,64}
        values = values.permute(0, 2, 1, 3)
        # shape of x is [batch_size,each_sen_length,num_heads,(emb_dim*3)/num_heads] -> {32,100,8,3*64}
        values=values.reshape(batch_size,each_sen_length,self.emb_dim)
        # dim of values = [batch_size,num_heads,each_sen_length,(emb_dim)/num_heads] -> {32,100,512}
        out=self.linear(values) 
        # After concatenation of all num_heads -> goes into liner network -> increase learning of all num_heads
        return out
        

In [4]:
def scaled_dot_attention(q,k,v,mask=None):
    # dim of q/k/v = [batch_size,num_heads,each_sen_length,each_head_emb_dim(head_dim)]
    dim_k=k.size()[-1]  # dim_k = head_dim
    scaled=torch.matmul(q,k.transpose(-1,-2))/math.sqrt(dim_k)
    if mask is not None:
        scaled=scaled+mask
    attention=torch.softmax(scaled,dim=-1)
    values=torch.matmul(attention,v)
    return values

In [None]:
class FeedForwardNetwork(nn.Module):
    def __init__(self,emb_dim,hidden_dim,drop_prob=0.1):
        super().__init__()
        self.linear1=nn.Linear(emb_dim,hidden_dim)
        self.relu=nn.ReLU()
        self.dropout=nn.Dropout(p=drop_prob)
        self.linear2=nn.Linear(hidden_dim,emb_dim)

    def forward(self,x):
        x=self.linear1(x)
        x=self.relu(x)
        x=self.dropout(x)
        x=self.linear2(x)
        return x
        

In [None]:
class CrossMultiHeadAttention(nn.Module):
    def __init__(self,emb_dim,num_heads):
        super().__init__()                                             
        self.emb_dim=emb_dim                                # Stores dimension of each word 
        self.num_heads=num_heads                            # No of self attentions running parallel
        self.head_dim=emb_dim//num_heads                    # Each word dimension in 1 self attention
        
        self.q_layer=nn.Linear(emb_dim,emb_dim)
        self.kv_layer=nn.Linear(emb_dim,2*emb_dim)         # Splitting into k,v (for better performance joing q,k,v[512:512:512]
        self.linear=nn.Linear(emb_dim,emb_dim)              # This layer do not change dim but improves learning after concating of q,k,v

        
    def forward(self,x,y,mask=None):
        batch_size,each_sen_length,emb_dim=x.shape 
        # Shape of x is [batch_size,each_sen_length,emb_dim] -> {32,100,512}
        batch_size, each_sen_length_y, _ = y.shape
        x=self.kv_layer(x)                                       
        query=self.q_layer(y)
        # After concatenation of k,v-->(emb_dim+emb_dim) shape of x is [batch_size,each_sen_length,emb_dim*2]-> {32,100,2*512}
        x=x.reshape(batch_size,each_sen_length,self.num_heads,2*self.head_dim)
        # This line divides into no of self attention(num heads)  shape of x is [batch_size,each_sen_length,num_heads,(emb_dim*3)/num_heads] -> {32,100,8,3*64}
        query=query.reshape(batch_size,each_sen_length_y,self.num_heads,self.head_dim)
         # This line divides into no of self attention(num heads)  shape of query is [batch_size,each_sen_length,num_heads,(emb_dim*)/num_heads] -> {32,100,8,64}
        x=x.permute(0,2,1,3)  
        # shape of x is, [batch_size,num_heads,each_sen_length,(emb_dim*3)/num_heads] -> {32,8,100,3*64},
        query=query.permute(0,2,1,3)
        # shape of query is, [batch_size,num_heads,each_sen_length,(emb_dim)/num_heads] -> {32,8,100,64},
        key,value=x.chunk(2,dim=-1)
        # Dividing into k,v -> each dim = [batch_size,num_heads,each_sen_length,(emb_dim)/num_heads] -> {32,8,100,64}
        values=scaled_dot_attention(query,key,value,mask)
        # dim of values = [batch_size,num_heads,each_sen_length,(emb_dim)/num_heads] -> {32,8,100,64}
        values = values.permute(0, 2, 1, 3)
        # shape of x is [batch_size,each_sen_length,num_heads,(emb_dim*3)/num_heads] -> {32,100,8,3*64}
        values=values.reshape(batch_size,each_sen_length_y,self.emb_dim)
        # dim of values = [batch_size,num_heads,each_sen_length,(emb_dim)/num_heads] -> {32,100,512}
        out=self.linear(values) 
        # After concatenation of all num_heads -> goes into liner network -> increase learning of all num_heads
        return out
        