In [1]:
import os
import sys
import time
import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt
import re
import unicodedata
import nltk
from nltk.tokenize.toktok import ToktokTokenizer
import json
import pickle
from sklearn.model_selection import train_test_split
%matplotlib inline
from nltk import sent_tokenize
from torch.autograd import Variable
from torch.optim import Adagrad

import math
import os
import random
import string

# Pytorch library for training
import torch
from torch import optim

import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

# from torchtext.data import Field, BucketIterator, Example


In [2]:
### Set device

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# device = "cpu"


##### Dataset contains:
##### data after padding: 
0) text_train_pad, 1) headline_train_pad, 
##### pad mask of data(1=padded, 0=not padded): 
2) text_train_padmask, 3) headline_train_padmask,
##### sentence length: 
4) text_train_len, 5) headline_train_len
##### out of vocabulary
6) text_train_oov, 7) headline_train_oov
#### dataset with unk embedding
8) text_train_no, 9) headline_train_no

In [3]:
traindata_zip = torch.load('./Dataset7/traindata_zip.pt')

devdata_zip = torch.load('./Dataset7/devdata_zip.pt')

testdata_zip = torch.load('./Dataset7/testdata_zip.pt')

embedding = np.load('./Dataset7/embedding.npy')

embedding_headline = np.load('./Dataset7/embedding_headline.npy')


### Load Vocabulary

