# Fall2018 DS-GA 1011 NLP Final Project
------
Team: Weicheng Zhu,Yiyi Zhang, Zhengyuan Ding, Zihao Zhao

Topic
----
The goal of this project is for you to build a neural machine translation system and experience how recent advances have made their way. Each team will build the following sequence of neural translation systems for two language pairs, Vietnamese (Vi)→English (En) and Chinese (Zh)→En (prepared corpora will be provided):

- Recurrent neural network based encoder-decoder without attention
- Recurrent neural network based encoder-decoder with attention
- Replace the recurrent encoder with either convolutional or self-attention based encoder.
- [Optional] Build either or both fully self-attention translation system or/and multilingual translation system.

You are expected to implement these on your own (if necessary), experiment them with both language pairs, report their performance (measured in terms of automatic evaluation metrics) and analyze their behaviours and properties. 


# Self Attention Model

In [1]:
#import related packages
from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import string
import re
import random
import numpy as np
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
from sacrebleu import corpus_bleu
from tensorboardX import SummaryWriter
import math, copy, time
from torch.autograd import Variable

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Preprocess Data

In [2]:
# Fix the token index for pad, sos, eos and unk in the dictionary
PAD_TOKEN = 0
SOS_TOKEN = 1
EOS_TOKEN = 2
UNK_TOKEN = 3

# Fix the tag for pad, sos, eos and unk in the dictionary
PAD_TAG = "<pad>"
SOS_TAG = "<sos>"
EOS_TAG = "<eos>"
UNK_TAG = "<unk>"

#fix the maximum sentence length to filter
MAX_LEN = 200

class Lang:
    """A language vocabulary
    
    Attributes:
        name: the name of the language
        word2index: a dict where keys are words and values are indices
        word2count: a dict where keys are words and values are corresponding word count 
        index2word: a dict where keys are indices and values are words
        n_words: number of words in the vocabulary
    """
    
    def __init__(self, name):
        """Initialize the Lang object with given language name as input"""
        self.name = name
        self.word2index = {}
        self.word2count = {}
        #Set initial index2word with pad, sos, eos and unk
        self.index2word = {0: PAD_TAG, 1: SOS_TAG,2:EOS_TAG, 3:UNK_TAG} 
        # Count PAD, SOS, EOS, UNK
        self.n_words = 4

    def addSentence(self, sentence):
        """ Tokenizes the input sentence and add the tokens to the language vocabulary"""
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        """ Adds the given word into the language vocabulary if not existed"""
        #check if the word is in the vocabulary using word2index
        if word not in self.word2index: 
            self.word2index[word] = self.n_words
            self.word2count[word] = 1 
            self.index2word[self.n_words] = word 
            self.n_words += 1
        # if the given word is already in the vocabulary
        else: 
            self.word2count[word] += 1 
            
def unicodeToAscii(s):
    """
    Turn a Unicode string to plain ASCII, thanks to 
    http://stackoverflow.com/a/518232/2809427
    """
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

def normalizeString(s):
    """Lowercase, trim, and remove non-letter characters"""
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s

def removeSpace(s):
    """Replace multiple spaces in Chinese tokens by just one space"""
    s = re.sub(' +',' ', s).strip()
    return s

