In [21]:
# !pip3 install sacrebleu

Collecting sacrebleu
  Downloading https://files.pythonhosted.org/packages/37/51/bffea2b666d59d77be0413d35220022040a1f308c39009e5b023bc4eb8ab/sacrebleu-1.2.12.tar.gz
Collecting typing (from sacrebleu)
  Downloading https://files.pythonhosted.org/packages/4a/bd/eee1157fc2d8514970b345d69cb9975dcd1e42cd7e61146ed841f6e68309/typing-3.6.6-py3-none-any.whl
Building wheels for collected packages: sacrebleu
  Running setup.py bdist_wheel for sacrebleu ... [?25ldone
[?25h  Stored in directory: /home/yc2462/.cache/pip/wheels/ea/0a/7d/ddcbdcd15a04b72de1b3f78e7e754aab415aff81c423376385
Successfully built sacrebleu
Installing collected packages: typing, sacrebleu
Successfully installed sacrebleu-1.2.12 typing-3.6.6


In [1]:
from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import string
import re
import random
import numpy as np

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from sacrebleu import corpus_bleu
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
words_to_load = 50000
PAD_IDX = 0
UNK_IDX = 1
SOS_IDX = 2
EOS_IDX = 3

with open('../cc.zh.300.vec') as f:
    loaded_embeddings_ft = np.zeros((words_to_load+3, 300))
    words_ft = {}
    idx2words_ft = {}
    ordered_words_ft = []
    ordered_words_ft.extend(['<pad>', '<unk>', '<s>'])
    loaded_embeddings_ft[0,:] = np.zeros(300)
    loaded_embeddings_ft[1,:] = np.random.normal(size = 300)
    loaded_embeddings_ft[2,:] = np.random.normal(size = 300)
    for i, line in enumerate(f):
        if i >= words_to_load: 
            break
        s = line.split()
        loaded_embeddings_ft[i+3, :] = np.asarray(s[1:])
        words_ft[s[0]] = i+3
        idx2words_ft[i+3] = s[0]
        ordered_words_ft.append(s[0])
    words_ft['<pad>'] = PAD_IDX
    words_ft['<unk>'] = UNK_IDX
    words_ft['<s>'] = SOS_IDX
    idx2words_ft[PAD_IDX] = '<pad>'
    idx2words_ft[UNK_IDX] = '<unk>'
    idx2words_ft[SOS_IDX] = '<s>'

In [3]:
#English embedding
with open('wiki-news-300d-1M.vec') as f:
    loaded_embeddings_ft_en = np.zeros((words_to_load+4, 300))
    words_ft_en = {}
    idx2words_ft_en = {}
    ordered_words_ft_en = []
    ordered_words_ft_en.extend(['<pad>', '<unk>', '<s>', '</s>'])
    loaded_embeddings_ft_en[0,:] = np.zeros(300)
    loaded_embeddings_ft_en[1,:] = np.random.normal(size = 300)
    loaded_embeddings_ft_en[2,:] = np.random.normal(size = 300)
    loaded_embeddings_ft_en[3,:] = np.random.normal(size = 300)
    for i, line in enumerate(f):
        if i >= words_to_load: 
            break
        s = line.split()
        loaded_embeddings_ft_en[i+4, :] = np.asarray(s[1:])
        words_ft_en[s[0]] = i+4
        idx2words_ft_en[i+4] = s[0]
        ordered_words_ft_en.append(s[0])
    words_ft_en['<pad>'] = PAD_IDX
    words_ft_en['<unk>'] = UNK_IDX
    words_ft_en['<s>'] = SOS_IDX
    words_ft_en['</s>'] = EOS_IDX
    idx2words_ft_en[PAD_IDX] = '<pad>'
    idx2words_ft_en[UNK_IDX] = '<unk>'
    idx2words_ft_en[SOS_IDX] = '<s>'
    idx2words_ft_en[EOS_IDX] = '</s>'

In [4]:
#read in chinese-english pairs
lines_zh = open('iwslt-zh-en/train.tok.zh',encoding = 'utf-8').read().strip().split('\n')
lines_en = open('iwslt-zh-en/train.tok.en',encoding = 'utf-8').read().strip().split('\n')
lines_zh_test = open('iwslt-zh-en/test.tok.zh',encoding = 'utf-8').read().strip().split('\n')
lines_en_test = open('iwslt-zh-en/test.tok.en',encoding = 'utf-8').read().strip().split('\n')
lines_zh_val = open('iwslt-zh-en/dev.tok.zh',encoding = 'utf-8').read().strip().split('\n')
lines_en_val = open('iwslt-zh-en/dev.tok.en',encoding = 'utf-8').read().strip().split('\n')

In [5]:
#add sos and eos in each sentence
def add_sos_eos(lines):  
    train = []
    for l in lines:
        l = '<s> ' + l + '</s>'
        train.append(l)
    return train
zh_train = add_sos_eos(lines_zh)    
en_train = add_sos_eos(lines_en)
zh_test = add_sos_eos(lines_zh_test)
en_test = add_sos_eos(lines_en_test)
zh_val = add_sos_eos(lines_zh_val)
en_val = add_sos_eos(lines_en_val)

In [6]:
# convert token to id in the dataset
def token2index_dataset(tokens_data,eng = False):
    indices_data = []
    for tokens in tokens_data:
        index_list = []
        for token in tokens.split():
            if eng == False:
                try:
                    index_list.append(words_ft[token])
                except KeyError:
                    index_list.append(UNK_IDX)
            else:
                try:
                    index_list.append(words_ft_en[token])
                except KeyError:
                    index_list.append(UNK_IDX)
        indices_data.append(index_list)
    return indices_data

In [7]:
zh_train_indices = token2index_dataset(zh_train)
en_train_indices = token2index_dataset(en_train,eng = True)
zh_test_indices = token2index_dataset(zh_test)
en_test_indices = token2index_dataset(en_test,eng = True)
zh_val_indices = token2index_dataset(zh_val)
en_val_indices = token2index_dataset(en_val,eng = True)

In [8]:
#max_sentence_length
length_of_en = [len(x.split()) for x in en_train]
max_sentence_length_en = sorted(length_of_en)[-int(len(length_of_en)*0.01)]
length_of_zh = [len(x.split()) for x in zh_train]
max_sentence_length_zh = sorted(length_of_zh)[-int(len(length_of_zh)*0.01)]

In [9]:
sorted(length_of_zh)[-int(len(length_of_zh)*0.1)]

37

In [10]:
#Create Data Loader
import torch
from torch.utils.data import Dataset

class load_dataset(Dataset):
    """
    Class that represents a train/validation/test dataset that's readable for PyTorch
    Note that this class inherits torch.utils.data.Dataset
    """
    
    def __init__(self, data_list_s1,data_list_s2):
        """
        @param data_list_zh: list of Chinese tokens 
        @param data_list_en: list of English tokens as TARGETS
        """
        self.data_list_s1 = data_list_s1
        self.data_list_s2 = data_list_s2
        
        assert (len(self.data_list_s1) == len(self.data_list_s2))

    def __len__(self):
        return len(self.data_list_s1)
        
    def __getitem__(self, key):
        """
        Triggered when you call dataset[i]
        """
        
        token_idx_s1 = self.data_list_s1[key][:max_sentence_length_zh]
        token_idx_s2 = self.data_list_s2[key][:max_sentence_length_en]
        return [token_idx_s1, token_idx_s2, len(token_idx_s1), len(token_idx_s2)]

def collate_func(batch):
    """
    Customized function for DataLoader that dynamically pads the batch so that all 
    data have the same length
    """
    data_list_s1 = []
    data_list_s2 = []
    length_list_s1 = []
    length_list_s2 = []
    for datum in batch:
        length_list_s1.append(datum[2])
        length_list_s2.append(datum[3])
        padded_vec_zh = np.pad(np.array(datum[0]), 
                                pad_width=((0,max_sentence_length_zh-datum[2])), 
                                mode="constant", constant_values=0)
        padded_vec_en = np.pad(np.array(datum[1]), 
                                pad_width=((0,max_sentence_length_en-datum[3])), 
                                mode="constant", constant_values=0)
        data_list_s1.append(padded_vec_zh[:max_sentence_length_zh])
        data_list_s2.append(padded_vec_en[:max_sentence_length_en])
    #print(type(data_list_s1[0]))
    if torch.cuda.is_available and torch.has_cudnn:
        return [torch.from_numpy(np.array(data_list_s1)).cuda(), torch.from_numpy(np.array(data_list_s2)).cuda(),
                torch.LongTensor(length_list_s1).cuda(), torch.LongTensor(length_list_s2).cuda()]
    else:    
        return [torch.from_numpy(np.array(data_list_s1)), torch.from_numpy(np.array(data_list_s2)),
                torch.LongTensor(length_list_s1), torch.LongTensor(length_list_s2)]
    


In [11]:
BATCH_SIZE = 32
EMBEDDING_SIZE = 300 # fixed as from the input embedding data


train_dataset = load_dataset(zh_train_indices, en_train_indices)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=BATCH_SIZE,
                                           collate_fn=collate_func,
                                           shuffle=True)

val_dataset = load_dataset(zh_val_indices, en_val_indices)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                           batch_size=BATCH_SIZE,
                                           collate_fn=collate_func,
                                           shuffle=False)