In [4]:
class Vocabulary:
    PAD_token = 0   # Used for padding short sentences
    SOS_token = 1   # Start-of-sentence token
    EOS_token = 2   # End-of-sentence token
    UNK_token = 3   # Out-of-vocabulary token

    def __init__(self, name):
        self.name = name
        self.word2index = {"pad":0, "sos":1, "eos":2, "unk":3}
        self.word2count = {"pad":0, "sos":0, "eos":0, "unk":0}              
        self.index2word = {0: "pad", 1: "sos", 2: "eos", 3: "unk"}
        self.num_words = 4
        self.num_sentences = 0
        self.longest_sentence = 0

    def add_word(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.num_words
            self.word2count[word] = 1
            self.index2word[self.num_words] = word
            self.num_words += 1
        else:
            self.word2count[word] += 1
            
    def add_sentence(self, sentence):
        sentence_len = 0
        for word in sentence:           
            sentence_len += 1
            self.add_word(word)
        if sentence_len > self.longest_sentence:
            self.longest_sentence = sentence_len
        self.num_sentences += 1

    def to_word(self, index):
        return self.index2word[index]

    def to_index(self, word):
        return self.word2index[word]

In [5]:
# text_dictionary
with open('./Dataset7/text_vocabulary.pgn', 'rb') as text_dictionary_file: #for S2S and S2S+GAN
    text_vocabulary = pickle.load(text_dictionary_file)
# headline_dictionary
with open('./Dataset7/headline_vocabulary.pgn', 'rb') as headline_dictionary_file: #for S2S and S2S+GAN
    headline_vocabulary = pickle.load(headline_dictionary_file)

In [6]:
## Set batch size and split data after padding to batches
def batch_dataloader(data, Batch_size):
    data_dataloader = torch.utils.data.DataLoader(data, batch_size=Batch_size, shuffle=False, num_workers=0)
    
    return data_dataloader

In [7]:
## Training data batching
trainloader = batch_dataloader(traindata_zip, 100)

devloader = batch_dataloader(devdata_zip, 20)

testloader = batch_dataloader(testdata_zip, 20)

## Attention and Generator

In [8]:
## Prepare dataset for training
text_testrun, hl_testrun, text_train_padmask, headline_train_padmask, text_train_len, \
headline_train_len, text_train_oov, headline_train_oov, text_train_no, headline_train_no= next(iter(trainloader))


In [9]:
## If need to send to cuda
text_testrun = text_testrun.to(device)
hl_testrun = hl_testrun.to(device)

text_train_padmask = text_train_padmask.to(device)
headline_train_padmask = headline_train_padmask.to(device)

text_train_oov = text_train_oov.to(device)
headline_train_oov = headline_train_oov.to(device)

text_train_no = text_train_no.to(device)
headline_train_no = headline_train_no.to(device)

#### Set model parameter

In [10]:
## Set parameters for test run ------ CAN DELETE THIS CELL IN FUTURE
input_size = int(len(text_vocabulary.index2word.keys())+1)
output_size = int(len(headline_vocabulary.index2word.keys())+1)

vocab_size = input_size - 1
unk_idx = 3
use_coverage = True
use_p_gen = True

enc_emb_size = 200
dec_emb_size = 200
hid_size = 128

n_layers = 1
enc_dropout = 0.5
dec_dropout = 0.5

beam_size = 50
max_dec_steps = 10
min_dec_steps = 1

### Pointer Generator Encoder

In [11]:
class Encoder(nn.Module):
    def __init__(self, input_size, hid_size, emb_size, embedding):
        super(Encoder, self).__init__()
        self.embedding = embedding

        self.lstm = nn.LSTM(emb_size, hid_size, num_layers=1, batch_first=True, bidirectional=True)
        self.W_h = nn.Linear(hid_size * 2, hid_size * 2, bias=False)

    #seq_lens should be in descending order
    def forward(self, enc_input, enc_input_len, hidden_state):
        embedded = torch.tensor([[self.embedding[i] for i in enc_input[:, seq]] \
                                 for seq in range(enc_input.shape[1])]).permute(1, 0, 2)

        packed = pack_padded_sequence(embedded, enc_input_len, batch_first=True, enforce_sorted= False)
            
        output, hidden = self.lstm(packed, hidden_state)
        
        enc_outputs, _ = pad_packed_sequence(output, batch_first=True)
        enc_outputs = enc_outputs.contiguous()
        
        enc_feature = enc_outputs.view(-1, 2*hid_size)  # B * L x 2*hidden_dim
        enc_feature = self.W_h(enc_feature)

        return enc_outputs, enc_feature, hidden

In [12]:
# Testrun
# Call Encoder
encoder = Encoder(input_size, hid_size, enc_emb_size, embedding).to(device)

text_train_len = torch.sum(torch.where(text_testrun==0,0,1), axis = 1).tolist()

# Check output of Encoder
enc_outputs, enc_feature, final_hidden = encoder(text_testrun.to(device), text_train_len, None)

final_hidden[0].shape, final_hidden[1].shape, enc_outputs.shape

(torch.Size([2, 100, 128]),
 torch.Size([2, 100, 128]),
 torch.Size([100, 174, 256]))

### Reduce Dimension of Encoder

In [13]:
class ReduceState(nn.Module):
    def __init__(self):
        super().__init__()

        self.reduce_h = nn.Linear(hid_size * 2, hid_size)
        self.reduce_c = nn.Linear(hid_size * 2, hid_size)

    def forward(self, hidden):
        h, c = hidden
        
        h_in = h.transpose(0, 1).contiguous().view(-1, hid_size * 2)
        hid_reduced_h = F.relu(self.reduce_h(h_in))
        
        c_in = c.transpose(0, 1).contiguous().view(-1, hid_size * 2)
        hid_reduced_c = F.relu(self.reduce_c(c_in))

        return hid_reduced_h.unsqueeze(0), hid_reduced_c.unsqueeze(0)

In [14]:
# Test run
reduce = ReduceState().to(device)
reduce_hidden = reduce(final_hidden)
reduce_hidden[0].shape, reduce_hidden[1].shape

(torch.Size([1, 100, 128]), torch.Size([1, 100, 128]))

### Attention

In [15]:
class Attention(nn.Module):
    def __init__(self, hid_size, use_coverage):
        super().__init__()
        
        self.use_coverage = use_coverage
        
        # Coverage layer
        if use_coverage:
            self.w_c = nn.Linear(1, hid_size * 2, bias=False)
            
        self.decode_proj = nn.Linear(hid_size * 2, hid_size * 2)
        self.v = nn.Linear(hid_size * 2, 1, bias=False)

    def forward(self, h_c_hat, enc_outputs, enc_feature, enc_padding_mask, coverage):
        """""
        h_c_hat: hidden, cell from decoder
        enc_outputs: first output of encoder
        enc_feature: second output of encoder
        enc_padding_mask: text_padmask
        coverage: initialize: Variable(torch.zeros((batch_size, 2 * hid_size)))
        
        Return:
        context_vec: sum(attention weights)*encoder hidden states
        attn_dist: attention distribution
        coverage: updated coverage
        """""
        b, m, n = list(enc_outputs.size())

        dec_feature = self.decode_proj(h_c_hat) # B x 2*hid_size
        
        dec_feature_expanded = dec_feature.unsqueeze(1).expand(b, m, n).contiguous() # B x m x 2*hid_size
        
        dec_feature_expanded = dec_feature_expanded.view(-1, n)  # (B * m )x 2*hid_size

        attn_feature = enc_feature + dec_feature_expanded # (B * m) x 2*hid_size
        
        if self.use_coverage:
            coverage_input = coverage.view(-1, 1)  # (B * m) x 1
            coverage_feature = self.w_c(coverage_input)  # (B * m) x 2*hid_size
            att_feature = attn_feature + coverage_feature

        scores = torch.tanh(att_feature) # (B * m) x 2*hidden_dim
        scores = self.v(scores)  # (B * m) x 1
        scores = scores.view(-1, m)  # B x m

        attn_dist = F.softmax(scores, dim=1) * (1-enc_padding_mask) # B x m
        normalization_factor = attn_dist.sum(1, keepdim = True)
        
        attn_dist = attn_dist / normalization_factor

        attn_dist = attn_dist.unsqueeze(1)  # B x 1 x m
        context_vec = torch.bmm(attn_dist, enc_outputs)  # B x 1 x n
        context_vec = context_vec.view(-1, hid_size * 2)  # B x 2*hidden_dim

        attn_dist = attn_dist.view(-1, m)  # B x m

        if self.use_coverage:
            coverage = coverage.view(-1, m)
            coverage = coverage + attn_dist

        return context_vec, attn_dist, coverage


In [16]:
# Test run
attn_net = Attention(hid_size, True).to(device)

# Set inputs to attn_net
coverage = Variable(torch.zeros(text_testrun.size())).to(device)

h_decoder, c_decoder = reduce_hidden
h_c_hat = torch.cat((h_decoder.view(-1, hid_size), c_decoder.view(-1, hid_size)), 1).to(device)

# Check outputs of Attention class
context_vec, attn_dist, coverage = attn_net(h_c_hat, enc_outputs, enc_feature, text_train_padmask.to(device), coverage=coverage)

### Pointer Generator Decoder

In [17]:
class Decoder(nn.Module):
    def __init__(self, input_size, hid_size, vocab_size, emb_size, embedding_headline, use_coverage, use_p_gen):
        super().__init__()
        
        self.attn_net = Attention(hid_size, use_coverage)
        
        self.use_coverage = use_coverage   #True/False
        
        self.use_p_gen = use_p_gen  #True/False

        self.embedding = embedding_headline 

        self.x_context = nn.Linear(hid_size * 2 + emb_size, emb_size)

        self.lstm = nn.LSTM(emb_size, hid_size, num_layers=1, batch_first=True, bidirectional=False)

        if use_p_gen:
            self.p_gen_linear = nn.Linear(hid_size * 4 + emb_size, 1)

        #p_vocab
        self.out1 = nn.Linear(hid_size * 3, hid_size)
        self.out2 = nn.Linear(hid_size, vocab_size)
        

    def forward(self, target, h_c_1, enc_outputs, enc_feature, enc_padding_mask,
                cont_v, enc_oov_len, enc_batch, coverage, step):
        """
        target: headline batch
        h_c_1: reduced_state(enc_hidden)
        h_c_hat: updated hidden for attn_net
        enc_outputs: first output of encoder
        enc_feature: second output of encoder
        enc_padding_mask: 
        cont_v: context vector input to decoder (initialization: Variable(torch.zeros((batch_size, 2 * hid_size))))
        enc_oov_len: text_batch_oov
        enc_batch: text_train_pad (OOV has index)
        coverage: initialization: Variable(torch.zeros(text_batch.size()))
        
        extro_zeros: initialization: Variable(torch.zeros((batch_size, max_oov_len)))

        """
        if not self.training and step == 0:
            h_decoder, c_decoder = h_c_1
            h_c_hat = torch.cat((h_decoder.view(-1, hid_size),
                                 c_decoder.view(-1, hid_size)), 1)  # B x 2*hid_size
            context_vec, _, coverage_new = self.attn_net(h_c_hat, enc_outputs, enc_feature,
                                                  enc_padding_mask, coverage)
            coverage = coverage_new
        
        max_oov_len = torch.max(enc_oov_len)
        
        target_embbed = torch.tensor([self.embedding[i] for i in target]).float()
        
        x = self.x_context(torch.cat((cont_v, target_embbed), 1))
        lstm_out, h_c = self.lstm(x.unsqueeze(1), h_c_1)

        h_decoder, c_decoder = h_c
        h_c_hat = torch.cat((h_decoder.view(-1, hid_size),
                             c_decoder.view(-1, hid_size)), 1)  # B x 2*hid_size
        context_vec, attn_dist, coverage_new = self.attn_net(h_c_hat, enc_outputs, enc_feature,
                                                          enc_padding_mask, coverage)

        if self.training or step > 0:
            coverage = coverage_new

        p_gen = None
        
        if self.use_p_gen:
            p_gen_input = torch.cat((cont_v, h_c_hat, x), 1)  # B x (2*2*hid_size + emb_size)
            p_gen = self.p_gen_linear(p_gen_input)
            p_gen = torch.sigmoid(p_gen)

        output = torch.cat((lstm_out.view(-1, hid_size), cont_v), 1) # B x hid_size * 3
        output = self.out1(output) # B x hid_size

        output = self.out2(output) # B x vocab_size
        vocab_dist = F.softmax(output, dim=1)
        
        dist_size = vocab_dist.size(0)
        extra_zeros = torch.zeros((dist_size, max_oov_len), device=vocab_dist.device)

        if self.use_p_gen:
            vocab_dist_p = p_gen * vocab_dist
            attn_dist_p = (1 - p_gen) * attn_dist

            if extra_zeros is not None:
                vocab_dist_p = torch.cat([vocab_dist_p, extra_zeros], 1)

            final_dist = vocab_dist_p.scatter_add(1, enc_batch, attn_dist_p)
        else:
            final_dist = vocab_dist

        return final_dist, h_c, cont_v, attn_dist, p_gen, coverage


In [18]:
# Test run
decoder = Decoder(input_size, hid_size, vocab_size, dec_emb_size, embedding_headline, use_coverage=True, use_p_gen=True).to(device)


# Set inputs to decoder
cont_v =  Variable(torch.zeros((100, 2 * hid_size))).to(device)

max_oov_len = torch.max(text_train_oov).to(device)
h_c_1 = reduce_hidden

extra_zeros = Variable(torch.zeros((100, max_oov_len))).to(device)

# Check outputs from Decoder class
final_dist, h_c, cont_v, attn_dist, p_gen, coverage = decoder(hl_testrun[:,0].to(device), h_c_1, enc_outputs, enc_feature, 
                                                              text_train_padmask.to(device), cont_v, text_train_oov.to(device), 
                                                              text_testrun.to(device), coverage, step=1)

### Pointer Generator Model

In [19]:
class Parameters():
    ## Set parameter
    input_size = int(len(text_vocabulary.index2word.keys())+1)
    output_size = int(len(headline_vocabulary.index2word.keys())+1)

    vocab_size = input_size - 1
    unk_idx = 3
    use_coverage = True
    use_p_gen = True

    enc_emb_size = 200
    dec_emb_size = 200
    hid_size = 128

    n_layers = 1
    enc_dropout = 0.5
    dec_dropout = 0.5

    beam_size = 50
    max_dec_steps = 10
    min_dec_steps = 1
    
    cov_weight = 1.0
    
    # training and optimizer
    lr = 0.1
    opt_acc = 0.1

In [20]:
class Pointer_Generator(nn.Module):
    def __init__(self, para, encoder: Encoder, reduced_net: ReduceState, decoder: Decoder, device: torch.device):
        super().__init__()
        self.encoder = encoder
        self.reduced_net = reduced_net
        self.decoder = decoder
        self.device = device
        self.para = para
        

    def forward(self, text_batch, text_batch_len, text_batch_padmask, text_batch_oov, \
                headline_batch, headline_batch_len, headline_batch_no, headline_batch_padmask, hidden_state):
        """""
        text_batch: text batch with oov index (eg. text_train_pad)
        text_batch_len: len of sentence in the batch before padding (eg. text_train_len)
        
        text_batch_padmask: padding mask of each sentence in text batch (padded:1, no pad: 0), (eg. text_train_padmask)
        text_batch_oov: number of oov in each sentence in text batch (eg. text_train_oov)

        headline_batch: headline batch with oov index (eg. headline_train_pad)
        headline_batch_len: len of sentence in the batch before padding (eg. headline_train_len)
        
        headline_batch_no: headline batch with oov index == unk_idx == 3 (eg. headline_train_no)
        headline_batch_padmask: padding mask of each sentence in hl batch (padded:1, no pad: 0), (eg. headline_train_padmask)
        
        hidden_state: hidden state for GAN (hidden_state = None, if not specified)
        """""
        
        batch_size, max_len = headline_batch.shape
        headline_vocab_size = self.para.output_size - 1

        # tensor to store decoder's output
        outputs = torch.zeros(max_len, batch_size, headline_vocab_size)

        # last hidden & cell state of the encoder is used as the decoder's initial hidden state
        enc_outputs, enc_feature, enc_hidden = self.encoder(text_batch, text_batch_len, hidden_state)
        
        h_c_1 = self.reduced_net(enc_hidden)
        
        c_t_1 = Variable(torch.zeros((100, 2 * self.para.hid_size)))
        
        coverage = Variable(torch.zeros((batch_size, max(text_batch_len))))
        print(coverage.shape)
#         coverage = Variable(torch.zeros(text_batch.size()))
        
        dist = torch.zeros((batch_size, max_len, self.para.vocab_size))
        
        step_loss_rec = []
        
        for i in range(max_len):
            final_dist, h_c, c_t, attn_dist, p_gen, coverage_new = self.decoder(headline_batch_no[:,i], h_c_1, enc_outputs, enc_feature, 
                                                                  text_batch_padmask, c_t_1, text_batch_oov, 
                                                                  text_batch, coverage, i)
            # Record distribution for GAN
            dist[:, i, :] = final_dist
            
            
            # Calculate loss for this batch
            headline = headline_batch[:, i]
            total_dist = torch.gather(final_dist, 1, headline.unsqueeze(1)).squeeze()
            step_loss = - torch.log(total_dist + 1e-10)
            
            if self.para.use_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1)
                
                step_loss = step_loss + step_coverage_loss
                coverage = coverage_new
                
            step_mask = headline_batch_padmask[:, i]
            step_loss = step_loss + 1.0*step_mask
            step_loss_rec.append(step_loss)

        total_loss = torch.sum(torch.stack(step_loss_rec, 1), 1)
        
        batch_loss_average = total_loss/headline_batch_len
        
        # Final loss is the average of batch_loss_average
        loss = torch.mean(batch_loss_average)

        return dist, loss           


