In [1]:
import torch
import torch.nn as nn
import numpy as np
from torch.nn import functional as F
from torch.autograd import Variable

In [2]:
def score_dot(H_e, h_d, fc_layer_1=None, fc_layer_2=None, fc_layer_3=None):
    '''
    Get attention score throught dot function
    :param H_e: encoder hiddens as batch_size * source_length * hidden_size
    :param h_d:decoder hidden as batch_size * hidden_size
    :return: attention score as batch_size * length * 1
    '''
    h_d = h_d.unsqueeze(2)
    #bs*len*hidden_size . bs * hidden_size*1
    attention_score = torch.matmul(H_e, h_d) # batch_size * source_length * 1
    attention_score = F.softmax(attention_score, dim=1)
    return attention_score

In [3]:
def score_generate(H_e, h_d,fc_layer_1=None, fc_layer_2=None, fc_layer_3=None):
    H_e = fc_layer_1(H_e)
    attention_score = score_dot(H_e, h_d)
    return attention_score

In [4]:
def score_concat(H_e, h_d, fc_layer_1=None, fc_layer_2=None, fc_layer_3=None):
    h_d = h_d.unsqueeze(1).repeat([1, H_e.size()[1], 1])
    # bs*hidden_size -> bs*1*hidden_size -> bs*len*hidden_size
    attention_score = fc_layer_3(F.tanh(fc_layer_2(torch.cat([H_e, h_d], dim=2)))) # bs*len*2hidden_size -> bs*len*hidden_size -> bs *len*1
    attention_score = F.softmax(attention_score, dim=1)
    return attention_score

In [5]:
def local_m(h_d, t):
    pt = torch.ones([h_d.size()[0],1]) * t
    return pt

In [6]:
def local_p(h_d, t, fc_layer_1, fc_layer_2, seq_len):
    pt = seq_len * F.sigmoid(fc_layer_2(F.tanh(fc_layer_1(h_d))))
    return pt

In [7]:
def local_score(attention_score, pt, seq_len, sigma):
    pt = pt.unsqueeze(2)
    s = torch.range(0, seq_len-1)
    s = s.view([1, seq_len, 1]).repeat([attention_score.size()[0],1,1])
    attention_score = attention_score * torch.exp(-(s - pt)**2/(2*sigma**2))
    attention_score = attention_score/torch.sum(attention_score, dim=1, keepdim=True)
    return attention_score