def readLangs(lang1, lang2, datasettype, reverse=False):
    """Read token files and store them as Lang objects
    
    Args:
        lang1, lang2: two language names
        datasettype: a string('train'/'dev'/'test') to indicate train/validatoin/test set
        reverse: whether to reverse order of two languages
        
    Returns:
        input_lang: Lang object of input language vocabulary
        output_lang: Lang object of output language vocabulary
        pairs: list of sentence pairs
    """
    
    print("Reading lines...")

    filename1 = '%s.tok.%s' % (datasettype, lang1)
    filename2 = '%s.tok.%s' % (datasettype, lang2)
        
    # Read the file and split into lines
    lines_1 = open(filename1, encoding='utf-8').read().strip().split('\n')
    lines_2 = open(filename2, encoding='utf-8').read().strip().split('\n')

    # Split every line into pairs and normalize
    pairs = [[normalizeString(lines_1[i]),removeSpace(lines_2[i])] for i in range(len(lines_1))]
    print('Pair1:', pairs[1])
    
    # Reverse pairs, make Lang instances
    #if reverse is True, input language is lang2 and output language is lang1
    if reverse:
        pairs = [list(reversed(p)) for p in pairs]
        input_lang = Lang(lang2)
        output_lang = Lang(lang1)
    else:
        input_lang = Lang(lang1)
        output_lang = Lang(lang2)

    return input_lang, output_lang, pairs


def filterPair(p):
    """Check whether the sentence pair exceed MAX_LEN
    
    Arg:
        p: a list consists of one sentence pair
        
    Returns:
        a boolean to indicate whether to filter the given sentence pair
    """
    return len(p[0].split(' ')) < MAX_LEN and len(p[1].split(' ')) < MAX_LEN 


def filterPairs(pairs):
    """Filter out sentence pairs that exceed MAX_LEN
    
    Arg: 
        pairs: list of sentence pairs
    
    Returns:
        a list filtered sentence pairs
    """
    return [pair for pair in pairs if filterPair(pair)]


def prepareData(lang1, lang2, datasettype, reverse=False):
    """Preprocess the token files
    
     Args:
        lang1, lang2: two language names
        datasettype: a string('train'/'dev'/'test') to indicate train/validatoin/test set
        reverse: whether to reverse order of two languages
        
    Returns:
        input_lang: Lang object of input language vocabulary after preprocessing
        output_lang: Lang object of output language vocabulary after preprocessing
        pairs: list of sentence pairs after preprocessing
    """
    #read token files and store in Lang objects
    input_lang, output_lang, pairs = readLangs(lang1, lang2, datasettype, reverse)
    print("Read %s sentence pairs" % len(pairs))
    
    #filter sentence pairs
    pairs = filterPairs(pairs)
    print("Trimmed to %s sentence pairs" % len(pairs))
    
    #build language vocabularies
    print("Counting words...")
    for pair in pairs:
        input_lang.addSentence(pair[0])
        output_lang.addSentence(pair[1])
    print("Counted words:")
    print(input_lang.name, input_lang.n_words)
    print(output_lang.name, output_lang.n_words)
    
    return input_lang, output_lang, pairs

# input_lang, output_lang, pairs = prepareData('en', 'zh', 'train', True)
# print(random.choice(pairs))

# Construct Dataset Pipeline

In [3]:
BATCH_SIZE = 16

def indexesFromSentence(lang, sentence):
    """Convert given sentence to vectors according to language vocabulary"""
    #set word to unknown index if not in the vocabulary
    return [lang.word2index[word]  
            if word in lang.word2index else 3 for word in sentence.split(' ')]

def tensorFromSentence(lang, sentence):
    """Convert given sentence to tensor"""
    indexes = indexesFromSentence(lang, sentence)
    #append <eos> to each sentence to indicate the end
    indexes.append(EOS_TOKEN)
    return torch.tensor(indexes, dtype=torch.long, device=device)
    
def tensorsFromPair(pair):
    """Convert sentence pairs to input tensors and output tensors"""
    input_tensor = tensorFromSentence(input_lang, pair[0])
    target_tensor = tensorFromSentence(output_lang, pair[1]) 
    return (input_tensor, target_tensor)
    

