In [None]:
import numpy as np
import torch
from torch.autograd import Variable
import torch.nn.functional as F
from Attentions import *
import json
import pickle
from pytorch_pretrained_bert import BertTokenizer
from pytorch_pretrained_bert import BertModel

In [None]:
data = pickle.load(open("../../../Data/sumdata/training_0.pickle", "rb"))
print(len(data))
max_doc_length = 100
max_summary_length = 20

## Network and Batching

In [None]:
def genBatch(bs = 5):
    indices = np.random.randint(0, len(data), (bs,))
    docs = [data[index]["story_tokens"] for index in indices]
    _pointers = [data[index]["pointers"] for index in indices]
    
    documents = []
    summaries = []
    pointers = []
    for doc in docs:
        doc = doc[1:]
        doc.insert(0, 101) #<- 101 is the token id for the CLS token
        while (len(doc) < max_doc_length):
            doc.append(0)
        doc = doc[:max_doc_length]
        documents.append(doc)
    #print(documents)  
    
    #print(indices)
    sums = [data[index]["summary_tokens"] for index in indices]
    #print(sums)
    for k in range(len(sums)):
        summ = sums[k]
        _point = _pointers[k]
        while (len(summ) < max_summary_length):
            summ.append(0)
        summ = summ[:max_summary_length]
        summaries.append(summ)
        points = np.zeros((len(summ),))
        _point_choice = np.asarray(_point) < max_summary_length
        _point = np.asarray(_point)[_point_choice]
        if (len(_point) > 0):
            points[_point] = 1
        pointers.append(points)
    
    documents = torch.LongTensor(documents).cuda()
    summaries = torch.LongTensor(summaries).cuda()
    segments = torch.zeros_like(documents).cuda()
    pointers = torch.FloatTensor(pointers).cuda()
    mask = documents > 0
    
    return documents, segments, mask, summaries, pointers
    
d, se, m, su, po = genBatch()
print(d.size(), se.size(), m.size(), su.size(), po.size())

In [None]:
def resolvePreviouslyGeneratedText(arr, innerAttentionMatrix, resolutionMatrix):
    _allPrev = torch.cat(arr, dim=1)
    prev_ = InnerAttention(_allPrev, innerAttentionMatrix)
    if (len(prev_.size()) == 2):
        prev_ = prev_.unsqueeze(1)
    prev_ = torch.sum(prev_, dim=1)
    return torch.matmul(prev_, resolutionMatrix)

In [None]:
class Summarizer(torch.nn.Module):
    def __init__(self, 
                 bert_model = "bert-base-uncased",
                 attention_dim = 512,
                 cuda = True):
        super(Summarizer, self).__init__()
        self.bert_width = 768
        self.bert_model = bert_model
        if ("-large-" in self.bert_model):
            self.bert_width = 1024
        
        #self.bertToModel = torch.nn.Linear(768, self.bert_width)
        self.wz = torch.nn.Parameter(torch.zeros((self.bert_width*2, self.bert_width)))
        self.wr = torch.nn.Parameter(torch.zeros((self.bert_width*2, self.bert_width*2)))
        self.w_cand = torch.nn.Parameter(torch.zeros((self.bert_width*4, self.bert_width)))
        
        self.bert = BertModel.from_pretrained(bert_model).cuda()
        self.innerXAttention = torch.nn.Parameter(torch.zeros((self.bert_width, attention_dim)))
        
        self.innerPrevAttention = torch.nn.Parameter(torch.zeros((30000, attention_dim)))
        self.prevToWidth = torch.nn.Parameter(torch.zeros((30000, self.bert_width)))
        self.attention_weights = torch.nn.Parameter(torch.zeros((self.bert_width, self.bert_width)))
        self.output_ = torch.nn.Parameter(torch.zeros(self.bert_width, 30000))
        
        self.pointer_out = torch.nn.Linear(self.bert_width * 3, 1)
        
    def init_hidden_state(self, size):
        _prev_word = self.output_[101] #<- this is basically the cls marker
        _prev_word = _prev_word.repeat(size[0], 1).unsqueeze(1)
        return torch.empty(size).uniform_(-1,1).cuda(), [_prev_word.cuda()]
    
    def forward(self, docs, segments, masks, output_ts = 75):
        coverages = []
        pointers = []
        atts = []
        hs, generated_words = self.init_hidden_state((docs.size()[0],1, self.bert_width))
        
        _docs, _ = self.bert(docs, segments, masks, output_all_encoded_layers = False)
        #_docs = self.bertToModel(_docs)
        _docs = _docs * masks.unsqueeze(-1).float()
        _x = InnerAttention(_docs, self.innerXAttention).unsqueeze(1)
        
        coverage = torch.zeros((docs.size()[0],docs.size()[1])).cuda()        
        for i in range(output_ts):
            #self attention and context vector generation of all previously generated words
            _generatedContext = resolvePreviouslyGeneratedText(generated_words, 
                                                  self.innerPrevAttention, 
                                                  self.prevToWidth)
            
            #gru gating
            _gru_in = torch.cat([_x, hs], dim=-1)
            z = torch.sigmoid(torch.matmul(_gru_in, self.wz))
            r = torch.sigmoid(torch.matmul(_gru_in, self.wr))
            
            #context vector generation for the doc space
            att = dotProductAttention(_docs, hs, self.attention_weights)
            doc_context_vector = torch.sum(_docs * att, dim=1).unsqueeze(1)
            
            #candidate hidden state and final hidden state for the gru
            _cand_in = torch.cat([_gru_in*r, doc_context_vector, _generatedContext.unsqueeze(1)], dim=-1)
            h_cand = torch.tanh(torch.matmul(_cand_in, self.w_cand))
            hs = (1-z)*hs + z*h_cand
            
            #pointer architecture
            _pointer_in = torch.cat([hs, doc_context_vector, _generatedContext.unsqueeze(1)], dim=-1)
            pointer = self.pointer_out(_pointer_in)
            
            #generate the output word
            word = torch.matmul(hs, self.output_)
            
            generated_words.append(word)
            pointers.append(pointer)
            
            coverage = coverage + att.squeeze(-1)
            add_cov = coverage.clone().unsqueeze(1)
            coverages.append(add_cov)
            atts.append(att.transpose(-2,-1))
        
        return torch.cat(generated_words[1:], dim=1), torch.cat(coverages, dim=1), torch.cat(pointers, dim=1).squeeze(-1), torch.cat(atts, dim=1)

