In [1]:
import torch
import torch.nn as nn
import csv
import json
import math
from nltk import word_tokenize as tokenize
from tqdm import tqdm_notebook as tqdm
from multiprocessing import Pool

In [2]:
import sys
import json

vocab_name = "preprocessing-cnn-all/vocab.json"
num_threads = 6
vocab = json.load(open(vocab_name, 'r'))
vocab_inv = {ind:word for word, ind in vocab.items()}
print(sys.version)

3.6.8 |Anaconda, Inc.| (default, Dec 30 2018, 01:22:34) 
[GCC 7.3.0]


In [3]:
HIDDEN_DIM = 256
EMB_DIM = 50
INPUT_MAX = 2000
OUTPUT_MAX = 100
num_epochs = 20
save_rate = 4 #how many epochs per modelsave
clip_grad_norm = 15. #maximum gradient norm
continue_from = "models/Model3" # if none, put None
# continue_from = None
epsilon = 1e-10
VOC_SIZE = len(vocab)

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
print(device)

cuda


In [5]:
class Inference(nn.Module):
    def __init__(self, hidden_dim, emb_dim, input_len, output_len, voc_size):
        super(Inference, self).__init__()
        
        self.hidden_dim = hidden_dim
        self.emb_dim = emb_dim
        self.input_len = input_len
        self.output_len = output_len
        self.voc_size = voc_size
    
        self.emb_layer = nn.Embedding(voc_size, emb_dim)
        self.encoder = nn.LSTM(emb_dim, hidden_dim, num_layers=1, batch_first=True, bidirectional=True)
        
        self.decoder = nn.LSTM(emb_dim, hidden_dim*2, num_layers=1, batch_first=True)
        
        self.attention_softmax = nn.Softmax(dim=1)
        
        self.pro_layer = nn.Sequential(
            nn.Linear(hidden_dim*4, voc_size, bias=True),
            nn.Softmax(dim=-1)
        )
        self.pgen_layer = nn.Sequential(
            nn.Linear(4*hidden_dim+emb_dim, 1, bias=True),
            nn.Sigmoid()
        )
        
        
    def forward(self, x):
        batch_size =x.shape[0]
        # encoder
        x_emb = self.emb_layer(x)
#         print(x_emb.shape)
        x_bilstm,(h, c) = self.encoder(x_emb)
#         print(x_bilstm.shape, h.shape, c.shape)
#         h = h.transpose(0, 1).contiguous().view(1, batch_size, h.shape[-1]*2)
#         c = c.transpose(0, 1).contiguous().view(1, batch_size, c.shape[-1]*2)
        h = h.transpose(0, 1).contiguous()
        c = c.transpose(0, 1).contiguous()
        h = h.view(batch_size, 1, h.shape[-1]*2)
        c = c.view(batch_size, 1, c.shape[-1]*2)
        h = h.transpose(0, 1).contiguous()
        c = c.transpose(0, 1).contiguous()
#         print(x_bilstm.shape, h.shape, c.shape)
        

        
        ## decoder
#         ans_emb = self.emb_layer(ans)
#         print(ans_emb.shape)
        out_h, out_c = (h, c)
        first = True
        
        # batch, 1, emb; content: vocab['<bos>']
        t_bos = torch.tensor([vocab['<bos>']]).to(device)
#         print('bos', t_bos.shape)
        t_bos = t_bos.repeat(batch_size, 1)
#         print('rbos', t_bos.shape)
        decoder_input_emb = self.emb_layer(t_bos)
#         print('bosemb', decoder_input_emb.shape)
#         print(self.output_len)
        for i in range(self.output_len):
            w = decoder_input_emb
#             print('w', w.shape)
            out, (out_h, out_c) = self.decoder(w, (out_h, out_c))