class Dataset(Dataset):
    """Construct dataset fit for pytorch input
    
    Attributes:
        input_lang: a Lang object of input language vocabulary
        output_lange: a Lang object of output language vocabulary
        pairs: list of sentence pairs
    """
    def __init__(self,datasettype):
        #input data from English and Chinese token files
        input_lang, output_lang, pairs = prepareData('en', 'zh',datasettype, True) 
        self.input_lang = input_lang
        self.output_lang = output_lang
        self.pairs = pairs
        
    def __len__(self):
        #number of datapoints
        return len(self.pairs)
    
    def __getitem__(self, index):
        #get source and target sentence tensors by given index
        src = tensorsFromPair(self.pairs[index])[0]
        trg = tensorsFromPair(self.pairs[index])[1]
        return src, trg

def collate_fn(data):
    """Collate function for pytorch dataloader"""
    
    def _pad_sequences(seqs):
        """pad sentences with zeros to the maximum length
        
        Args:
            seqs: given list of sentences
        
        Returns:
            padded_seqs: list of padded sentences
            lens: a list storing lengths of the sentences
        """
        lens = [len(seq) for seq in seqs] #store lengths in a list
        padded_seqs = torch.zeros(len(seqs), max(lens)).to(device) #pad by zero
        for i, seq in enumerate(seqs):
            end = lens[i]
            padded_seqs[i, :end] = torch.cuda.LongTensor(seq[:end])
        return padded_seqs, lens
    
    #sort according to length of src seqs
    data.sort(key=lambda x: len(x[0]), reverse=True) 
    src_seqs, trg_seqs = zip(*data)
    #pad sentences
    src_seqs, src_lens = _pad_sequences(src_seqs)
    trg_seqs, trg_lens = _pad_sequences(trg_seqs)

    #(batch, seq_len) => (seq_len, batch)
    src_seqs = src_seqs.transpose(0,1)
    trg_seqs = trg_seqs.transpose(0,1)

    return src_seqs, src_lens, trg_seqs, trg_lens

#load and verify dataset 
input_lang, output_lang, pairs = prepareData('en', 'zh', 'train', True)
val_dataset = Dataset('dev')
val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                           batch_size=BATCH_SIZE,
                                           collate_fn=collate_fn,
                                           shuffle=True)

print("Number of source-target pairs:", len(val_dataset))
print("Input language: "+ val_dataset.input_lang.name + '('+str(val_dataset.input_lang.n_words)+')')
print("Output language: "+ val_dataset.output_lang.name + '('+str(val_dataset.output_lang.n_words)+')')
train_data = Dataset('train')