### With Attention

In [12]:
class EncoderRNN(nn.Module):
    def __init__(self, emb_dim, hidden_size, embed= torch.from_numpy(loaded_embeddings_ft).float(),num_layers=1):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.emb_dim = emb_dim
        self.num_layers = num_layers 

        self.embedding = nn.Embedding.from_pretrained(embed, freeze=True)
        self.gru = nn.GRU(emb_dim, hidden_size,num_layers=num_layers,batch_first=True)

    def forward(self, data, hidden):
        
        batch_size, seq_len = data.size()
        
        embed = self.embedding(data)
        
        output, hidden = self.gru(embed,hidden)
        #hidden = [n layers * n directions =1 , batch_size, hidden_size ]
        
        return output, hidden

    # initialize the hidden with random numbers
    def initHidden(self,batch_size):
        return torch.randn(self.num_layers, batch_size, self.hidden_size,device=device)

In [13]:
class AttnDecoderRNN(nn.Module):
    def __init__(self,emb_dim,hidden_size, output_size, embed= torch.from_numpy(loaded_embeddings_ft_en).float(),num_layers=1,
                 dropout_p=0.1, max_length=max_sentence_length_zh):
        super(AttnDecoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers 
        self.output_size = output_size
        self.dropout_p = dropout_p
        self.max_length = max_length

        self.embedding = nn.Embedding.from_pretrained(embed, freeze=True)
        self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
        self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
        self.dropout = nn.Dropout(self.dropout_p)

        self.gru = nn.GRU(emb_dim, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, data, hidden,encoder_outputs):
        
        ### embed: [1 * batch size * emb_dim = 300 ] ###
        ### hidden: [1 * batch size * hidden_size = 300 ] ###
        ### encoder_outputs: [batch size * max_sentence_length_zh * hidden_size = 300 ] ###
        ### 因为这里concat之后，attn layer 他给的是 hidden size *2 
        ### 所以我这儿的hidden size就只能写300了 
        
        embed = self.embedding(data)
        embed = self.dropout(embed)    
        ### torch.cat((embed, hidden), 2)  
        ### [1 * batch size * (emb_dim + hidden_size) ]
        
        ### attn_weights: [1 * batch size * max_sentence_length_zh ]###
        ### attn_weights[0].unsqueeze(1): [batch size * 1 * max_sentence_length_zh ]###
        
        ### softmax dim=2 因为最后一个dimension是 词组什么的，不能是1，1的话就是
        ### 不同batch间这样比较了？
        attn_weights = F.softmax(
            self.attn(torch.cat((embed, hidden), 2)), dim=2)
        

        ### torch.bmm(attn_weights[0].unsqueeze(1),encoder_outputs).squeeze(1) :
        ### [batch size * 1 * hidden_size ]###

        ### attn_applied: [batch size * hidden_size (= 300) ] ###
        attn_applied = torch.bmm(attn_weights[0].unsqueeze(1), encoder_outputs).squeeze(1)
        
        ### output: [batch size * hidden_size (= 300) ] ###
        ### embed[0]: [batch size * hidden_size (= 300) ] ###

        output = torch.cat((embed[0], attn_applied), 1)
 
        ### output: [1 * batch size * hidden_size (= 300) ] ###
        output = self.attn_combine(output).unsqueeze(0)
        
        ### output: [1 * batch size * hidden_size (= 300) ] ###
        output = F.relu(output)
        
        output, hidden = self.gru(output, hidden)
        output = self.softmax(self.out(output[0]))
        
        return output, hidden, attn_weights

    def initHidden(self,batch_size):
        return torch.randn(self.num_layers, batch_size, self.hidden_size,device=device)

In [14]:
teacher_forcing_ratio = 0.5
#input_tensor: list of sentence tensor
def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer,
          criterion, eee):
    
    ### target_tensor [batch size, max_sentence_length_en = 377] ###
    ### target_tensor [batch size, max_sentence_length_zh = 220] ###
    batch_size_1, input_length = input_tensor.size()
    batch_size_2, target_length = target_tensor.size()
    
    
    encoder_hidden = encoder.initHidden(batch_size_1)

    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    loss = 0
    ### encoder_hidden: 1 * batch * hidden size ### 
    ### encoder_output: batch size * max_sentence_length_zh * hidden size ### 
    encoder_output, encoder_hidden = encoder(input_tensor, encoder_hidden)

    decoder_input = torch.tensor(np.array([[SOS_IDX]]*batch_size_1).reshape(1,batch_size_1),device=device)
    decoder_hidden = encoder_hidden

    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
    #print(use_teacher_forcing)
    if use_teacher_forcing:
        # Teacher forcing: Feed the target as the next input
        for di in range(target_length):
            
            ### decoder_output: [batchsize,5000] ###
            decoder_output, decoder_hidden, decoder_attention = decoder(
                decoder_input, decoder_hidden,encoder_output)
            
            loss += criterion(decoder_output, target_tensor[:,di])
            decoder_input = target_tensor[:,di].unsqueeze(0)  # Teacher forcing
            
    else:
        # Without teacher forcing: use its own predictions as the next input
        for di in range(target_length):
            decoder_output, decoder_hidden, decoder_attention = decoder(
                decoder_input, decoder_hidden,encoder_output)
                        
            ### decoder_output [batch size, 50003]  ###
            
            ### topi is a [batch size, 1] tensor first we remove the size 1
            ### demension then we add it at the beginning using squeeze
            ### 有点脑残诶，做个转置不就好了？
            
            topv, topi = decoder_output.topk(1)
            decoder_input = topi.squeeze().detach()  # detach from history as input
            
            ### decoder_input [1, batch size]  ###
            decoder_input = decoder_input.unsqueeze(0)
 
            loss += criterion(decoder_output, target_tensor[:,di])

    loss.backward()

    encoder_optimizer.step()
    decoder_optimizer.step()

    return loss.item() / target_length