#             print('out', out.shape)
            attention = torch.bmm(x_bilstm, out.transpose(1, 2)).view(batch_size,x_bilstm.shape[1])
            attention = self.attention_softmax(attention)

            pointer_prob = torch.zeros([batch_size, self.voc_size], dtype=torch.float).to(device)
            pointer_prob = pointer_prob.scatter_add_(dim=1, index=x, src=attention).view(batch_size, 1, self.voc_size)
            
            context_vector = torch.bmm(attention.view(batch_size, 1, self.input_len), x_bilstm)

            
            feature = torch.cat((out, context_vector), -1)
            
            pgen_feat = torch.cat((context_vector, out, w), -1)

            distri = self.pro_layer(feature)
            pgen = self.pgen_layer(pgen_feat)
            final_dis = pgen*distri + (1.-pgen)*pointer_prob + epsilon
            
            ans_indices = torch.argmax(final_dis, dim=-1, keepdim=False)
            decoder_input_emb = self.emb_layer(ans_indices)
            
            ######### decoder attention
            if first:
                decoder_attn = attention.view(batch_size, 1, attention.shape[1])
                #print(attention.shape) torch.Size([1, 50])
            else:
                decoder_attn = torch.cat((decoder_attn, attention.view(batch_size, 1, attention.shape[1])), 1)
            ######### end 
            
            if first:
                first = False
                ans_seq = ans_indices
            else:
                ans_seq = torch.cat((ans_seq, ans_indices), 1)
                       
        
        return ans_seq, decoder_attn

In [6]:
saved_model = torch.load(continue_from)

inf_model = Inference(HIDDEN_DIM, EMB_DIM, INPUT_MAX, OUTPUT_MAX, VOC_SIZE).to(device)
inf_model.load_state_dict(saved_model['model'])

In [7]:
def preprocess(s):
    s = s.lower()
    words = ['<bos>'] + tokenize(s) + ['<eos>']
    seq = []
    for w in words:        
        try:
            wid = vocab[w]
        except KeyError:
            wid = vocab["<unk>"]
        seq.append(wid)
    
    pad_len = INPUT_MAX - len(seq)
    seq = [vocab['<pad>']]*pad_len + seq
#     return torch.tensor([seq[-INPUT_MAX:]])
    return torch.tensor([seq[:INPUT_MAX]])

In [8]:
def readable(sent):
    try:
        end = sent.index('<eos>')
    except ValueError:
        end = len(sent)
    sent = " ".join(sent[:end])
    sent = sent.replace("<bos>", '').replace("<eos>", '').replace("<unk>", '-UNK-').replace("<pad>", '')
    sent = sent.capitalize()
    return " ".join(sent.split())
def tensor2sent(t):
    return [vocab_inv[wid.item()] for wid in t]

In [None]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt



def print_attn(a, doc, summ):   
    
    array = a.cpu().numpy()
    sc = 0.6
    fig, ax = plt.subplots(figsize=(30*sc, 100*sc))
    im = ax.imshow(array, cmap='hot')

    # We want to show all ticks...
    ax.set_xticks(np.arange(len(doc)))
    ax.set_yticks(np.arange(len(summ)))
    # ... and label them with the respective list entries
    ax.set_xticklabels(doc, fontsize=28*sc)#12
    ax.set_yticklabels(summ, fontsize=28*sc)#12

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")

    ax.set_title("Attention matrix", fontsize=20)
    fig.tight_layout()
    plt.show()

In [None]:
with torch.no_grad():    
    while True:
        rawdoc = input("Document to summarize: ")
        wordseq = preprocess(rawdoc)
        wordseq = wordseq.to(device)

        summ_sent = tensor2sent(wordseq.view(INPUT_MAX))
        print('\n[input]', readable(summ_sent))

        predict, attn = inf_model.forward(wordseq)
        predict = predict.view(OUTPUT_MAX)
#         print(attn[0].shape)



        sent = tensor2sent(predict)
        print('\n[output]\n', readable(sent))
#         print_attn(attn[0], summ_sent, sent)

Document to summarize: on wednesday , united states president donald trump signed an executive order escalating his administration 's campaign against chinese telecoms giant huawei , raising pressure on allies to follow suit in banning the company from their 5g and other networks . the us claims huawei , one of china 's most important companies , poses a spying risk to western technology infrastructure . the latest move against the firm comes amid a worsening trade war between beijing and washington , after talks expected to bring a breakthrough fell apart , resulting in billions of dollars in further tariffs from both sides . while some us allies -- notably australia and new zealand -- have followed trump 's lead on huawei , others have been more reticent . europe in particular is split over whether to ban the company , a market leader on 5g technology which is expected to be the lifeblood of the new economy . the huawei issue cuts to the heart of tensions between security and economi