In [21]:
# Test run
para = Parameters()
encoder = Encoder(para.input_size, para.hid_size, para.enc_emb_size, embedding).to(device)
reduce = ReduceState().to(device)
decoder = Decoder(para.input_size, para.hid_size, para.vocab_size, para.dec_emb_size, embedding_headline, \
                  para.use_coverage, para.use_p_gen).to(device)

PG = Pointer_Generator(para, encoder, reduce, decoder, device).to(device)
dist, loss = PG(text_testrun.to(device), text_train_len, text_train_padmask.to(device), text_train_oov.to(device), \
          hl_testrun.to(device), headline_train_len, headline_train_no.to(device), \
                headline_train_padmask.to(device), hidden_state = None)

torch.Size([100, 174])


In [22]:
loss

tensor(99.4610, grad_fn=<MeanBackward0>)

### Pretrain pointer generator

In [23]:
para = Parameters()
encoder = Encoder(para.input_size, para.hid_size, para.enc_emb_size, embedding).to(device)
reduce = ReduceState().to(device)
decoder = Decoder(para.input_size, para.hid_size, para.vocab_size, para.dec_emb_size, embedding_headline, \
                  para.use_coverage, para.use_p_gen).to(device)

model_params = list(encoder.parameters()) + list(decoder.parameters()) + list(reduce.parameters())