In [15]:
def trainIters(encoder, decoder, n_iters, print_every=1, plot_every=100, learning_rate=0.001):
    start = time.time()
    plot_losses = []
    print_loss_total = 0  # Reset every print_every
    plot_loss_total = 0  # Reset every plot_every
    
    #--------------------------------------------	
    #	
    #    LOAD MODELS	
    #	
    #--------------------------------------------	
    folder = '.'	
    if not os.path.exists(folder):	
        os.makedirs(folder)	

    if os.path.exists('./attentation_model/Encoder_b'):	
        print('---------------------------------------------------------------------')	
        print('----------------Readind trained model---------------------------------')	
        print('---------------------------------------------------------------------')	
        	
        #read trained models	
        encoder.load_state_dict(torch.load(folder+"/Encoder_b"))
        decoder.load_state_dict(torch.load(folder+"/Decoder_b"))	
    
    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)
    
    criterion = nn.CrossEntropyLoss()

    for iter in range(1, n_iters + 1):
        for i, (data_s1, data_s2, lengths_s1, lengths_s2) in enumerate(train_loader):
            input_tensor = data_s1
            target_tensor = data_s2
            #print("train",target_tensor.size())
            loss = train(input_tensor, target_tensor, encoder,
                         decoder, encoder_optimizer, decoder_optimizer, criterion,i)
            print_loss_total += loss
            plot_loss_total += loss

            if i % print_every == 0:
                print_loss_avg = print_loss_total / print_every
                print_loss_total = 0
                print('%s (%d %d%%) %.4f' % (timeSince(start, iter / n_iters),
                                             iter, iter / n_iters * 100, print_loss_avg))

            if i % plot_every == 0:
                plot_loss_avg = plot_loss_total / plot_every
                plot_losses.append(plot_loss_avg)
                plot_loss_total = 0
                
        # Save the model for every epoch
        print('---------------------------------------------------------------------')	
        print('----------------Saving trained model---------------------------------')	
        print('---------------------------------------------------------------------')	
      
        torch.save(encoder.state_dict(),folder +"/Encoder_b")
        torch.save(decoder.state_dict(),folder +"/Decoder_b")

    
    return plot_losses