Reading lines...
Pair1: ['with vibrant video clips captured by submarines david gallo takes us to some of earth apos s darkest most violent toxic and beautiful habitats the valleys and volcanic ridges of the oceans apos depths where life is bizarre resilient and shockingly abundant .', '大卫 盖罗 通过 潜水 潜水艇 拍下 的 影片 把 我们 带到 了 地球 最 黑暗 最 险恶 同时 也 最美 美丽 的 生物 栖息 栖息地 这里 是 海洋 深处 的 峡谷 和 火山 山脊 这里 怪诞 适应 适应力 应力 强 而且 数量 惊人 的 生命']
Read 213376 sentence pairs
Trimmed to 213334 sentence pairs
Counting words...
Counted words:
zh 88784
en 50875
Reading lines...
Pair1: ['my father was listening to bbc news on his small gray radio .', '我 的 父亲 在 用 他 的 灰色 小 收音 收音机 听 BBC 新闻']
Read 1261 sentence pairs
Trimmed to 1261 sentence pairs
Counting words...
Counted words:
zh 6133
en 3671
Number of source-target pairs: 1261
Input language: zh(6133)
Output language: en(3671)
Reading lines...
Pair1: ['with vibrant video clips captured by submarines david gallo takes us to some of earth apos s darkest most violent toxic and be

# Models

In [4]:
# Configure models
attn_model = 'dot'
hidden_size = 256
embed_size = 256
n_layers = 1
dropout = 0.1
batch_size = 16
checkpoint_dir = "checkpoints"

# Configure training/optimization
clip = 50
learning_rate = 0.001
decoder_learning_ratio = 5.0
n_epochs = 10

In [None]:
class Attn(nn.Module):
    def __init__(self, hidden_size):
        super(Attn, self).__init__()
        
#         self.method = method
        self.hidden_size = hidden_size
        
#         if self.method == 'general':
#             self.attn = nn.Linear(self.hidden_size, hidden_size)
#         elif self.method == 'concat':
        self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
        self.v = nn.Parameter(torch.rand(1, hidden_size))
        stdv = 1. / math.sqrt(self.v.size(0))
        self.v.data.normal_(mean=0, std=stdv)
        
    def forward(self, last_hidden, encoder_outputs, src_len=None):
        
        # Create variable to store attention energies
        length = encoder_outputs.shape[0]
        batch_size = encoder_outputs.shape[1]
        last_hidden = last_hidden.repeat(length,1,1)
        energy =  torch.tanh(self.attn(torch.cat([last_hidden, encoder_outputs],dim=2)))
        energy = energy.transpose(0, 1).transpose(1,2)
        v = self.v.repeat(batch_size,1 , 1)
        score = torch.bmm(v, energy)
        if src_len is not None:
            mask = []
            for b in range(batch_size):
                mask.append([0] * src_len[b] + [1] * (encoder_outputs.size(0) - src_len[b]))
            mask = (torch.cuda.ByteTensor(mask).unsqueeze(1)) # [B,1,T]
            score = score.masked_fill(mask, -1e9)
        attn_weights = F.softmax(score, dim = 2)   
        context_vector = torch.bmm(attn_weights, encoder_outputs.transpose(0,1))
        # For each batch of encoder outputs
        # Calculate energy for each encoder output
        # Normalize energies to weights in range 0 to 1, resize to 1 x B x S
        # Return context vectors
        return context_vector, attn_weights

def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1))              / math.sqrt(d_k)
    mask = mask.transpose(0,3).repeat(1,scores.size(1),scores.size(2),1)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = F.softmax(scores, dim = -1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        "Take in model size and number of heads."
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), h)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, query, key, value, mask=None):
        "Implements Figure 2"
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)
        # 1) Do all the linear projections in batch from d_model => h x d_k 
        query, key, value =             [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
             for l, x in zip(self.linears, (query, key, value))]
        # 2) Apply attention on all the projected vectors in batch. 
        x, self.attn = attention(query, key, value, mask=mask, 
                                 dropout=self.dropout)
        
        # 3) "Concat" using a view and apply a final linear. 
        x = x.transpose(1, 2).contiguous()              .view(nbatches, -1, self.h * self.d_k)
        return self.linears[-1](x)
    
class PositionwiseFeedForward(nn.Module):
    "Implements FFN equation."
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(F.relu(self.w_1(x))))


class PositionalEncoding(nn.Module):
    "Implement the PE function."
    def __init__(self, d_model, dropout, max_len=200):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0., max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0., d_model, 2) *
                             -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x.transpose(0,1)
        x = x + Variable(self.pe[:, :x.size(1)].repeat(x.size(0),1,1), 
                         requires_grad=False)
        return self.dropout(x)

class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)
    

def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

class SelfAttnEncoder(nn.Module):
    "Core encoder is a stack of N layers"
    def __init__(self, N, h, d_model, d_ff, src_vocab, dropout):
        super(SelfAttnEncoder, self).__init__()
        attn = MultiHeadedAttention(h, d_model)
        ff = PositionwiseFeedForward(d_model, d_ff, dropout)
        layer = EncoderLayer(d_model, c(attn), c(ff), dropout)
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        position = PositionalEncoding(d_model, dropout)
        self.src_embed= nn.Sequential(Embeddings(d_model, src_vocab), c(position))
        
    def forward(self, x, mask):
        "Pass the input (and mask) through each layer in turn."
        x = self.src_embed(x.long())
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x).transpose(0,1)
    