optimizer = Adagrad(model_params, lr = para.lr, initial_accumulator_value = para.opt_acc)

In [24]:
def Pretrain_PGN(PG: Pointer_Generator, dataloader, optimizer, hidden_state, loss_decay = 0.9):
    
    PG.train()
    
    epoch_loss = 0
    
    for text_batch, hl_batch, text_batch_padmask, headline_batch_padmask, text_batch_len, \
    headline_batch_len, text_batch_oov, headline_batch_oov, \
    text_batch_no, headline_batch_no in dataloader:

        ## send to cuda
        text_batch = text_batch.to(device)
        hl_batch = hl_batch.to(device)
        
        text_batch_padmask = text_batch_padmask.to(device)
        headline_batch_padmask = headline_batch_padmask.to(device)
        
        text_batch_len = text_batch_len.to(device)
        headline_batch_len = headline_batch_len.to(device)
        
        text_batch_oov = text_batch_oov.to(device)
        headline_batch_oov = headline_batch_oov.to(device)
        
        text_batch_no = text_batch_no.to(device)
        headline_batch_no = headline_batch_no.to(device)
        
        _, loss = PG(text_batch, text_batch_len, text_batch_padmask, text_batch_oov, \
          hl_batch, headline_batch_len, headline_batch_no, headline_batch_padmask, hidden_state)