In [16]:
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import matplotlib.ticker as ticker
import numpy as np
%matplotlib inline

def showPlot(points):
    plt.figure()
    fig, ax = plt.subplots()
    # this locator puts ticks at regular intervals
    loc = ticker.MultipleLocator(base=0.2)
    ax.yaxis.set_major_locator(loc)
    plt.plot(points)
    
import time
import math

def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))


In [17]:
# beam search + bleu score
def beam_search_decoder(data, k):
    sequences = [[list(), 1.0]]
    # walk over each step in sequence
    for row in data:
        all_candidates = list()
        # expand each current candidate
        for i in range(len(sequences)):
            seq, score = sequences[i]
            for j in range(len(row)):
                candidate = [seq + [j], score * -log(row[j])]
                all_candidates.append(candidate)
        # order all candidates by score
        ordered = sorted(all_candidates, key=lambda tup:tup[1])
        # select top best
        sequences = ordered[:1]
    return sequences

In [18]:
def index2token_sentence(sentence_batch):
    return ' '.join(idx2words_ft_en[sent.item()] for sent in sentence_batch if sent.item()!=PAD_IDX)

In [19]:
class decoder_output_node:
    def __init__(self,parent, word_idx, prob_sum, isroot=False):
        self.parent = parent
        self.isroot = isroot
        self.children = []
        self.word_idx = word_idx
        self.prob_sum = prob_sum
    
    def get_children(self):
        '''
        return children
        '''
        return self.children
    
    def add_children(self, child):
        '''
        child: node
        '''
        self.children.append(child)
        return
    
    def get_parent(self):
        '''
        get parent of children
        '''
        return self.parent
    
    def get_word_idx(self):
        
        return self.word_idx
    
    def get_prob_sum(self):
        
        return self.prob_sum
    
    def is_root(self):
        return self.isroot


In [55]:
def return_sentence_sequence(child_node):
    if child_node.is_root():
        return [child_node.get_word_idx()]
    
    return return_sentence_sequence(child_node.get_parent())+[child_node.get_word_idx()]

In [83]:
def beam_search(beam_k, decoder_output, prob_sum = None, parent_node_list=None, vocab_size = len(idx2words_ft_en)-1):
    '''
    params:
    beam_k
    decoder_output: previous round decoder output
    parent_node_list: previous candidate word list (for only one candidate)
    
    return:
    list_of_best_k_nodes: best k nodes found in this iteration, list of list, first dim batch, second dim best k
    prob_with_sum: probabilistic matrix after sum+sortee 
    '''
    # if first word
    if parent_node_list is None:
        # initialize result
        prob_with_sum_sorted, word_idx_sorted = decoder_output.data.topk(beam_k)
        
        # add initialize tree list
        list_of_best_k_nodes = []
        batchsize = prob_with_sum_sorted.shape[0]
        for batch_i in range(batchsize):
            batch_i_tree_list = []
            for beam_i in range(beam_k):
                # add tree root node to list
                batch_i_tree_list.append(decoder_output_node(parent=None, word_idx=word_idx_sorted[batch_i, beam_i].item(), 
                                                            prob_sum= prob_with_sum_sorted[batch_i, beam_i].item(), isroot=True))
                
            list_of_best_k_nodes.append(batch_i_tree_list)
   
    # if not first word
    else:
        # get sorted results for all outputs
        prob, word_idx = decoder_output.data.topk(vocab_size)
        #print(decoder_output.data.data.topk(vocab_size))
        #print(word_idx)
        
        
        # find top beam k words options
        prob_with_sum = prob+prob_sum
        prob_with_sum_sorted, word_idx_sorted = torch.sort(prob_with_sum, dim=1, descending=True)
        
        # add top beam k words options into tree
        batchsize = prob_with_sum_sorted.shape[0]
        
        list_of_best_k_nodes = []
        for batch_i in range(batchsize):
            batch_i_tree_list = []
            for beam_i in range(beam_k):
                #print(word_idx_sorted[batch_i, beam_i])
                #print(parent_node_list[batch_i].get_word_idx())
                child_node = decoder_output_node(parent=parent_node_list[batch_i], word_idx= word_idx[batch_i,word_idx_sorted[batch_i,beam_i]].item(), prob_sum=prob_with_sum_sorted[batch_i,beam_i].item())
                # update parent node's child
                parent_node_list[batch_i].add_children(child_node)
                #save child to new list
                batch_i_tree_list.append(child_node)
            # add batch tree list to best k
            list_of_best_k_nodes.append(batch_i_tree_list)
                
    return list_of_best_k_nodes, prob_with_sum_sorted[:,:beam_k], word_idx_sorted[:,:beam_k]


