In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import math 

In [2]:
class Single_attention(nn.Module):
    
    """
    Single attention computation class
    Scaled dot product based attention
    """
    
    def forward(self, query, key, value, mask=None, dropout=None):
        
        scores = torch.matmul(query,key.transpose(-2,-1))/math.sqrt(query.size(-1))
        
        if mask is not None:
            scores = scores.masked_fill(mask==0, -1e9)

        attention_scores=F.softmax(scores, dim=-1)
            
        if dropout is not None:
            attention_scores = dropout(attention_scores)
        
        attention_value=torch.matmul(attention_scores,value)
        
        return attention_value,attention_scores

In [3]:
class Multhead_attention(nn.Module):
    def __init__(self, num_head, hid_model, dropout=0.1):
        super().__init__()

        assert hid_model%num_head == 0
        self.hid_head = hid_model//num_head
        self.num_head = num_head
        
        #self.linear_layers = nn.ModuleList([nn.Linear(hid_model, hid_model) for _ in range(3)])
        
        self.query_matrix = nn.Linear(hid_model,hid_model)
        self.key_matrix = nn.Linear(hid_model,hid_model)
        self.value_matrix = nn.Linear(hid_model,hid_model)
        
        self.attention = Single_attention()
        self.dropout = nn.Dropout(p=dropout)
        
        self.output_linear = nn.Linear(hid_model,hid_model)
        
    def forward(self, query, key, value, mask=None):
        batch_size=query.size(0)
        
        #query, key, value = [l(x).view(batch_size, -1, self.num_head, self.hid_head).transpose(1, 2)
                             #for l, x in zip(self.linear_layers, (query, key, value))]
        
        query = self.query_matrix(query).view(batch_size, -1, self.num_head, self.hid_head).transpose(1,2)
        key = self.query_matrix(key).view(batch_size, -1, self.num_head, self.hid_head).transpose(1,2)
        value = self.query_matrix(value).view(batch_size, -1, self.num_head, self.hid_head).transpose(1,2)
        
        attention_value, attention_scores = self.attention(query, key, value, mask=mask, dropout=self.dropout)
        attention_value = attention_value.transpose(1,2).contiguous().view(batch_size,-1,self.num_head*self.hid_head)
        
        output = self.output_linear(attention_value)
        
        return output

In [4]:
from feed_forward import PositionwiseFeedForward
from sublayer import SublayerConnection

class Transformer_block(nn.Module):
    def __init__(self, hid_model, num_head, hid_ff, dropout):
        super().__init__()
        
        self.attention = Multhead_attention(num_head=num_head, hid_model=hid_model, dropout=dropout)
        self.feed_forward = PositionwiseFeedForward(d_model=hid_model, d_ff=hid_ff, dropout=dropout)
        self.input_sublayer = SublayerConnection(size=hid_model, dropout=dropout)
        self.output_sublayer = SublayerConnection(size=hid_model, dropout=dropout)
        self.dropout = nn.Dropout(p=dropout)
    
    
    def forward(self, x, mask):
        x = self.input_sublayer(x, lambda _x: self.attention(_x, _x, _x, mask=mask))
        x = self.output_sublayer(x, self.feed_forward)
        x = self.dropout(x)
        
        return x

        
        
        

In [21]:
input_vector = torch.rand(5,5,300)
mask = torch.Tensor([[[[1,1,1,1,0],[1,1,1,1,0],[1,1,1,1,0],[1,1,1,1,0],[0,0,0,0,0]]],[[[1,1,1,1,0],[1,1,1,1,0],[1,1,1,1,0],[1,1,1,1,0],[0,0,0,0,0]]],[[[1,1,1,1,0],[1,1,1,1,0],[1,1,1,1,0],[1,1,1,1,0],[0,0,0,0,0]]],[[[1,1,1,1,0],[1,1,1,1,0],[1,1,1,1,0],[1,1,1,1,0],[0,0,0,0,0]]],[[[1,1,1,1,0],[1,1,1,1,0],[1,1,1,1,0],[1,1,1,1,0],[0,0,0,0,0]]]])




In [24]:
tfb = Transformer_block(300,10,2000,0.1)

tensor([[[ 0.3810,  0.8385, -0.5023,  ...,  1.1234,  0.6241,  0.1547],
         [ 0.2362,  0.7687,  0.2799,  ...,  1.1560,  0.7709,  1.0626],
         [ 0.1515,  0.0000,  0.0000,  ..., -0.1523,  1.0768,  0.7858],
         [ 0.4177,  0.8347,  0.3368,  ..., -0.0457,  0.1850,  0.7222],
         [ 0.4843,  0.2518,  0.7744,  ...,  0.0000,  1.3551,  0.7600]],

        [[ 0.5796,  1.3251,  0.3713,  ...,  0.4648,  0.0000,  0.9455],
         [ 0.3733, -0.0198,  0.1568,  ...,  1.7718,  2.1061,  0.1078],
         [ 0.6474,  0.5058,  0.2743,  ...,  0.4379,  0.5351,  0.1076],
         [ 1.0242,  0.9305,  1.4672,  ...,  1.3323,  1.2964,  0.4655],
         [ 0.4504,  0.1567,  0.6198,  ...,  0.4826,  0.4744,  0.8823]],

        [[ 0.5228,  0.0000, -0.2845,  ...,  0.6848,  0.0000,  0.2633],
         [ 0.3443, -0.5753,  0.1803,  ...,  0.3606,  0.0000,  0.0000],
         [ 0.0000,  0.1812,  0.0000,  ...,  0.8450,  0.1984,  0.4250],
         [ 0.6007,  1.1472, -0.2138,  ...,  0.4108,  0.0974,  0.0000],
  