# translation

In [1]:
import os
import shutil
import torch
import h5py
import numpy as np
from torch import nn, optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook
import matplotlib.pyplot as plt
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image

from dataset import TextDataset                          # custom module
from model.seq2seq import AttnDecoderRNN, DecoderRNN, EncoderRNN


# use_cuda = torch.cuda.is_available()                     # gpu可用
use_cuda = False
device = torch.device('cuda' if use_cuda else 'cpu')     # 优先使用gpu

  from ._conv import register_converters as _register_converters


## 1.读入数据，设置参数

In [2]:
SOS_token = 0
EOS_token = 1
MAX_LENGTH = 10
SOS_token = 0                                           # 词汇表中的起止符
EOS_token = 1

In [3]:
lang_dataset = TextDataset()

Reading lines...
Read 135842 sentence pairs
Trimmed to 10853 sentence pairs
Counting words...
Counted words:
fra 4489
eng 2925
['elles ne sont pas du tout interessees .', 'they are not at all interested .']


In [4]:
lang_loader = DataLoader(lang_dataset, shuffle=True)   # use num_workers=4 can be slower???

In [5]:
%%time
for data in lang_loader:
    x, y = data
    print('source language:', x)
    print('target language:', y)
    break

source language: tensor([[   6, 2737,  299,  264,    5,    1]])
target language: tensor([[  2,   3, 798, 753,  42, 796,   4,   1]])
Wall time: 9.03 ms


### 注意数据集批次数据末尾的1表示`<EOS>`结束符，每个数据对都有

In [6]:
input_size = lang_dataset.input_lang_words
output_size = lang_dataset.output_lang_words
print('source language vocabulary size:', input_size)
print('target language vocabulary size:', output_size)

source language vocabulary size: 4489
target language vocabulary size: 2925


## 2. 训练模型

In [7]:
# Define hyperparameters
hidden_size = 128
epochs = 20
batch_size = 1
use_attn = False       # the sign whether to use attention

In [8]:
# define the model
encoder = EncoderRNN(vocab_size=input_size, hidden_size=hidden_size, n_layers=2)         # prefer using gpu
decoder = DecoderRNN(vocab_size=output_size, hidden_size=hidden_size, n_layers=2)
attn_decoder = AttnDecoderRNN(vocab_size=output_size, hidden_size=hidden_size, n_layers=2)

In [17]:
# train the translation model
import time
plot_losses = []
def train(encoder, decoder, epochs, use_attn):
    """
        func: train the seq2seq translation
        encoder: RNN encoding model
        decoder: RNN decoding model
        epochs: the number of iteration
        use_attn: whether to use attention, True represents use 
    """
    param = list(encoder.parameters()) + list(decoder.parameters())         # the parameters preparaed to optimize
    optimizer = optim.Adam(param, lr=1e-3)                                  # define the opimizer
    criterion = nn.NLLLoss()                                                # define the negative log likelihood loss
    for epoch in range(epochs):
        start = time.time()                                                 # calculate consumption time
        running_loss = 0                                                    # loss value during training
        total_loss = 0                                                      # total loss after training
        plt_loss_total = 0                                                  # loss value for plot
        print('{}epoch:{}{}'.format('-'*15, epoch+1, '-'*15))
        for i, data in enumerate(lang_loader):
            in_lang, out_lang = data
#             print(in_lang.shape, out_lang.shape)
            GPU = lambda x:x.to(device)                                     # prefer using gpu
            in_lang, out_lang = map(GPU, [in_lang, out_lang])               
            encoder, decoder = map(GPU, [encoder, decoder])
            
            # 1.encode source language to a context
            # encoder_outputs is used to attention decoder，该变量只用于注意力解码中
            encoder_outputs = torch.zeros([batch_size, MAX_LENGTH, hidden_size]).to(device)  # create a zero output
            encoder_hidden  = encoder.initHidden().to(device)               # initialize h0 state
            # in_lang:(N,seq)
            
            for seq_idx in range(in_lang.shape[0]):
                encoder_output, encoder_hidden = encoder(in_lang[:, seq_idx:seq_idx+1], encoder_hidden)
                encoder_outputs[:,seq_idx:seq_idx+1] = encoder_output[:,:]           # [1, 1, hidden]
#                 print(encoder_outputs[:,seq_idx:seq_idx+1].shape,  encoder_output[:,:].shape)            
#             print(encoder_outputs.shape)                                 # (1,10,hidden)
            
            # 2.decode context to seqence
            decoder_input = torch.LongTensor([[SOS_token]]).to(device)     # note [[]] =>  get (1,1)
            decoder_hidden = encoder_hidden                                # (1,1,hidden) context tensor !!!
            loss = 0
            # use attention to decode context
            if use_attn:
                for seq_idx in range(out_lang.shape[1]):
                    # decoder_output: (1,1,vocab_size) decoder_hidden: (1,1,hidden)
                    # decoder_attention: (1, 1, MAX_LENGTH)
                    decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden,
                                                                               encoder_outputs)
                    loss += criterion(decoder_output.view(1,-1), out_lang[:,seq_idx:seq_idx+1][0])
                    topv,topi = decoder_output.data.topk(1)
                    if topi.item() == EOS_token:
                        break
                    decoder_input = out_lang[:,seq_idx:seq_idx+1]      # (1,1)
            # use no attention to decode context                    
            else:      
                for seq_idx in range(out_lang.shape[1]):
                    # decoder_output: (1,1,vocab_size) decoder_hidden: (1,1,hidden)
                    decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
                                        
                    # predict:(1,vocab_size)   label:(1)
                    loss += criterion(decoder_output.view(1,-1), out_lang[:,seq_idx:seq_idx+1][0])
                    topv, topi = decoder_output.data.topk(1)          # select the maximum value
                    if topi.item() == EOS_token:                      # finish decoding
                        break
                    # teacher forcing指在训练过程中直接使用正确的标签来进行解码器的训练，当然可以取当前时刻输出的
                    # 最大值作为下一个时刻的输入，但teacher forcing可以加速收敛，但在测试时下一个时刻的输入必须是当前
                    # 时刻输出的值(输出字典大小的向量，取最大的值对应的索引当做输出的词)
                    decoder_input = out_lang[:,seq_idx:seq_idx+1]   # use teacher forcing to accelerate convergence!!!
                    
            optimizer.zero_grad()           # clear the gradient
            loss.backward()                 # backpropagation
            optimizer.step()                # update the parameters
            running_loss += loss.item()
            plt_loss_total += loss.item()
            total_loss += loss.item()
            if (i + 1) % 500 == 0:
                print('{}/{}, Loss:{:.6f}'.format(i+1, len(lang_loader), running_loss/5000))
                running_loss = 0
            if (i + 1) % 100 == 0:
                plot_loss = plt_loss_total / 100
                plot_losses.append(plot_loss)
                plt_loss_total = 0

        epoch_time = time.time() - start                            # calculate consumption time
        print('Finish {}/{}, Loss:{:.6f}, Time:{:.0f}s'.format(epoch+1, epochs, 
                                                               total_loss/len(lang_loader), epoch_time))
    attn = 'attn_' if use_attn else ''
    torch.save(encoder.state_dict(), os.path.join('snapshot', 'encoder.pth'))
    torch.save(decoder.state_dict(), os.path.join('snapshot', attn+'decoder.pth'))

# train(encoder, decoder, 1, use_attn=False)
train(encoder, attn_decoder, 1, use_attn=True)

---------------epoch:1---------------
Finish 1/1, Loss:0.005845, Time:0s