#         if epoch_loss == 0:
#             epoch_loss = loss
        
#         else:
#             epoch_loss = epoch_loss * loss_decay + (1 - loss_decay) * loss 
        
        loss.backward()
        optimizer.step()
        
#         epoch_loss += loss.item()

        if epoch_loss == 0:
            epoch_loss = loss.item()
        
        else:
            epoch_loss = epoch_loss * loss_decay + (1 - loss_decay) * loss.item()
   

    return epoch_loss / len(dataloader)

In [25]:
def Pretrain_evaluation(PG: Pointer_Generator, dataloader, hidden_state, loss_decay = 0.9):
    
    PG.train()
    
    epoch_loss = 0
    with torch.no_grad():
        for text_batch, hl_batch, text_batch_padmask, headline_batch_padmask, text_batch_len, \
        headline_batch_len, text_batch_oov, headline_batch_oov, \
        text_batch_no, headline_batch_no in dataloader:

            ## send to cuda
            text_batch = text_batch.to(device)
            hl_batch = hl_batch.to(device)

            text_batch_padmask = text_batch_padmask.to(device)
            headline_batch_padmask = headline_batch_padmask.to(device)

            text_batch_len = text_batch_len.to(device)
            headline_batch_len = headline_batch_len.to(device)

            text_batch_oov = text_batch_oov.to(device)
            headline_batch_oov = headline_batch_oov.to(device)

            text_batch_no = text_batch_no.to(device)
            headline_batch_no = headline_batch_no.to(device)

            _, loss = PG(text_batch, text_batch_len, text_batch_padmask, text_batch_oov, \
            hl_batch, headline_batch_len, headline_batch_no, headline_batch_padmask, hidden_state)