class LayerNorm(nn.Module):
    "Construct a layernorm module (See citation for details)."
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features, device=device))
        self.b_2 = nn.Parameter(torch.zeros(features, device=device))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
    

class SublayerConnection(nn.Module):
    """
    A residual connection followed by a layer norm.
    Note for code simplicity the norm is first as opposed to last.
    """
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        "Apply residual connection to any sublayer with the same size."
        return x + self.dropout(sublayer(self.norm(x)))
    
class EncoderLayer(nn.Module):
    "Encoder is made up of self-attn and feed forward (defined below)"
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)
        self.size = size

    def forward(self, x, mask):
        "Follow Figure 1 (left) for connections."
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)

def save_checkpoint(encoder, decoder, checkpoint_dir):
    enc_filename = "{}/enc-{}.pth".format(checkpoint_dir, time.strftime("%d%m%y-%H%M%S"))
    dec_filename = "{}/dec-{}.pth".format(checkpoint_dir, time.strftime("%d%m%y-%H%M%S"))
    torch.save(encoder.state_dict(), enc_filename)
    torch.save(decoder.state_dict(), dec_filename)
    print("Model saved.")

### Decoder

In [None]:
class DecoderRNN(nn.Module):
    def __init__(self, embedding_size, hidden_size, output_size, n_layer=1, dropout=0.1,attention=True):
        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size
#         self.embedding = nn.Embedding(output_size, hidden_size)
        self.dropout = dropout
        self.n_layers = n_layers
        self.embedding_size = embedding_size
        self.softmax = nn.LogSoftmax(dim=1)
        self.embedding = nn.Embedding(output_size, embedding_size, padding_idx=PAD_TOKEN)
        self.embedding_dropout = nn.Dropout(dropout)
#         self.gru = nn.GRU(hidden_size + embedding_size, hidden_size, n_layers, dropout=dropout, bidirectional=True)
        self.lstm = nn.LSTM(embedding_size, hidden_size, n_layers, dropout=dropout, bidirectional=False)
        self.concat = nn.Linear(hidden_size * 2, hidden_size)
#         self.out = nn.Linear(hidden_size * 2, output_size)
        self.out = nn.Linear(hidden_size, output_size)
        if attention:
            self.attn = Attn(hidden_size)
    
    def forward(self, input_seq, last_hidden=None, src_len = None, encoder_outputs=None):
        batch_size = input_seq.size(0)
        embedded = self.embedding(input_seq.long())
        embedded = embedded.view(1, batch_size, self.embedding_size) # 1*B*E
        embedded = self.embedding_dropout(embedded)
        # Get current hidden state from input word and last hidden state
        if last_hidden:
            last_hidden = (torch.sum(last_hidden[1], dim = 0).unsqueeze(0),
                           torch.sum(last_hidden[1], dim = 0).unsqueeze(0))
        rnn_output, hidden = self.lstm(embedded, last_hidden)
#         rnn_output = rnn_output[:, :, :self.hidden_size] + rnn_output[:, :, self.hidden_size:] #1 * B * H
        context, attn_weights = self.attn(rnn_output, encoder_outputs, src_len)
        rnn_output = rnn_output.squeeze(0) # S=1 x B x N -> B x N
        context = context.squeeze(1)       # B x S=1 x N -> B x N
        concat_input = torch.cat((rnn_output, context), 1)
        concat_output = torch.tanh(self.concat(concat_input))
        # Finally predict next token (Luong eq. 6, without softmax)
        output = F.log_softmax(self.out(concat_output), dim=1)
        # Return final output, hidden state, and attention weights (for visualization)
        return output, hidden, attn_weights

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

# Training Step