In [8]:
class Loung_NMT(nn.Module):
    def __init__(self,source_vocab_size,target_vocab_size,embedding_size,
                 lstm_size, score_f, attention_c, feed_input, window_size=10, reverse=True):
        super(Loung_NMT,self).__init__()
        self.score_f = score_f
        self.attention_c = attention_c
        self.window_size = window_size
        self.feed_input = feed_input
        self.reverse = reverse
        self.source_embedding =nn.Embedding(source_vocab_size,embedding_size)
        self.target_embedding = nn.Embedding(target_vocab_size,embedding_size)
        self.encoder = nn.LSTM(input_size=embedding_size,hidden_size=lstm_size,num_layers=1,
                               batch_first=True) # seq_len*batch_size*embedding_size , batch_size*seq_len*embedding_size
        if not feed_input:
            self.decoder = nn.LSTM(input_size=embedding_size, hidden_size=lstm_size, num_layers=1,
                                   batch_first=True)
        else:
            self.decoder = nn.LSTM(input_size=embedding_size+lstm_size,hidden_size=lstm_size,num_layers=1,
                                   batch_first=True)
        self.class_fc_1 = nn.Linear(lstm_size+lstm_size, lstm_size) # 分类全连接层1
        self.class_fc_2 = nn.Linear(lstm_size, target_vocab_size) # 分类全连接层2

        self.attention_fc_1 = nn.Linear(lstm_size, lstm_size)
        self.attention_fc_2 = nn.Linear(2*lstm_size, lstm_size)
        self.attention_fc_3 = nn.Linear(lstm_size, 1)

        self.local_fc_1 = nn.Linear(lstm_size, lstm_size)
        self.local_fc_2 = nn.Linear(lstm_size, 1)

    def attention_forward(self,input_embedding, feed_input_h, dec_prev_hidden, enc_output, t):
        if not self.feed_input:
            dec_lstm_input = input_embedding
        else:
            dec_lstm_input = torch.cat([input_embedding, feed_input_h], dim=2) # bs * 1 * (embed_size+hidden_size)
        dec_output, dec_hidden  = self.decoder(dec_lstm_input, dec_prev_hidden) 
        # dec_output: bs*1*lstm_size, dec_hidden:(1*bs*lstm_size, 1*bs*lstm_size)
        if self.score_f == "dot":
            attention_weights = score_dot(enc_output, dec_hidden[0].squeeze(), self.attention_fc_1, self.attention_fc_2, self.attention_fc_3)
        elif self.score_f == "general":
            attention_weights = score_generate(enc_output, dec_hidden[0].squeeze(), self.attention_fc_1, self.attention_fc_2, self.attention_fc_3)
        elif self.score_f == "concat":
            attention_weights = score_concat(enc_output, dec_hidden[0].squeeze(), self.attention_fc_1, self.attention_fc_2, self.attention_fc_3)
        else:
            print ("Attention score function input error!")
            exit()
        if self.attention_c == "local_m":
            if self.reverse:
                t = enc_output.size()[1]-1-t
            pt = local_m(dec_hidden[0].squeeze(), t)
            attention_weights = local_score(attention_weights, pt, enc_output.size()[1], self.window_size/2)
        elif self.attention_c == "local_p":
            pt = local_p(dec_hidden[0].squeeze(), t, self.local_fc_1, self.local_fc_2, enc_output.size()[1])
            attention_weights = local_score(attention_weights, pt, enc_output.size()[1], self.window_size / 2)
        elif self.attention_c == "global":
            pass
        else:
            print ("Attention class input error!")
            exit()
        atten_output = torch.sum(attention_weights * enc_output, dim=1).unsqueeze(1) # bs*1*hidden_size
        return atten_output,dec_output,dec_hidden
    def forward(self, source_data,target_data, mode = "train",is_gpu=True):
        source_data_embedding = self.source_embedding(source_data)
        enc_output, enc_hidden = self.encoder(source_data_embedding)
        # enc_output: bs*len*hidden_size, (1*bs*hidden_size, 1*bs*hidden_size)
        self.atten_outputs = Variable(torch.zeros(target_data.shape[0],
                                                  target_data.shape[1],
                                                  enc_output.shape[2]))
        self.dec_outputs = Variable(torch.zeros(target_data.shape[0],
                                                target_data.shape[1],
                                                enc_hidden[0].shape[2]))
        if is_gpu:
            self.atten_outputs = self.atten_outputs
            self.dec_outputs = self.dec_outputs
        # enc_output: bs*length*(2*lstm_size)
        if mode=="train":
            target_data_embedding = self.target_embedding(target_data)
            dec_prev_hidden = [enc_hidden[0],enc_hidden[1]]
            # dec_prev_hidden[0]: 1*bs*lstm_size, dec_prev_hidden[1]: 1*bs*lstm_size
            # dec_h: bs*lstm_size
            feed_input_h = enc_hidden[0].squeeze().unsqueeze(1) # 1*bs*hidden_size -> bs*hidden_size -> bs *1 *hidden_size
            for i in range(100):
                input_embedding = target_data_embedding[:,i,:].unsqueeze(1)  # bs *1 *embedding_size
                atten_output, dec_output, dec_hidden = self.attention_forward(input_embedding,
                                                                              feed_input_h,
                                                                              dec_prev_hidden,
                                                                              enc_output, i)
                self.atten_outputs[:,i] = atten_output.squeeze()
                self.dec_outputs[:,i] = dec_output.squeeze()
                dec_prev_hidden = dec_hidden
                feed_input_h = F.tanh(self.class_fc_1(torch.cat([atten_output,dec_output],dim=2)))
            outs = self.class_fc_2(F.tanh(self.class_fc_1(torch.cat([self.atten_outputs,self.dec_outputs],dim=2))))
        else:
            input_embedding = self.target_embedding(target_data)
            dec_prev_hidden = [enc_hidden[0], enc_hidden[1]]
            outs = []
            feed_input_h = enc_hidden[0].squeeze(0).unsqueeze(1)
            for i in range(100):
                atten_output, dec_output, dec_hidden = self.attention_forward(input_embedding,
                                                                              feed_input_h,
                                                                              dec_prev_hidden,
                                                                              enc_output, i)

                feed_input_h = F.tanh(self.class_fc_1(torch.cat([atten_output,dec_output],dim=2)))
                pred = self.class_fc_2(feed_input_h)
                pred = torch.argmax(pred,dim=-1)
                outs.append(pred.squeeze().cpu().numpy())
                dec_prev_hidden = dec_hidden
                input_embedding = self.target_embedding(pred)
        return outs

In [9]:
deep_nmt = Loung_NMT(source_vocab_size=30000,target_vocab_size=30000,embedding_size=256,
                         lstm_size=256, score_f="dot", attention_c="global", feed_input=True, window_size=10, reverse=True)
source_data = torch.Tensor(np.zeros([64,100])).long()
target_data = torch.Tensor(np.zeros([64,100])).long()
preds = deep_nmt(source_data,target_data,is_gpu=False)
print (preds.shape)
target_data = torch.Tensor(np.zeros([64, 1])).long()
preds = deep_nmt(source_data, target_data,mode="test",is_gpu=False)
print(np.array(preds).shape)

torch.Size([64, 100, 30000])
(100, 64)
