In [3]:
import torch
import torch.nn as nn
import torch.nn.init as init
from Modules import BottleLinear as Linear
from Modules import ScaledDotProductAttention,LayerNormalization
import numpy as np
from torch.autograd import Variable
import Constants
from random import randint

# sublayers

In [2]:
class MultiHeadAttention(nn.Module):
    def __init__(self,n_head,d_model,d_k,d_v,dropout=0.1):
        super(MultiHeadAttention,self).__init__()
        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v
        
        self.w_qs = nn.Parameter(torch.FloatTensor(n_head,d_model,d_k))
        self.w_ks = nn.Parameter(torch.FloatTensor(n_head,d_model,d_k))
        self.w_vs = nn.Parameter(torch.FloatTensor(n_head,d_model,d_v))
        
        self.attention = ScaleDotProductAttention(d_model)
        self.layer_norm = LayerNormalization(d_model)
        self.proj = Linear(n_head*d_v,d_model)
        
        self.dropout = nn.Dropout(dropout)
        
        init.xavier_normal(self.w_qs)
        init.xavier_normal(self.w_ks)
        init.xavier_normal(self.w_vs)
        
    def forward(self,q,k,v,attn_mask=None):
        d_k,d_v = self.d_k,self.d_v
        n_head = self.n_head
        
        residual = q
        
        mb_size, len_q, d_model = q.size()
        mb_size, len_k, d_model = k.size()
        mb_size, len_v, d_model = v.size()
        
        q_s = q.repeat(n_head,1,1).view(n_head,-1,d_model)
        k_s = k.repeat(n_head,1,1).view(n_head,-1,d_model)
        v_s = v.repeat(n_head,1,1).view(n_head,-1,d_model)
        
        q_s = torch.bmm(q_s,self.w_qs).view(-1,len_q,d_k)
        k_s = torch.bmm(k_s,self.w_ks).view(-1,len_k,d_k)
        v_s = torch.bmm(v_s,self.w_vs).view(-1,len_v,d_v)
        
        outputs,attns = self.attention(q_s,k_s,v_s,att_mask=attn_mask.repeat(n_head,1,1))
        
        outputs = torch.cat(torch.split(outputs,mb_size,dim=0),dim=-1)
        outputs = self.proj(outputs)
        outputs = self.dropout(outputs)
        
        return self.layer_norm(outputs + residual),attns

In [3]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self,d_hid,d_inner_hid,dropout=0.1):
        super(PositionwiseFeedForward,self).__init__()
        self.w_1 = nn.Conv1d(d_hid,d_inner,1)
        self.w_2 = nn.Conv1d(d_inner,d_hid,1)
        self.layer_norm = LayerNormalization(d_hid)
        self.dropout = nn.Dropout(dropout)
        self.relu = nn.ReLU()
        
    def forward(self,x):
        residual = x
        outputs = self.relu(self.w_1(x.transpose(1,2)))
        outputs = self.w_2(outputs).transpose(2,1)
        outputs = self.dropout(outputs)
        return self.layer_norm(output+residual)

In [6]:
from SubLayers import MultiHeadAttention,PositionwiseFeedForward

In [7]:
class EncoderLayer(nn.Module):
    def __init__(self,d_model,d_inner_hid,n_head,d_k,d_v,dropout=0.1):
        super(EncoderLayer,self).__init__()
        self.slf_attn = MultiHeadAttention(
            n_head,d_model,d_k,d_v,dropout=dropout)
        self.pos_ffn = PositionwiseFeedForward(d_model,d_inner_hid,dropout=dropout)
        
    def forward(self,enc_input,self_attn_mask=None):
        enc_output,enc_slf_attn = self.slf_attn(
            enc_input,enc_input,enc_input,attn_mask=slf_attn_mask)
        enc_output = self.pos_ffn(enc_output)
        return enc_output,enc_slf_attn

In [8]:
class DecoderLayer(nn.Module):
    def __init__(self,d_model,d_inner_hid,n_head,d_k,d_v,dropout=0.1):
        super(DecoderLayer,self).__init__()
        self.slf_attn = MultiHeadAttention(n_head,d_model,d_k,d_v,dropout=dropout)
        self.enc_attn = MultiHeadAttention(n_head,d_model,d_k,d_v,dropout=dropout)
        self.pos_ffn = PositionwiseFeedForward(d_model,d_inner_hid,dropout=dropout)
        
    def forward(self,dec_input,enc_output,slf_attn_mask=None,dec_enc_attn_mask=None):
        dec_output,dec_slf_attn = self.slf_attn(
            dec_input,dec_input,dec_input,attn_mask=slf_attn_mask)
        dec_output,dec_enc_attn = self.enc_attn(
            dec_output,enc_output,enc_output,attn_mask=dec_enc_attn_mask)
        dec_output = self.pos_ffn(dec_output)
        return dec_output,dec_slf_attn,dec_enc_attn