In [89]:
# beam search temp eval
beam_k = 3
with torch.no_grad():
    for i, (data_s1, data_s2, lengths_s1, lengths_s2) in enumerate(val_loader):
        print(i)
        input_tensor = data_s1
        input_length = input_tensor.size()[0]
        #sentence_length to the output length
        sentence_length = data_s2.size()[1]
        encoder_hidden = encoder1.initHidden(input_length)

        encoder_output, encoder_hidden = encoder1(input_tensor, encoder_hidden)

        #decoder_input = torch.tensor([[SOS_token]], device=device)  # SOS
        decoder_input = torch.tensor(np.array([[SOS_IDX]]*input_length).reshape(1,input_length),device=device)

        decoder_hidden = encoder_hidden

        decoder_attentions = torch.zeros(sentence_length, sentence_length)
        decoded_words_eval = []
        list_of_best_k_nodes = []
        
        for di in range(sentence_length):
            
            ############################################beam search###################################################
            if di == 0:
                decoded_words_sub = []
                
                
                decoder_output, decoder_hidden, decoder_attention = decoder1(
                                decoder_input, decoder_hidden, encoder_output)
                
                # find top k candidates
                list_of_best_k_nodes,prob_with_sum_sorted ,word_idx_sorted = beam_search(beam_k, decoder_output, parent_node_list=None)
                #print(word_idx_sorted)
                #print(list_of_best_k_nodes[0][0].get_word_idx())
                #print(list_of_best_k_nodes[0][1].get_word_idx())
                
            else:
                # keep track of all new nodes
                new_nodes = []
                nodes_prob = None
                #nodes_word_idx = None
                
                # store index in previous candidate to locate position in new nodes, repeats=beam_size*beam_size
                prev_candidate_idx = np.repeat(range(beam_k), repeats=beam_k*beam_k)
                #print(prev_candidate_idx)
                
                # iterate through each node candidate from last iterations to find new candidates
                for beam_i in range(beam_k):
                    #print(word_idx_sorted.shape)
                    topi = word_idx_sorted[:,beam_i]
                    #print(topi)
                    prob_sum = prob_with_sum_sorted[:,beam_i].view((input_length,1))
                    
                    #change the dimension
                    decoder_input = topi.squeeze().detach()
                    decoder_input = decoder_input.unsqueeze(0)
                
                    # get decoder output
                    decoder_output, decoder_hidden, decoder_attention = decoder1(
                                decoder_input, decoder_hidden, encoder_output)
                    
                    # get beam search output
                    best_k_curr_node, prob_sum_curr_node, _ = beam_search(beam_k, decoder_output, prob_sum=prob_sum, parent_node_list=[ls[beam_i] for ls in list_of_best_k_nodes])
                    #print(word_idx_curr_node)
                    
                    # keep track of beam search output
                    new_nodes.append(best_k_curr_node)
                    
                    if beam_i == 0:
                        nodes_prob = prob_sum_curr_node
                        
                        #nodes_word_idx = word_idx_curr_node
                    else:
                        nodes_prob = torch.cat((nodes_prob, prob_sum_curr_node),dim=1)
                        #nodes_word_idx = torch.cat((nodes_word_idx, word_idx_curr_node),dim=1)
                
                _, sorted_idx = torch.sort(nodes_prob, dim=1, descending=True)
                #print("length",nodes_prob.shape)
                
                # update 
                #print(sorted_idx.shape)
                for batch_i in range(input_length):
                    for beam_i in range(beam_k):
                        # find the index of which candidate it descended from
                        st_idx = sorted_idx[batch_i][beam_i].item()
                        # find the corresponding node, st_idx gives parent node id, batch_i gives which example, st_idx%beam_k gives which node in the existing node list
                        #print(st_idx)
                        update_node = new_nodes[prev_candidate_idx[st_idx]][batch_i][st_idx%beam_k]
                        
                        list_of_best_k_nodes[batch_i][beam_i] = update_node
                        #print(batch_i)
                        #print(beam_i)
                        #print(list_of_best_k_nodes[0][0].parent.get_word_idx())
                        
                        # update word idex, prob sum correspondingly for next iteration
                        #word_idx_sorted[batch_i][beam_i] = nodes_word_idx[batch_i][st_idx] 
                        word_idx_sorted[batch_i][beam_i] = update_node.get_word_idx()
                        prob_with_sum_sorted[batch_i][beam_i] = update_node.get_prob_sum()
            
        # find the best and get index
        listed_predictions = []
        for batch_i in range(input_length):
            best_sequence_last_node = list_of_best_k_nodes[batch_i][0]
            batch_i_word_idx = return_sentence_sequence(best_sequence_last_node)
            
            listed_predictions.append(' '.join(idx2words_ft_en[token_idx] for token_idx in batch_i_word_idx if token_idx!=PAD_IDX))
            #print(batch_i_word_idx)
        listed_reference = []
        for ele in data_s2:
            sent = index2token_sentence(ele)
            
            listed_reference.append(sent)
            
        print(listed_predictions)
        bleu_score = corpus_bleu(listed_predictions,[listed_reference])
        print('BLEU Score is %s' % (str(bleu_score.score)))

        ############################################beam search###################################################


