In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import math 

In [2]:
class Single_attention(nn.Module):
    
    """
    Single attention computation class
    Scaled dot product based attention
    """
    
    def forward(self, query, key, value, mask=None, dropout=None):
        
        scores = torch.matmul(query,key.transpose(-2,-1))/math.sqrt(query.size(-1))
        
        if mask is not None:
            scores = scores.masked_fill(mask==0, -1e9)

        attention_scores=F.softmax(scores, dim=-1)
            
        if dropout is not None:
            attention_scores = dropout(attention_scores)
        
        attention_value=torch.matmul(attention_scores,value)
        
        return attention_value,attention_scores

In [3]:
class Multhead_attention(nn.Module):
    def __init__(self, num_head, hid_model, dropout=0.1):
        super().__init__()

        assert hid_model%num_head == 0
        self.hid_head = hid_model//num_head
        self.num_head = num_head
        
        #self.linear_layers = nn.ModuleList([nn.Linear(hid_model, hid_model) for _ in range(3)])
        
        self.query_matrix = nn.Linear(hid_model,hid_model)
        self.key_matrix = nn.Linear(hid_model,hid_model)
        self.value_matrix = nn.Linear(hid_model,hid_model)
        
        self.attention = Single_attention()
        self.dropout = nn.Dropout(p=dropout)
        
        self.output_linear = nn.Linear(hid_model,hid_model)
        
    def forward(self, query, key, value, mask=None):
        batch_size=query.size(0)
        
        #query, key, value = [l(x).view(batch_size, -1, self.num_head, self.hid_head).transpose(1, 2)
                             #for l, x in zip(self.linear_layers, (query, key, value))]
        
        query = self.query_matrix(query).view(batch_size, -1, self.num_head, self.hid_head).transpose(1,2)
        key = self.query_matrix(key).view(batch_size, -1, self.num_head, self.hid_head).transpose(1,2)
        value = self.query_matrix(value).view(batch_size, -1, self.num_head, self.hid_head).transpose(1,2)
        
        attention_value, attention_scores = self.attention(query, key, value, mask=mask, dropout=self.dropout)
        attention_value = attention_value.transpose(1,2).contiguous().view(batch_size,-1,self.num_head*self.hid_head)
        
        output = self.output_linear(attention_value)
        
        return output

In [4]:
from feed_forward import PositionwiseFeedForward
from sublayer import SublayerConnection

class Transformer_block(nn.Module):
    def __init__(self, hid_model, num_head, hid_ff, dropout):
        super().__init__()
        
        self.attention = Multhead_attention(num_head=num_head, hid_model=hid_model, dropout=dropout)
        self.feed_forward = PositionwiseFeedForward(d_model=hid_model, d_ff=hid_ff, dropout=dropout)
        self.input_sublayer = SublayerConnection(size=hid_model, dropout=dropout)
        self.output_sublayer = SublayerConnection(size=hid_model, dropout=dropout)
        self.dropout = nn.Dropout(p=dropout)
    
    
    def forward(self, x, mask):
        x = self.input_sublayer(x, lambda _x: self.attention(_x, _x, _x, mask=mask))
        x = self.output_sublayer(x, self.feed_forward)
        x = self.dropout(x)
        
        return x

        
        
        

In [14]:
class MultiLayerTransformer(nn.Module):
    """
    hid_model, num_layers, num_head, dropout are same as BERT
    """
    def __init__(self, hid_model=768, num_layers=12, num_head=12, dropout=0.1):
        super().__init__()
        
        self.hid_model = hid_model
        self.num_layers = num_layers
        self.num_head = num_head 
        self.dropout = dropout

        # paper noted they used 4*hidden_size for ff_network_hidden_size
        self.feed_forward_hidden = self.hid_model * 4
        self.transformer_blocks = nn.ModuleList(
            [Transformer_block(self.hid_model,self.num_head,self.hid_model*4,self.dropout) for _ in range(self.num_layers)])
        
    def forward(self,x,mask):
        for transformer_block in self.transformer_blocks :
            x=transformer_block(x,mask)
            
        return x

# TEST 

In [12]:
input_vector = torch.rand(5,5,300)
mask = torch.Tensor([[[[1,1,1,1,0],[1,1,1,1,0],[1,1,1,1,0],[1,1,1,1,0],[0,0,0,0,0]]],[[[1,1,1,1,0],[1,1,1,1,0],[1,1,1,1,0],[1,1,1,1,0],[0,0,0,0,0]]],[[[1,1,1,1,0],[1,1,1,1,0],[1,1,1,1,0],[1,1,1,1,0],[0,0,0,0,0]]],[[[1,1,1,1,0],[1,1,1,1,0],[1,1,1,1,0],[1,1,1,1,0],[0,0,0,0,0]]],[[[1,1,1,1,0],[1,1,1,1,0],[1,1,1,1,0],[1,1,1,1,0],[0,0,0,0,0]]]])




In [6]:
tfb = Transformer_block(300,10,2000,0.1)

In [7]:
output = tfb(input_vector,mask=mask)

In [8]:
output.shape

torch.Size([5, 5, 300])

In [17]:
mtf = MultiLayerTransformer(hid_model=300)

In [19]:
mtf(input_vector,mask).shape

torch.Size([5, 5, 300])