In [None]:
def train_step(src_batch, src_lens, trg_batch, trg_lens, encoder, decoder,
               encoder_optimizer, decoder_optimizer, criterion):
    # Zero gradients of both optimizers
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    loss, em_accuracy, edit_distance = 0.0, 0.0, 0.0
    # Run words through encoder
    batch_size = src_batch.shape[1]
    src_batch = src_batch.long()
    src_mask  = (src_batch != PAD_TOKEN).unsqueeze(-2)
    trg_mask = trg_batch.to(device)
    trg_mask = (trg_mask != PAD_TOKEN).float()
    encoder_outputs = encoder(src_batch, src_mask)
    # Prepare input and output variables
    decoder_input = torch.LongTensor([SOS_TOKEN] * batch_size).to(device)
#     decoder_hidden = encoder_hidden[:decoder.n_layers*2] # Use last (forward) hidden state from encoder
    max_trg_len = max(trg_lens)
    decoder_hidden = None
    # Run through decoder one time step at a time using TEACHER FORCING=1.0
    TEACHER_FORCING = 1
    for t in range(max_trg_len):
        decoder_output, decoder_hidden, attn_weights = decoder(decoder_input, 
                                                 decoder_hidden , src_lens, encoder_outputs)
        if TEACHER_FORCING:
            decoder_input = trg_batch[t]
        else:
            topv, topi = decoder_output.topk(1)
            decoder_input = topi  # detach from history as input
        loss += (criterion(decoder_output, trg_batch[t].long()) * trg_mask[t]).mean()
    loss = loss / max_trg_len
    loss.backward()
    
    # Clip gradient norms
    enc_grads = torch.nn.utils.clip_grad_norm_(encoder.parameters(), clip)
    dec_grads = torch.nn.utils.clip_grad_norm_(decoder.parameters(), clip)

    # Update parameters with optimizers
    encoder_optimizer.step()
    decoder_optimizer.step()
    
    return loss.item(), em_accuracy, edit_distance #, enc_grads, dec_grads        

def train(dataset, batch_size, n_epochs, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, 
          checkpoint_dir=None, save_every=3000):
    
    writer=SummaryWriter('./tensorboard/self-attention-ZHEN')
    global_step = 0 
    

    train_iter = torch.utils.data.DataLoader(dataset=train_data,
                                           batch_size=BATCH_SIZE,
                                           collate_fn=collate_fn,
                                           shuffle=True)
    
    fake_test = torch.utils.data.DataLoader(dataset=[train_data[i] for i in range(BATCH_SIZE * 10)],
                                           batch_size=BATCH_SIZE,
                                           collate_fn=collate_fn,
                                           shuffle=False)
    for i in range(n_epochs):
        tick = time.process_time()
        print("Epoch {}/{}".format(i+1,  n_epochs))
        losses, accs, eds = [], [], []
        
        for batch_idx, batch in enumerate(train_iter):
            
            global_step += 1
            
            input_batch, input_lengths, target_batch, target_lengths = batch
            loss, accuracy, edit_distance = train_step(input_batch, input_lengths, target_batch, target_lengths,
                                                       encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)
            
            losses.append(loss)
            accs.append(accuracy)
            eds.append(edit_distance)
            
            writer.add_scalar('loss', loss, global_step)

            if batch_idx % 100 == 0 and batch_idx != 0:
                print("batch: {}, loss: {}, accuracy: {}, edit distance: {}".format(batch_idx, loss, accuracy, 
                                                                                   edit_distance))
            
        tock = time.process_time()
        print("Time: {} Avg loss: {} Avg acc: {} Edit Dist.: {}".format(
            tock-tick, np.mean(losses), np.mean(accs), np.mean(eds)))
        save_checkpoint(encoder, decoder, checkpoint_dir)
        bleu = evaluate(encoder, decoder, val_loader)
        writer.add_scalar('val_bleu', bleu, global_step)
        print('real:', bleu)

        bleu_1 = evaluate(encoder, decoder, fake_test)
        writer.add_scalar('train_bleu', bleu_1, global_step)
        print('train_bleu:',bleu_1)

