In [None]:
import pickle
data = pickle.load(open("../../../Data/DMQA/cnn_tokenized.pickle", "rb"))
print(len(data))

In [None]:
import numpy as np
import torch
from torch.autograd import Variable
import torch.nn.functional as F
from Attentions import *

In [None]:
from pytorch_pretrained_bert import BertTokenizer
from pytorch_pretrained_bert import BertModel
max_doc_length = 100
max_summary_length = 20

In [None]:
def genBatch(bs = 5):
    indices = np.random.randint(0, len(data), (bs,))
    docs = [data[index]["story_tokens"] for index in indices]
    _pointers = [data[index]["pointers"] for index in indices]
    
    documents = []
    summaries = []
    pointers = []
    for doc in docs:
        doc.insert(0, 101) #<- 101 is the token id for the CLS token
        while (len(doc) < max_doc_length):
            doc.append(0)
        doc = doc[:max_doc_length]
        documents.append(doc)
        
    sums = [data[index]["summary_tokens"] for index in indices]
    for k in range(len(sums)):
        summ = sums[k]
        _point = _pointers[k]
        summ.insert(0, 101)
        while (len(summ) < max_summary_length):
            summ.append(0)
        summ = summ[:max_summary_length]
        summaries.append(summ)
        points = np.zeros((len(summ),))
        _point_choice = np.asarray(_point) < max_summary_length
        _point = np.asarray(_point)[_point_choice]
        if (len(_point) > 0):
            points[_point] = 1
        pointers.append(points)
    
    documents = torch.LongTensor(documents)
    summaries = torch.LongTensor(summaries)
    segments = torch.zeros_like(documents)
    pointers = torch.FloatTensor(pointers)
    mask = documents > 0
    
    return documents, segments, mask, summaries, pointers
    
d, se, m, su, po = genBatch()
print(d.size(), se.size(), m.size(), su.size(), po.size())

In [None]:
def resolvePreviouslyGeneratedText(arr, innerAttentionMatrix, resolutionMatrix):
    _allPrev = torch.cat(arr, dim=1)
    prev_ = InnerAttention(_allPrev, innerAttentionMatrix)
    if (len(prev_.size()) == 2):
        prev_ = prev_.unsqueeze(1)
    prev_ = torch.sum(prev_, dim=1)
    return torch.matmul(prev_, resolutionMatrix)