0
['<s> And <unk> <unk>', '<s> And <unk>', '<s> And , <unk> <unk>', '<s> <s> <unk>', '<s> And <unk> <unk>', '<s> And <unk> <unk>', '<s> And <unk> <unk>', '<s> And <unk> <unk>', '<s> And <unk> <unk>', '<s> <s> , <unk> <unk>', '<s> And <unk> <unk>', '<s> And <unk> <unk>', '<s> And , <unk> <unk>', '<s> And <unk> <unk>', '<s> And <unk>', '<s> And <unk> <unk>', '<s> And <unk> <unk>', '<s> And <unk> <unk>', '<s> <s> <unk> <unk>', '<s> And <unk>', '<s> And <unk>', '<s> And <unk> <unk>', '<s> And <unk> <unk>', '<s> And , <unk> <unk>', '<s> And <unk> <unk>', '<s> <s> <unk>', '<s> <s> , , <unk>', '<s> And , , <unk>', '<s> And , , <unk> <unk>', '<s> And <unk> <unk>', '<s> And , , <unk> <unk>', '<s> And <unk> <unk> <unk>']
BLEU Score is 6.370453494395023
1
['<s> And <unk> <unk>', '<s> <s> , , <unk>', '<s> <s> <unk>', '<s> And <unk>', '<s> And <unk> <unk>', '<s> And , , <unk>', '<s> And <unk> <unk>', '<s> And , , <unk>', '<s> And , , <unk> <unk>', '<s> And <unk>', '<s> And <unk> <unk>', '<s> And , 

KeyboardInterrupt: 

In [44]:
list_of_best_k_nodes[0][0].parent.get_word_idx()

tensor(171, device='cuda:0')

In [64]:
import operator
#loader can be test_loader or val_loader
def evaluate(loader, encoder, decoder, beam = False, beam_k = 1, threshold = 0.5):
    bleu_score_list = []
    with torch.no_grad():
        for i, (data_s1, data_s2, lengths_s1, lengths_s2) in enumerate(loader):
            print(i)
            input_tensor = data_s1
            input_length = input_tensor.size()[0]
            #sentence_length to the output length
            sentence_length = data_s2.size()[1]
            encoder_hidden = encoder.initHidden(input_length)

            encoder_output, encoder_hidden = encoder(input_tensor, encoder_hidden)
            
            #decoder_input = torch.tensor([[SOS_token]], device=device)  # SOS
            decoder_input = torch.tensor(np.array([[SOS_IDX]]*input_length).reshape(1,input_length),device=device)

            decoder_hidden = encoder_hidden

            decoder_attentions = torch.zeros(sentence_length, sentence_length)
            decoded_words_eval = []
            # EOS_IDX tensor matrix
            test_matrix = torch.ones(input_length, beam_k)*EOS_IDX
            out_sequences = []
            for di in range(sentence_length):
                decoded_words_sub = []
                if beam == True:
#                     pass
                    #input length is the batch_size 
                    last_word_matrix = torch.zeros([input_length, beam_k])
                    # if all last word are the EOS_IDX then break
                    if torch.nonzero(last_word_matrix==test_matrix).size(0) >= threshold*input_length*beam_k:
                        break
                    # initiate
                    if di == 0:
                        decoder_output, decoder_hidden, decoder_attention = decoder(
                            decoder_input, decoder_hidden, encoder_output)
                        prob, elements = decoder_output.data.topk(beam_k)
                        #batch loop
                        for idx, ind in enumerate(elements):
                            sequences = []
                            for idx2 in range(beam_k):
                                # ind[idx2] is the index of vocab
                                sequences.append(([ind[idx2].item()], prob[idx][idx2].item()))
                                last_word_matrix[idx, idx2] = ind[idx2]
                            out_sequences.append(sequences)
                    else:
                        # shape of decoder output is 1 less than english vocab size, why?
                        prob, elements = decoder_output.data.topk(len(idx2words_ft_en)-1)
                        #batch loop
                        for idx, ind in enumerate(elements):
                            #score_list : (1*vocab_size)
                            '''
                            ? whether score_list make sense
                            last_word_matrix every element needs to be tensor
                            '''
                            score_list = list(prob[idx].cpu())
                            updated_dic = {}
                            for idx2 in range(beam_k):
                                # vocab size list (log)prob + the (log)prob 
                                updated_score_list = np.array(score_list) + np.array([out_sequences[idx][idx2][1]]*len(score_list))
                                #length of vocab_size
                                for idx3, ele in enumerate(updated_score_list):
                                    # key is the tuple of two indices of vocab # out_sequences[idx][idx2][0]
                                    updated_dic[(idx2, idx3)] = ele
                            # sort all the dict values and output the keys (tuple of two indices)
                            optimal = dict(sorted(updated_dic.items(), key=operator.itemgetter(1), reverse=True)[:beam_k])
                            
                            #change a dictionary to list of tuples and override it to the out_sequences
                            for index, (k, v) in enumerate(optimal.items()):
                                (k1, k2) = k
                                out_sequences[idx] = [(out_sequences[idx][k1][0].append(k2), v)]
                                last_word_matrix[idx][index] = k2
                    
                    #######################temp##############
                    #need to iterate through topi for all choi
                    topi = last_word_matrix[:,0]
                    #######################temp##############
                    # select the first column of out_sequence (which has the hightest softmax value) to do the corpus_blue
                    '''
                    TO DO
                    '''
                else:
                    decoder_output, decoder_hidden, decoder_attention = decoder(
                        decoder_input, decoder_hidden, encoder_output)
                    # topk(1) - softmax probability maximum
                    topv, topi = decoder_output.data.topk(1) 
                    #print(topi)
                    #batch loop
                
                for ind in topi:
                    if ind.item() == EOS_IDX:
                        decoded_words_sub.append('</s>')
                        break
                    else:
                        decoded_words_sub.append(idx2words_ft_en[ind.item()])
                decoded_words_eval.append(decoded_words_sub)

                #change the dimension
                decoder_input = topi.squeeze().detach()
                decoder_input = decoder_input.unsqueeze(0)

            pred_num = 0
            listed_predictions = []
            #swap dimensions of decoded_words to [batch_size * 377]
            decoded_words_new = [[i for i in ele] for ele in list(zip(*decoded_words_eval))]
           # print(decoded_words_new)
            for token_list in decoded_words_new:
                sent = ' '.join(token for token in token_list if token!="<pad>")
                #print(len(token_list))
                #print (sent)
                listed_predictions.append(sent)
                pred_num += 1
                
            ref_num = 0
            listed_reference = []
            for ele in data_s2:
                sent = index2token_sentence(ele)
                #print (tokens)
                #sent = ' '.join(tokens)
                #print (sent)
                listed_reference.append(sent)
                ref_num += 1
            print(listed_predictions)
            bleu_score = corpus_bleu(listed_predictions,[listed_reference])
            print('BLEU Score is %s' % (str(bleu_score.score)))
        bleu_score_list.append(bleu_score)
        return bleu_score_list, decoded_words_new, decoder_attentions[:di + 1]

In [22]:
hidden_size = 300
encoder1 = EncoderRNN(EMBEDDING_SIZE,hidden_size).to(device)
decoder1 = AttnDecoderRNN(EMBEDDING_SIZE,hidden_size, len(ordered_words_ft)).to(device)

# ##UNCOMMENT TO TRAIN THE MODEL
trainIters(encoder1, decoder1, 1, print_every=50)
#encoder1.load_state_dict(torch.load("encoder.pt"))
#decoder1.load_state_dict(torch.load("decoder.pt"))

0m 0s (- 0m 0s) (1 100%) 0.2166
0m 11s (- 0m 0s) (1 100%) 3.2614
0m 21s (- 0m 0s) (1 100%) 2.2029
0m 32s (- 0m 0s) (1 100%) 2.0147
0m 43s (- 0m 0s) (1 100%) 1.9650


KeyboardInterrupt: 

In [65]:
score_list, output_words, attentions = evaluate(val_loader, encoder1, decoder1, beam=False)

0
['<s> And , , <unk> <unk>', '<s> And <unk> <unk>', '<s> And , , , , , , <unk> <unk> <unk>', '<s> And <unk> <unk> <unk> <unk>', '<s> And , <unk> <unk> <unk> <unk> <unk>', '<s> And <unk> <unk> <unk>', '<s> And <unk> <unk> <unk>', '<s> And , <unk> <unk>', '<s> And <unk> <unk> <unk> <unk>', '<s> And , , , , , , <unk> <unk> <unk>', '<s> And <unk> <unk> <unk> <unk>', '<s> And <unk> <unk> <unk> <unk> <unk>', '<s> And , , <unk> <unk> <unk> <unk>', '<s> And , , <unk> <unk> <unk>', '<s> And , <unk> <unk>', '<s> And <unk> <unk> <unk> <unk> <unk>', '<s> And <unk> <unk> <unk> <unk>', '<s> And <unk> <unk> <unk>', '<s> And <unk> <unk> <unk>', '<s> And <unk> <unk> <unk>', '<s> And <unk> <unk> <unk> <unk>', '<s> And , , <unk> <unk> <unk>', '<s> And , <unk> <unk> <unk>', '<s> And , , , , , , , <unk> <unk> <unk> <unk>', '<s> And , <unk> <unk>', '<s> And <unk> <unk> <unk>', '<s> And <unk> <unk> <unk> <unk> <unk> <unk> <unk>', '<s> And , , , <unk> <unk> <unk> <unk> <unk> <unk>', '<s> And , , , <unk> <unk

['<s> <s> And , <unk> <unk>', '<s> And , , , , <unk> <unk> <unk> <unk> <unk>', '<s> And , , , , <unk> <unk> <unk> <unk> <unk>', '<s> And , , , <unk> <unk> <unk>', '<s> And , , , , , , , , <unk> <unk> <unk>', '<s> And , , , , <unk> <unk> <unk>', '<s> And <unk> <unk> <unk> <unk>', '<s> <s> And , <unk> <unk>', '<s> And , , , , , , , <unk> <unk> <unk> <unk>', '<s> And , <unk> <unk> <unk> <unk>', '<s> <s> And <unk> <unk>', '<s> <s> And <unk> <unk> <unk>', '<s> And <unk> <unk> <unk> <unk>', '<s> And <unk> <unk> <unk>', '<s> And , , , , <unk> <unk> <unk> <unk> <unk>', '<s> <s> And , <unk> <unk> <unk>', '<s> And <unk> <unk> <unk>', '<s> And <unk> <unk> <unk> <unk>', '<s> <s> <s> <s> I , , , , , <unk> <unk> <unk> <unk> <unk>', '<s> And <unk> <unk> <unk>', '<s> And , , <unk> <unk> <unk> <unk> <unk> <unk> <unk>', '<s> <s> And , , , , , <unk> <unk> <unk> <unk> <unk>', '<s> And , <unk> <unk> <unk> <unk>', '<s> And , <unk> <unk> <unk>', '<s> And , , , <unk> <unk> <unk> <unk> <unk>', '<s> And , , , ,

['<s> And <unk> <unk> <unk>', '<s> And , , , , , <unk> <unk> <unk> <unk> <unk>', '<s> And , <unk> <unk>', '<s> And <unk> <unk> <unk>', '<s> <s> And , <unk> <unk> <unk>', '<s> And , , , <unk> <unk> <unk> <unk>', '<s> And , , , , , , , , , , , , <unk>', '<s> And , , <unk> <unk> <unk> <unk> <unk> <unk>', '<s> <s> And <unk> <unk> <unk> <unk>', '<s> <s> And <unk> <unk> <unk> <unk> <unk> <unk>', '<s> And , , , , , , , <unk> <unk> <unk> <unk>', '<s> And , <unk> <unk>', '<s> And <unk> <unk>', '<s> And , , , , , , , , <unk> <unk> <unk>', '<s> <s> And , , , , <unk> <unk> <unk> <unk> <unk>', '<s> And , , , , , , , , , , <unk>', '<s> <s> And , <unk> <unk>', '<s> And <unk> <unk> <unk>', '<s> <s> And <unk> <unk> <unk>', '<s> And <unk> <unk> <unk>', '<s> <s> And , , , , , <unk> <unk> <unk> <unk>', '<s> And , , <unk> <unk> <unk> <unk>', '<s> And , , , , , <unk> <unk> <unk>', '<s> And , <unk> <unk>', '<s> And , , <unk> <unk> <unk>', '<s> And , , <unk> <unk> <unk>', '<s> And , , , , , , <unk> <unk> <unk

['<s> And , , , , , , <unk> <unk> <unk>', '<s> And , <unk> <unk> <unk>', '<s> And <unk> <unk> <unk> <unk>', '<s> <s> And , , , , , , , <unk> <unk> <unk>', '<s> And <unk> <unk> <unk>', '<s> And , , , , , , , , <unk> <unk> <unk>', '<s> And , , <unk> <unk> <unk> <unk> <unk>', '<s> And <unk> <unk> <unk>', '<s> And , <unk> <unk> <unk>', '<s> <s> <s> I <unk> <unk>', '<s> And <unk> <unk> <unk> <unk> <unk>', '<s> And , <unk> <unk> <unk>', '<s> And <unk> <unk> <unk> <unk>', '<s> And <unk> <unk> <unk> <unk> <unk> <unk> <unk>', '<s> And <unk> <unk> <unk>', '<s> And , <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk>', '<s> <s> <s> <s> I , , , , , , , <unk> <unk> <unk>', '<s> And , , <unk> <unk> <unk> <unk> <unk>', '<s> And <unk> <unk> <unk>', '<s> And <unk> <unk> <unk>', '<s> And , , , <unk> <unk> <unk> <unk>', '<s> And <unk> <unk> <unk>', '<s> And , <unk> <unk> <unk> <unk> <unk>', '<s> And , , <unk> <unk> <unk> <unk> <unk>', '<s> And , , , <unk> <unk> <unk> <unk> <unk> <unk>', '<s> <s> And , , , ,

['<s> And <unk> <unk> <unk> <unk>', '<s> And <unk> <unk> <unk> <unk> <unk>', '<s> And , , , , , , , , , <unk> <unk>', '<s> And <unk> <unk> <unk> <unk>', '<s> And <unk> <unk> <unk>', '<s> <s> And <unk> <unk> <unk>', '<s> And , , , <unk> <unk> <unk>', '<s> And , , , , <unk> <unk> <unk> <unk> <unk>', '<s> And , , <unk> <unk> <unk> <unk>', '<s> And , , <unk> <unk> <unk> <unk> <unk> <unk> <unk>', '<s> And , , , <unk> <unk> <unk>', '<s> <s> And I , , , <unk> <unk> <unk> <unk>', '<s> <s> And , , , , , , , , <unk> <unk> <unk> <unk>', '<s> And , , , <unk> <unk> <unk> <unk> <unk>', '<s> And <unk> <unk> <unk>', '<s> And , <unk> <unk> <unk>', '<s> And , <unk> <unk> <unk> <unk> <unk>', '<s> And , <unk> <unk>', '<s> And , , , , , , <unk> <unk> <unk>', '<s> <s> And <unk> <unk> <unk>', '<s> And <unk> <unk> <unk> <unk> <unk> <unk>', '<s> And <unk> <unk> <unk>', '<s> And , , , , , , <unk> <unk> <unk> <unk>', '<s> <s> And <unk> <unk> <unk>', '<s> And <unk> <unk> <unk>', '<s> And <unk> <unk> <unk> <unk>',

['<s> And , , , <unk> <unk> <unk> <unk> <unk>', '<s> And <unk> <unk> <unk> <unk> <unk> <unk> <unk>', '<s> And <unk> <unk> <unk>', '<s> And <unk> <unk> <unk>', '<s> <s> And <unk> <unk> <unk> <unk> <unk> <unk> <unk>', '<s> And , , , , , <unk> <unk> <unk> <unk>', '<s> <s> And <unk> <unk> <unk>', '<s> And , , , <unk> <unk> <unk> <unk> <unk> <unk>', '<s> And <unk> <unk> <unk>', '<s> And , , , <unk> <unk> <unk> <unk> <unk> <unk>', '<s> <s> <s> I <unk> <unk> <unk>', '<s> And , , , , , , <unk> <unk> <unk> <unk>', '<s> And <unk> <unk> <unk>', '<s> And , , , , , , , , <unk> <unk> <unk>', '<s> And <unk> <unk> <unk>', '<s> And , <unk> <unk>', '<s> <s> And , , <unk> <unk> <unk> <unk> <unk>', '<s> And , , , , <unk> <unk> <unk> <unk> <unk> <unk>', '<s> And <unk> <unk> <unk>', '<s> And <unk> <unk> <unk> <unk>', '<s> <s> And , <unk> <unk> <unk>', '<s> And <unk> <unk> <unk>', '<s> And , , <unk> <unk> <unk> <unk> <unk>', '<s> <s> And , , , , , , , <unk> <unk>', '<s> And <unk> <unk> <unk> <unk>', '<s> <s>