# Evaluation

In [None]:
def evaluate(encoder, decoder, test_loader, k=1, max_length=None):
    output = []
    h_t = []
    p = []
    score = 0
    count = 0
    for batch_idx, batch in enumerate(test_loader):
        src_batch, src_lens, trg_batch, trg_lens = batch
        batch_size = src_batch.shape[1]
#         pos_index = Variable(torch.LongTensor(range(batch_size)) * k).view(-1, 1)
        with torch.no_grad():
            decoded_sentences = []
            for b in range(batch_size):
                count += 1
                max_length = src_lens[b]
                trg_sentence = [output_lang.index2word[int(token)] for token in trg_batch[:,b] if token != PAD_TOKEN]
                src_batch = src_batch.long()
                src_input = src_batch[:src_lens[b],b].unsqueeze(1)
                src_mask  = (torch.ones(src_lens[b],device = device)).unsqueeze(1).unsqueeze(-1)
                encoder_outputs = encoder(src_input, src_mask)
                decoder_input = torch.LongTensor([[SOS_TOKEN]]).to(device)
                decoder_hidden = None
                max_trg_len = trg_lens[b]
                decoded_words = []
                decoder_attentions = torch.zeros(batch_size, max_length, max_length)
                priors = [[decoder_input, decoder_hidden, encoder_outputs,decoder_attentions,0, 0]]
                sent_cand = ['' for i in range(k)]
                for di in range(3 * max_length):
                    curr = {}
                    possible = []
                    for prior_data in priors:
                        decoder_input, decoder_hidden, encoder_outputs, decoder_attentions, v, source_idx = prior_data
                        decoder_output, decoder_hidden, decoder_attention = decoder(
                            decoder_input, decoder_hidden, [src_lens[b]], encoder_outputs)
                        topv, topi = decoder_output.data.topk(k)
#                         decoder_attentions[di] = decoder_attention.data
                        for i in range(k):
                            possible.append(int(topi[:,i].squeeze().detach()))
                            curr[topv[0,i]+v] = [topi[:,i], decoder_hidden, encoder_outputs, 
                                                 decoder_attentions, topv[0,i]+v, source_idx]

                    sorted_v = sorted(curr.keys(),reverse=True)
                    top_k = sorted_v[:k]
                    temp = [x for x in sent_cand]
                    for i, index in enumerate(top_k):
                        token = int(curr[index][0])
                        source_idx = curr[index][-1]
                        curr[index][-1] = i
                        if token == EOS_TOKEN:
                            sent_cand[i] = temp[source_idx] + '<eos>'
                            break
                        else:
                            sent_cand[i] = temp[source_idx] + (output_lang.index2word[token] + " " )                    
                    if EOS_TOKEN == possible[0]:
                        decoded_words = sent_cand[0]
                        break
                    priors = [curr[index] for index in top_k]
                if not decoded_words:
                    decoded_words = '<eos>'
                trg_sentence = ' '.join(trg_sentence)
                s = corpus_bleu(decoded_words,trg_sentence).score
                score += s
        
    return score/count

# Train

In [None]:
# Initialize models
c = copy.deepcopy
h = 4
d_model = 256
d_ff = 512
dropout = 0.1
src_vocab = input_lang.n_words
N = 4
learning_rate = 0.0005


encoder = SelfAttnEncoder(N, h, d_model, d_ff, src_vocab, dropout).to(device)
decoder = DecoderRNN(embed_size, hidden_size, output_lang.n_words, dropout=dropout).to(device)
encoder.train()
decoder.train()

encoder_optimizer = optim.Adam(encoder.parameters(),lr = learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(),lr = learning_rate)
criterion = nn.NLLLoss(ignore_index = PAD_TOKEN)

train(train_data, batch_size, n_epochs, encoder, decoder, 
          encoder_optimizer, decoder_optimizer, criterion, checkpoint_dir)