In [None]:
class Summarizer(torch.nn.Module):
    def __init__(self, bert_model = "bert-base-uncased"):
        super(Summarizer, self).__init__()
        self.bert_width = 768
        self.bert_model = bert_model
        if ("-large-" in self.bert_model):
            self.bert_width = 1024
        
        #self.bertToModel = torch.nn.Linear(768, self.bert_width)
        self.wz = torch.nn.Parameter(torch.zeros((self.bert_width*2, self.bert_width)))
        self.wr = torch.nn.Parameter(torch.zeros((self.bert_width*2, self.bert_width*2)))
        self.w_cand = torch.nn.Parameter(torch.zeros((self.bert_width*4, self.bert_width)))
        
        self.bert = BertModel.from_pretrained(bert_model)
        self.innerXAttention = torch.nn.Parameter(torch.zeros((self.bert_width, 512)))
        
        self.innerPrevAttention = torch.nn.Parameter(torch.zeros((30000, 512)))
        self.prevToWidth = torch.nn.Parameter(torch.zeros((30000, self.bert_width)))
        self.attention_weights = torch.nn.Parameter(torch.zeros((self.bert_width, self.bert_width)))
        self.output_ = torch.nn.Parameter(torch.zeros(self.bert_width, 30000))
        
        self.pointer_out = torch.nn.Linear(self.bert_width * 3, 1)
        
    def init_hidden_state(self, size):
        _prev_word = self.output_[101] #<- this is basically the cls marker
        _prev_word = _prev_word.repeat(size[0], 1).unsqueeze(1)
        return torch.empty(size).uniform_(-1,1), [_prev_word]
    
    def forward(self, docs, segments, masks, output_ts = 75):
        coverages = []
        pointers = []
        atts = []
        hs, generated_words = self.init_hidden_state((docs.size()[0],1, self.bert_width))
        
        _docs, _ = self.bert(docs, segments, masks, output_all_encoded_layers = False)
        #_docs = self.bertToModel(_docs)
        _docs = _docs * masks.unsqueeze(-1).float()
        _x = InnerAttention(_docs, self.innerXAttention).unsqueeze(1)
        
        coverage = torch.zeros((docs.size()[0],docs.size()[1]))        
        for i in range(output_ts):
            #self attention and context vector generation of all previously generated words
            _generatedContext = resolvePreviouslyGeneratedText(generated_words, 
                                                  self.innerPrevAttention, 
                                                  self.prevToWidth)
            
            #gru gating
            _gru_in = torch.cat([_x, hs], dim=-1)
            z = torch.sigmoid(torch.matmul(_gru_in, self.wz))
            r = torch.sigmoid(torch.matmul(_gru_in, self.wr))
            
            #context vector generation for the doc space
            att = dotProductAttention(_docs, hs, self.attention_weights)
            doc_context_vector = torch.sum(_docs * att, dim=1).unsqueeze(1)
            
            #candidate hidden state and final hidden state for the gru
            _cand_in = torch.cat([_gru_in*r, doc_context_vector, _generatedContext.unsqueeze(1)], dim=-1)
            h_cand = torch.tanh(torch.matmul(_cand_in, self.w_cand))
            hs = (1-z)*hs + z*h_cand
            
            #pointer architecture
            _pointer_in = torch.cat([hs, doc_context_vector, _generatedContext.unsqueeze(1)], dim=-1)
            pointer = self.pointer_out(_pointer_in)
            
            #generate the output word
            word = torch.matmul(hs, self.output_)
            
            generated_words.append(word)
            pointers.append(pointer)
            
            coverage = coverage + att.squeeze(-1)
            add_cov = coverage.clone().unsqueeze(1)
            coverages.append(add_cov)
            atts.append(att.transpose(-2,-1))
        
        return torch.cat(generated_words[1:], dim=1), torch.cat(coverages, dim=1), torch.cat(pointers, dim=1).squeeze(-1), torch.cat(atts, dim=1)

In [None]:
pointerCriterion = torch.nn.BCEWithLogitsLoss()
wordCriterion = torch.nn.CrossEntropyLoss()

def CoverageLoss(attentions, coverages):
    """
    :param attentions: (b, yt, xt)
    :param coverages: (b, yt, xt)
    """
    l = torch.min(attentions, coverages) #b, yt, xt
    l = torch.sum(l, dim=-1) #eliminate the xt
    l = torch.sum(l, dim=-1) #eliminate the yt
    l = torch.sum(l)/(l.size()[0]) #eliminate the b
    return l

def PointerLoss(yPointers, y_Pointers):
    return pointerCriterion(y_Pointers, yPointers)

def WordLoss(yWords, y_Words):
    return wordCriterion(y_Words.view(-1,30000), yWords.flatten())

In [None]:
ts = 75
epochs = 10
batches_per_epoch = 5
s = Summarizer()
optimizer = torch.optim.Adam(s.parameters(), lr=1e-3)

alpha = 1.
beta = 1.
gamma = 1.

for k in range(epochs):
    for j in range(batches_per_epoch):
        optimizer.zero_grad()
        d, se, m, su, po = genBatch(bs=16)
        words, coverage, pointers, atts = s.forward(d, se, m, output_ts=ts)
        l = gamma * CoverageLoss(atts, coverage)
        l2 = beta * PointerLoss(po[:,:ts], pointers)
        l3 = alpha * WordLoss(su[:,:ts], words)
        total_loss = l3 + l2 + l
        total_loss.backward()
        optimizer.step()
        print(total_loss.data.item())