#             epoch_loss += loss.item()

        if epoch_loss == 0:
            epoch_loss = loss.item()
        
        else:
            epoch_loss = epoch_loss * loss_decay + (1 - loss_decay) * loss.item()

    return epoch_loss / len(dataloader)

### Beam search to output index

In [27]:
class Beam(object):
    def __init__(self, tokens, log_probs, state, context, coverage):
        self.tokens = tokens
        self.log_probs = log_probs
        self.state = state
        self.context = context
        self.coverage = coverage

    def extend(self, token, log_prob, state, context, coverage):
        return Beam(tokens = self.tokens + [token],
                          log_probs = self.log_probs + [log_prob],
                          state = state,
                          context = context,
                          coverage = coverage)

    @property
    def latest_token(self):
        return self.tokens[-1]

    @property
    def avg_log_prob(self):
        return sum(self.log_probs) / len(self.tokens)

In [28]:
class Beam_Search(object):
    def __init__(self, para: Parameters, encoder: Encoder, reduced_net: ReduceState, decoder: Decoder, device: torch.device):
        
        
        self.para = Parameters()
        self.unk_idx = 3
        
        # Call Encoder, ReducedState and Decoder
        self.encoder = encoder
        self.reduced_net = reduced_net
        self.decoder = decoder

    def sort_beams(self, beams):
        return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True)

    
    def beam_search(self, text, text_len, text_padmask, text_oov, beam_batch, hidden_state):
        """""
        text: one text data from the batch (eg. text_train_pad[1])
        text_len: length of text (witout padding) (eg. text_train_len[1])
        text_padmask: padding mask of the text, padded position = 1, not padded = 0 (eg. text_train_padmask[1])
        text_oov: number of oov in a text (eg. text_train_oov[1])
        
        hidden_state: hidden state for discriminator (=None, if not specified)
        """""
        ## Reapeat this text to form a batch containing this text only
        text_batch = torch.transpose(text.unsqueeze(1).repeat(1, beam_batch),0,1)
        
        text_len = torch.sum(torch.where(text==0,0,1)).tolist()
        
        text_batch_len = [text_len] * beam_batch
        
        text_batch_padmask = torch.transpose(text_padmask.unsqueeze(1).repeat(1, beam_batch),0,1)
        
        text_batch_oov =  torch.tensor([text_oov] * beam_batch).unsqueeze(1)
        
        # Initialize c_t and coverage
        c_t_0 = Variable(torch.zeros((100, 2 * self.para.hid_size)))
        coverage_t_0 = Variable(torch.zeros(text_batch.size()))
        
        
        # Call encoder
        enc_outputs, enc_feature, enc_hidden = self.encoder(text_batch, text_batch_len, hidden_state)
        
        h_c_0 = self.reduced_net(enc_hidden)

        dec_h, dec_c = h_c_0 # 1 x 2*hidden_size
        dec_h = dec_h.squeeze()
        dec_c = dec_c.squeeze()

        # Beam hypothesis
        beams = [Beam(tokens=[1],
                      log_probs=[0.0],
                      state=(dec_h[0], dec_c[0]),
                      context = c_t_0[0],
                      coverage=(coverage_t_0[0])) for _ in range(beam_batch)]
        
        results = []
        steps = 0
        while steps < self.para.max_dec_steps and len(results) < self.para.beam_size:
            latest_tokens = [h.latest_token for h in beams]
            latest_tokens = [t if t < self.para.vocab_size else self.unk_idx for t in latest_tokens]
            target = Variable(torch.LongTensor(latest_tokens))

            all_state_h =[]
            all_state_c = []

            all_context = []

            for h in beams:
                state_h, state_c = h.state
                all_state_h.append(state_h)
                all_state_c.append(state_c)

                all_context.append(h.context)

            h_c_1 = (torch.stack(all_state_h, 0).unsqueeze(0), torch.stack(all_state_c, 0).unsqueeze(0))
            c_t_1 = torch.stack(all_context, 0)
 
            coverage_t_1 = None
            if use_coverage:
                all_coverage = []
                for h in beams:
                    all_coverage.append(h.coverage)
                coverage_t_1 = torch.stack(all_coverage, 0)
            
            final_dist, h_c, c_t, attn_dist, p_gen, coverage_t = self.decoder(target, h_c_1, enc_outputs, enc_feature, 
                                                                  text_batch_padmask, c_t_1, text_batch_oov, 
                                                                  text_batch, coverage_t_1, steps)
            
            log_probs = torch.log(final_dist)
            topk_log_probs, topk_ids = torch.topk(log_probs, self.para.beam_size*2)

            dec_h, dec_c = h_c
            dec_h = dec_h.squeeze()
            dec_c = dec_c.squeeze()

            all_beams = []

            if steps == 0:
                num_orig_beams = 1
            else:
                num_orig_beams = len(beams)
                
            for i in range(num_orig_beams):
                h = beams[i]
                state_i = (dec_h[i], dec_c[i])
                context_i = c_t[i]
                coverage_i = (coverage_t[i])

                for j in range(self.para.beam_size * 2):  # for each of the top 2*beam_size hyps:
                    new_beam = h.extend(token=topk_ids[i, j].item(),
                                   log_prob=topk_log_probs[i, j].item(),
                                   state=state_i,
                                   context=context_i,
                                   coverage=coverage_i)
                    all_beams.append(new_beam)

            beams = []
            for h in self.sort_beams(all_beams):
                if h.latest_token == [2]:
                    if steps >= self.para.min_dec_steps:
                        results.append(h)
                else:
                    beams.append(h)
                if len(beams) == beam_batch or len(results) == self.para.beam_size:
                    break

            steps += 1

        if len(results) == 0:
            results = beams

        beams_sorted = self.sort_beams(results)

        return beams_sorted[0]

In [312]:
# Test run
para = Parameters()
encoder = Encoder(para.input_size, para.hid_size, para.enc_emb_size).to(device)
reduce = ReduceState()
decoder = Decoder(para.input_size, para.hid_size, para.vocab_size, para.dec_emb_size, \
                  para.use_coverage, para.use_p_gen)
BS = Beam_Search(para, encoder, reduce, decoder, device)

# Check result
result = BS.beam_search(text_testrun[1], text_train_len[1], text_train_padmask[1], text_train_oov[1], beam_batch=3, hidden_state = None)

RuntimeError: The size of tensor a (585) must match the size of tensor b (591) at non-singleton dimension 0

In [200]:
result.tokens

[1, 12, 12, 8, 24, 11]