In [2]:
import torch
import torch.nn as nn
import numpy as np

In [3]:
class Deep_NMT(nn.Module):
    def __init__(self,source_vocab_size,target_vocab_size,embedding_size,
                 source_length,target_length,lstm_size):
        super(Deep_NMT,self).__init__()
        self.source_embedding =nn.Embedding(source_vocab_size,embedding_size)
        self.target_embedding = nn.Embedding(target_vocab_size,embedding_size)
        self.encoder = nn.LSTM(input_size=embedding_size,hidden_size=lstm_size,num_layers=4,
                               batch_first=True) # if batch_first==False: input_shape=[length,batch_size,embedding_size]
        self.decoder = nn.LSTM(input_size=embedding_size,hidden_size=lstm_size,num_layers=4,
                               batch_first=True)
        self.fc = nn.Linear(lstm_size, target_vocab_size)
    def forward(self, source_data,target_data, mode = "train"):
        source_data_embedding = self.source_embedding(source_data) # batch_size*length*embedding_size
        enc_output, enc_hidden = self.encoder(source_data_embedding)
        # enc_output.shape: batch_size*length*lstm_size 只返回最高层的所有hidden
        # enc_hidden：[[h1,h2,h3,h4],[c1,c2,c3,c4]] 返回每层最后一个时间步的h和c
        if mode=="train":
            target_data_embedding = self.target_embedding(target_data) # batch_size*length*embedding_size

            dec_output, dec_hidden = self.decoder(target_data_embedding,enc_hidden)
            # dec_output.shape: batch_size*length*lstm_size 只返回最高层的所有hidden
            # dec_hidden：[[h1,h2,h3,h4],[c1,c2,c3,c4]] 返回每层最后一个时间步的h和c
            outs = self.fc(dec_output) # batch_size*length*target_vocab_size
        else:
            target_data_embedding = self.target_embedding(target_data) # batch_size*1*embedding_size
            dec_prev_hidden = enc_hidden # [[h1,h2,h3,h4],[c1,c2,c3,c4]]
            outs = []
            for i in range(100):
                dec_output, dec_hidden = self.decoder(target_data_embedding, dec_prev_hidden)
                # dec_output.shape: batch_size*1*lstm_size 只返回最高层的所有hidden
                # dec_hidden：[[h1,h2,h3,h4],[c1,c2,c3,c4]] 返回每层最后一个时间步的h和c
                pred = self.fc(dec_output) # batch_size*1*target_vocab_size
                pred = torch.argmax(pred,dim=-1) # batch_size*1
                outs.append(pred.squeeze().cpu().numpy())
                dec_prev_hidden = dec_hidden # [[h1,h2,h3,h4],[c1,c2,c3,c4]]
                target_data_embedding = self.target_embedding(pred) # batch_size*1*embedding_size
        return outs

In [4]:
deep_nmt = Deep_NMT(source_vocab_size=30000,target_vocab_size=30000,embedding_size=256,
                 source_length=100,target_length=100,lstm_size=256)
source_data = torch.Tensor(np.zeros([64,100])).long()
target_data = torch.Tensor(np.zeros([64,100])).long()
preds = deep_nmt(source_data,target_data)
print (preds.shape)
target_data = torch.Tensor(np.zeros([64, 1])).long()
preds = deep_nmt(source_data, target_data,mode="test")
print(np.array(preds).shape)

torch.Size([64, 100, 30000])
(100, 64)