In [9]:
from Layers import EncoderLayer,DecoderLayer

In [10]:
def position_encoding_init(n_position,d_pos_vec):
    position_enc = nn.array([
        [pos / np.power(10000,2*(j//2)/d_pos_vec) for j in range(d_pos_vec)]
        if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)
    ])
    position_enc[1:,0::2] = np.sin(position_enc[1:,0::2])
    position_enc[1:,1::2] = np.cos(position_enc[1:,1::2])
    return torch.from_numpy(position_enc).type(torch.FloatTensor)

In [11]:
def get_attn_padding_mask(seq_q,seq_k):
    assert seq_q.dim() == 2 and seq_k.dim() == 2
    mb_size,len_q = seq_q.size()
    mb_size,len_k = seq_k.size()
    pad_attn_mask = seq_k.data.eq(Constants.PAD).unsqueeze(1)
    pad_attn_mask = pad_attn_mask.expand(mb_size,len_q,len_k)
    return pad_attn_mask

In [48]:
def get_attn_subsequent_mask(seq):
    assert seq.dim() == 2
    attn_shape = (seq.size(0),seq.size(1),seq.size(1))
    subsequent_mask = np.triu(np.ones(attn_shape),k=1).astype('uint8')
    subsequent_mask = torch.from_numpy(subsequent_mask)
    return subsequent_mask

In [27]:
test_data = np.concatenate((np.random.random((10,3)),np.zeros((3,3))),axis=0)

In [28]:
test_data = torch.from_numpy(test_data).type(torch.FloatTensor)
test_data = Variable(test_data)

In [33]:
test_data.size()

torch.Size([13, 3])

In [32]:
get_attn_padding_mask(test_data,test_data)


(0 ,.,.) = 
  0  0  0
  0  0  0
  0  0  0

(1 ,.,.) = 
  0  0  0
  0  0  0
  0  0  0

(2 ,.,.) = 
  0  0  0
  0  0  0
  0  0  0

(3 ,.,.) = 
  0  0  0
  0  0  0
  0  0  0

(4 ,.,.) = 
  0  0  0
  0  0  0
  0  0  0

(5 ,.,.) = 
  0  0  0
  0  0  0
  0  0  0

(6 ,.,.) = 
  0  0  0
  0  0  0
  0  0  0

(7 ,.,.) = 
  0  0  0
  0  0  0
  0  0  0

(8 ,.,.) = 
  0  0  0
  0  0  0
  0  0  0

(9 ,.,.) = 
  0  0  0
  0  0  0
  0  0  0

(10,.,.) = 
  1  1  1
  1  1  1
  1  1  1

(11,.,.) = 
  1  1  1
  1  1  1
  1  1  1

(12,.,.) = 
  1  1  1
  1  1  1
  1  1  1
[torch.ByteTensor of size 13x3x3]

In [49]:
get_attn_subsequent_mask(test_data)


(0 ,.,.) = 
  0  1  1
  0  0  1
  0  0  0

(1 ,.,.) = 
  0  1  1
  0  0  1
  0  0  0

(2 ,.,.) = 
  0  1  1
  0  0  1
  0  0  0

(3 ,.,.) = 
  0  1  1
  0  0  1
  0  0  0

(4 ,.,.) = 
  0  1  1
  0  0  1
  0  0  0

(5 ,.,.) = 
  0  1  1
  0  0  1
  0  0  0

(6 ,.,.) = 
  0  1  1
  0  0  1
  0  0  0

(7 ,.,.) = 
  0  1  1
  0  0  1
  0  0  0

(8 ,.,.) = 
  0  1  1
  0  0  1
  0  0  0

(9 ,.,.) = 
  0  1  1
  0  0  1
  0  0  0

(10,.,.) = 
  0  1  1
  0  0  1
  0  0  0

(11,.,.) = 
  0  1  1
  0  0  1
  0  0  0

(12,.,.) = 
  0  1  1
  0  0  1
  0  0  0
[torch.ByteTensor of size 13x3x3]

In [37]:
test_data2 = torch.FloatTensor(np.random.random((13,3,3)))

In [38]:
test_data2


(0 ,.,.) = 
  0.3283  0.2818  0.1895
  0.8391  0.9610  0.6921
  0.8609  0.7171  0.9459

(1 ,.,.) = 
  0.0230  0.6590  0.4776
  0.7684  0.3062  0.8785
  0.3999  0.7001  0.4939

(2 ,.,.) = 
  0.6848  0.0378  0.0957
  0.9445  0.5699  0.9185
  0.4937  0.5345  0.8517

(3 ,.,.) = 
  0.8327  0.2719  0.1924
  0.2072  0.3360  0.3084
  0.6014  0.9138  0.4349

(4 ,.,.) = 
  0.0465  0.2373  0.9371
  0.3756  0.4465  0.1945
  0.2218  0.7685  0.5120

(5 ,.,.) = 
  0.1172  0.6778  0.7712
  0.0508  0.2874  0.6622
  0.6880  0.0408  0.1986

(6 ,.,.) = 
  0.2820  0.4468  0.9389
  0.3088  0.8669  0.7539
  0.2834  0.9564  0.4440

(7 ,.,.) = 
  0.1776  0.1531  0.9330
  0.0240  0.9338  0.4544
  0.5160  0.1940  0.1774

(8 ,.,.) = 
  0.0095  0.9395  0.8145
  0.6572  0.3415  0.8222
  0.3853  0.3665  0.0445

(9 ,.,.) = 
  0.6685  0.2727  0.2630
  0.4169  0.2593  0.9923
  0.4314  0.0356  0.3390

(10,.,.) = 
  0.0848  0.0101  0.2678
  0.9473  0.6298  0.0969
  0.3271  0.4015  0.6972

(11,.,.) = 
  0.7665  0.4391  0

In [41]:
test_data2.masked_fill_(get_attn_padding_mask(test_data,test_data),-float('inf'))


(0 ,.,.) = 
  0.3283  0.2818  0.1895
  0.8391  0.9610  0.6921
  0.8609  0.7171  0.9459

(1 ,.,.) = 
  0.0230  0.6590  0.4776
  0.7684  0.3062  0.8785
  0.3999  0.7001  0.4939

(2 ,.,.) = 
  0.6848  0.0378  0.0957
  0.9445  0.5699  0.9185
  0.4937  0.5345  0.8517

(3 ,.,.) = 
  0.8327  0.2719  0.1924
  0.2072  0.3360  0.3084
  0.6014  0.9138  0.4349

(4 ,.,.) = 
  0.0465  0.2373  0.9371
  0.3756  0.4465  0.1945
  0.2218  0.7685  0.5120

(5 ,.,.) = 
  0.1172  0.6778  0.7712
  0.0508  0.2874  0.6622
  0.6880  0.0408  0.1986

(6 ,.,.) = 
  0.2820  0.4468  0.9389
  0.3088  0.8669  0.7539
  0.2834  0.9564  0.4440

(7 ,.,.) = 
  0.1776  0.1531  0.9330
  0.0240  0.9338  0.4544
  0.5160  0.1940  0.1774

(8 ,.,.) = 
  0.0095  0.9395  0.8145
  0.6572  0.3415  0.8222
  0.3853  0.3665  0.0445

(9 ,.,.) = 
  0.6685  0.2727  0.2630
  0.4169  0.2593  0.9923
  0.4314  0.0356  0.3390

(10,.,.) = 
    -inf    -inf    -inf
    -inf    -inf    -inf
    -inf    -inf    -inf

(11,.,.) = 
    -inf    -inf   

In [50]:
test_data2.masked_fill_(get_attn_subsequent_mask(test_data),-float('inf'))


(0 ,.,.) = 
  0.3283    -inf    -inf
  0.8391  0.9610    -inf
  0.8609  0.7171  0.9459

(1 ,.,.) = 
  0.0230    -inf    -inf
  0.7684  0.3062    -inf
  0.3999  0.7001  0.4939

(2 ,.,.) = 
  0.6848    -inf    -inf
  0.9445  0.5699    -inf
  0.4937  0.5345  0.8517

(3 ,.,.) = 
  0.8327    -inf    -inf
  0.2072  0.3360    -inf
  0.6014  0.9138  0.4349

(4 ,.,.) = 
  0.0465    -inf    -inf
  0.3756  0.4465    -inf
  0.2218  0.7685  0.5120

(5 ,.,.) = 
  0.1172    -inf    -inf
  0.0508  0.2874    -inf
  0.6880  0.0408  0.1986

(6 ,.,.) = 
  0.2820    -inf    -inf
  0.3088  0.8669    -inf
  0.2834  0.9564  0.4440

(7 ,.,.) = 
  0.1776    -inf    -inf
  0.0240  0.9338    -inf
  0.5160  0.1940  0.1774

(8 ,.,.) = 
  0.0095    -inf    -inf
  0.6572  0.3415    -inf
  0.3853  0.3665  0.0445

(9 ,.,.) = 
  0.6685    -inf    -inf
  0.4169  0.2593    -inf
  0.4314  0.0356  0.3390

(10,.,.) = 
    -inf    -inf    -inf
    -inf    -inf    -inf
    -inf    -inf    -inf

(11,.,.) = 
    -inf    -inf   

In [5]:
torch.LongTensor([[0,1,2,0,0],[2,0,0,1,2]])


 0  1  2  0  0
 2  0  0  1  2
[torch.LongTensor of size 2x5]

In [57]:
eye = torch.eye(10).long().unsqueeze(2).repeat(1,1,2)

In [59]:
inputs = torch.arange(0,10).long().unsqueeze(0).expand(10,-1).unsqueeze(2).repeat(1,1,2)

In [72]:
linear = nn.Linear(2,3)

In [73]:
linear(Variable(eye).float())

Variable containing:
(0 ,.,.) = 
  0.3405  0.4525  0.0260
  0.4785 -0.3143  0.2136
  0.4785 -0.3143  0.2136
  0.4785 -0.3143  0.2136
  0.4785 -0.3143  0.2136
  0.4785 -0.3143  0.2136
  0.4785 -0.3143  0.2136
  0.4785 -0.3143  0.2136
  0.4785 -0.3143  0.2136
  0.4785 -0.3143  0.2136

(1 ,.,.) = 
  0.4785 -0.3143  0.2136
  0.3405  0.4525  0.0260
  0.4785 -0.3143  0.2136
  0.4785 -0.3143  0.2136
  0.4785 -0.3143  0.2136
  0.4785 -0.3143  0.2136
  0.4785 -0.3143  0.2136
  0.4785 -0.3143  0.2136
  0.4785 -0.3143  0.2136
  0.4785 -0.3143  0.2136

(2 ,.,.) = 
  0.4785 -0.3143  0.2136
  0.4785 -0.3143  0.2136
  0.3405  0.4525  0.0260
  0.4785 -0.3143  0.2136
  0.4785 -0.3143  0.2136
  0.4785 -0.3143  0.2136
  0.4785 -0.3143  0.2136
  0.4785 -0.3143  0.2136
  0.4785 -0.3143  0.2136
  0.4785 -0.3143  0.2136

(3 ,.,.) = 
  0.4785 -0.3143  0.2136
  0.4785 -0.3143  0.2136
  0.4785 -0.3143  0.2136
  0.3405  0.4525  0.0260
  0.4785 -0.3143  0.2136
  0.4785 -0.3143  0.2136
  0.4785 -0.3143  0.2136
  0

In [11]:
step_idx = torch.FloatTensor(256,38,38).zero_()

In [41]:
torch.arange(0,38).long().unsqueeze(0).expand(128,-1).unsqueeze(2)


( 0 ,.,.) = 
   0
   1
   2
 ⋮  
  35
  36
  37

( 1 ,.,.) = 
   0
   1
   2
 ⋮  
  35
  36
  37

( 2 ,.,.) = 
   0
   1
   2
 ⋮  
  35
  36
  37
... 

(125,.,.) = 
   0
   1
   2
 ⋮  
  35
  36
  37

(126,.,.) = 
   0
   1
   2
 ⋮  
  35
  36
  37

(127,.,.) = 
   0
   1
   2
 ⋮  
  35
  36
  37
[torch.LongTensor of size 128x38x1]

In [29]:
step_idx.scatter_(2,torch.arange(0,38).long().unsqueeze(0).expand(256,-1).unsqueeze(2),1)


( 0 ,.,.) = 
   1   0   0  ...    0   0   0
   0   1   0  ...    0   0   0
   0   0   1  ...    0   0   0
     ...       ⋱       ...    
   0   0   0  ...    1   0   0
   0   0   0  ...    0   1   0
   0   0   0  ...    0   0   1

( 1 ,.,.) = 
   1   0   0  ...    0   0   0
   0   1   0  ...    0   0   0
   0   0   1  ...    0   0   0
     ...       ⋱       ...    
   0   0   0  ...    1   0   0
   0   0   0  ...    0   1   0
   0   0   0  ...    0   0   1

( 2 ,.,.) = 
   1   0   0  ...    0   0   0
   0   1   0  ...    0   0   0
   0   0   1  ...    0   0   0
     ...       ⋱       ...    
   0   0   0  ...    1   0   0
   0   0   0  ...    0   1   0
   0   0   0  ...    0   0   1
... 

(253,.,.) = 
   1   0   0  ...    0   0   0
   0   1   0  ...    0   0   0
   0   0   1  ...    0   0   0
     ...       ⋱       ...    
   0   0   0  ...    1   0   0
   0   0   0  ...    0   1   0
   0   0   0  ...    0   0   1

(254,.,.) = 
   1   0   0  ...    0   0   0
   0   1   0  ...    0   0

In [4]:
class Encoder(nn.Module):
    def __init__(self,n_max_seq,d_model,n_layers=6,n_heads=8,d_k=64,d_v=64,
                 d_pos=512,d_inner_hid=2,dropout=0.1):
        super(Encoder,self).__init__()
        n_position = n_max_seq + 1
        self.position_enc = nn.Embedding(
            n_position,d_pos,padding_idx=Constants.PAD)
        self.position_enc.weight.data = position_encoding_init(
            n_position,d_pos)
        self.layer_stack = nn.ModuleList([
            EncoderLayer(d_model,d_inner_hid,n_head,d_k,d_v,dropout=dropout)
            for _ in range(n_layers)
        ])
    def forward(self,src_seq,src_pos,return_attns=False):
        enc_input = src_seq
        enc_input += self.position_enc(src_pos)
        if return_attns:
            enc_slf_attns = []
        enc_output = enc_input
        enc_slf_attn_mask = get_attn_padding_mask(
            src_seq[:,:,0],src_seq[:,:,0])
        for enc_layer in self.layer_stack:
            enc_output, enc_slt_attn = enc_layer(
                enc_output,slf_attn_mask=enc_slf_attn_mask)
            if return_attns:
                enc_slf_attns += [enc_slf_attn]
        if return_attns:
            return enc_output,enc_slf_attns
        return enc_output

In [None]:
class Decoder(nn.Module):
    def __init__(self,n_max_seq,d_model,n_layers=6,n_head=8,d_k=64,d_v=64,
                d_pos=512,d_inner_hid=1024,dropout=0.1):
        super(Decoder,self).__init__()
        n_position = n_max_seq + 1
        self.position_enc = nn.Embedding(
            n_position,d_pos,padding_idx=Constants.PAD)
        self.position_enc.weight.data = position_encoding_init(
            n_position,d_pos)
        self.dropout = nn.Dropout(dropout)
        self.layer_stack = nn.ModuleList([
            DecoderLayer(d_model,d_inner_hid,n_head,d_k,d_v,dropout=dropout)
            for _ in range(n_layers)
        ])
    def forward(self,tgt_seq,tgt_pos,src_seq,enc_output,return_attns=False):
        dec_input = tgt_seq
        dec_input += self.position_enc(tgt_pos)
        dec_slf_attn_mask = get_attn_padding_mask(
            tgt_seq[:,:,0],tgt_seq[:,:,0])
        dec_enc_attn_pad_mask = get_attn_padding_mask(
            tgt_seq[:,:,0],src_seq[:,:,0])
        if return_attns:
            dec_slf_attns,dec_enc_attns = [],[]
        dec_output = dec_input
        for dec_layer in self.layer_stack:
            dec_output,dec_slf_attn,dec_enc_attn = dec_layer(
                dec_output,enc_output,
                slf_attn_mask=dec_slf_attn_mask,
                dec_enc_attn_mask=dec_enc_attn_pad_mask)
            if return_attns:
                dec_slf_attns += [dec_slf_attn]
                dec_enc_attns += [dec_enc_attn]
        if return_attns:
            return dec_output,dec_slf_attns,dec_enc_attns
        return dec_output