## Loss Functions

In [None]:
pointerCriterion = torch.nn.BCEWithLogitsLoss()
wordCriterion = torch.nn.CrossEntropyLoss()

def CoverageLoss(attentions, coverages):
    """
    :param attentions: (b, yt, xt)
    :param coverages: (b, yt, xt)
    """
    l = torch.min(attentions, coverages) #b, yt, xt
    l = torch.sum(l, dim=-1) #eliminate the xt
    l = torch.sum(l, dim=-1)/20 #eliminate the yt
    l = torch.sum(l)/(l.size()[0]) #eliminate the b
    return l

def PointerLoss(yPointers, y_Pointers):
    return pointerCriterion(y_Pointers, yPointers)

def WordLoss(yWords, y_Words):
    return wordCriterion(y_Words.view(-1,30000), yWords.flatten())

## Training

In [None]:
network = Summarizer().cuda()
optimizer = torch.optim.Adam(network.parameters(), lr=1e-3)

alpha = 1.
beta = 1.
gamma = 1.

epoch_losses = []
epoch_vals = []
epoch_accs = []

In [None]:
def _save(cause):
    torch.save(network.state_dict(), "./summarizer_" + cause + ".h5")
    with open("./Summarizer_training_cycle_"  + cause + ".json", "w") as f:            
        f.write(json.dumps(
            {
                "training_losses":epoch_losses,
                "validation_losses":epoch_vals,
                "validation_accuracy":epoch_accs,
            }
        ))
        f.close()

def saveModel():
    if (np.min(epoch_losses) == epoch_losses[-1]):
        print("Saving model for cause BestTrainingLoss")
        _save("BestTrainingLoss")
        with open("./Summarizer_training_cycle.json", "w") as f:            
            f.write(json.dumps(
                {
                    "training_losses":epoch_losses,
                    "validation_losses":epoch_vals,
                    "validation_accuracy":epoch_accs,
                }
            ))
            f.close()

In [None]:
ts = max_summary_length
epochs = 1
batches_per_epoch = 25000

for k in range(epochs):
    batch_losses = []
    b_wl = []
    b_pl = []
    b_cl = []
    for j in range(batches_per_epoch):
        optimizer.zero_grad()
        d, se, m, su, po = genBatch(bs=8)
        loss_mask = su > 0
        words, coverage, pointers, atts = network.forward(d, se, m, output_ts=ts)
        l = gamma * CoverageLoss(atts, coverage)
        
            
        l3 = alpha * WordLoss(su, words)
        b_wl.append(l3.data.item())
        
        total_loss = l3 
        if (torch.sum(po) > 1)
            l2 = beta * PointerLoss(po, pointers)
            b_pl.append(l2.data.item())
            b_cl.append(l.item())
            total_loss = total_loss + l2
        total_loss.backward()
        optimizer.step()
        batch_losses.append(total_loss.data.item())
        _str = "Epoch: " + str(k+1) + \
        ";  Batch: " + str(j+1)  + "/" + str(batches_per_epoch) + \
        "; Loss:" + str(np.round(np.mean(batch_losses), 5)) + \
        " (" + str(np.round(np.mean(b_wl),5))  + "," + str(np.round(np.mean(b_pl),5)) + \
        "," + str(np.round(np.mean(b_cl), 5)) + ")"
        print(_str, end = "\r")
    print("\n")
    epoch_losses.append(np.mean(batch_losses))
    saveModel()
    print("\n\t Epoch:", str(k+1), "; " + str(np.round(np.mean(batch_losses), 5)))
    print("\n")

In [None]:
from pytorch_pretrained_bert import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
with torch.no_grad():
    d, se, m, su, po = genBatch(bs=8)
    words, coverage, pointers, atts = network.forward(d, se, m, output_ts=ts)
    print(su)
    print(words.size())
    words2 = F.softmax(words, dim=1)
    w = torch.max(words2, dim=-1)
    print(w)
    _word = tokenizer.convert_ids_to_tokens(w.cpu().numpy()[1])

In [None]:
_word