In [1]:
%matplotlib inline
import numpy as np
import torch
from torch.autograd import Variable
import torch.nn.functional as F
from Attentions import *
import json
import pickle
from pytorch_pretrained_bert import BertTokenizer
from pytorch_pretrained_bert import BertModel
import seaborn as sns
import matplotlib.pyplot as plt

In [2]:
all_data = pickle.load(open("./training_0.pickle", "rb"))
print(len(all_data))
training = all_data[:-len(all_data)//10]
testing = all_data[-len(all_data)//10:]

max_doc_length = 100
max_summary_length = 20

1000000


## Network and Batching

In [3]:
_cuda = torch.cuda.is_available()

In [4]:
def genBatch(bs = 5, validation = False):
    data = training
    if (validation):
        data = testing
    indices = np.random.randint(0, len(data), (bs,))
    docs = [data[index]["story_tokens"] for index in indices]
    _pointers = [data[index]["pointers"] for index in indices]
    
    documents = []
    summaries = []
    pointers = []
    for doc in docs:
        doc = doc[1:]
        doc.insert(0, 101) #<- 101 is the token id for the CLS token
        while (len(doc) < max_doc_length):
            doc.append(0)
        doc = doc[:max_doc_length]
        documents.append(doc)
    #print(documents)  
    
    #print(indices)
    sums = [data[index]["summary_tokens"] for index in indices]
    #print(sums)
    for k in range(len(sums)):
        summ = sums[k]
        _point = _pointers[k]
        while (len(summ) < max_summary_length):
            summ.append(0)
        summ = summ[:max_summary_length]
        summaries.append(summ)
        points = np.zeros((len(summ),))
        _point_choice = np.asarray(_point) < max_summary_length
        _point = np.asarray(_point)[_point_choice]
        if (len(_point) > 0):
            points[_point] = 1
        pointers.append(points)
        
    if _cuda:
        documents = torch.LongTensor(documents).cuda()
        summaries = torch.LongTensor(summaries).cuda()
        segments = torch.zeros_like(documents).cuda()
        pointers = torch.FloatTensor(pointers).cuda()
    else:
        documents = torch.LongTensor(documents)
        summaries = torch.LongTensor(summaries)
        segments = torch.zeros_like(documents)
        pointers = torch.FloatTensor(pointers)
    mask = documents > 0
    
    return documents, segments, mask, summaries, pointers
    
d, se, m, su, po = genBatch()
print(d.size(), se.size(), m.size(), su.size(), po.size())

torch.Size([5, 100]) torch.Size([5, 100]) torch.Size([5, 100]) torch.Size([5, 20]) torch.Size([5, 20])


In [5]:
def resolvePreviouslyGeneratedText(arr, innerAttentionMatrix, resolutionMatrix):
    #_allPrev = torch.cat(arr, dim=1)
    _allPrev = arr
    prev_ = InnerAttention(_allPrev, innerAttentionMatrix)
    if (len(prev_.size()) == 2):
        prev_ = prev_.unsqueeze(1)
    prev_ = torch.sum(prev_, dim=1)
    return torch.matmul(prev_, resolutionMatrix)

In [6]:
class Summarizer(torch.nn.Module):
    def __init__(self, 
                 bert_model = "bert-base-uncased",
                 attention_dim = 512,
                 tf = True,
                 isCuda = True):
        super(Summarizer, self).__init__()
        self.bert_width = 768
        self.bert_model = bert_model
        self.iscuda = isCuda
        self.teacherForcing = tf
        if ("-large-" in self.bert_model):
            self.bert_width = 1024
        
        """ GRU """
        self.wz = torch.nn.Parameter(torch.rand((self.bert_width*3, self.bert_width)))
        self.wr = torch.nn.Parameter(torch.rand((self.bert_width*3, self.bert_width*3)))
        self.w_cand = torch.nn.Parameter(torch.rand((self.bert_width*3, self.bert_width)))
        
        torch.nn.init.normal_(self.wz)
        torch.nn.init.normal_(self.wr)
        torch.nn.init.normal_(self.w_cand)
        """ BERT """
        if (self.iscuda):
            self.bert = BertModel.from_pretrained(bert_model).cuda()
        else:
            self.bert = BertModel.from_pretrained(bert_model)

        """ UaHj, Wa, Va; Weights for the context vector"""
        self.ua = torch.nn.Parameter(torch.rand(self.bert_width, 256))
        self.wa = torch.nn.Parameter(torch.rand(self.bert_width, 256))
        self.va = torch.nn.Parameter(torch.rand((256,)))
        torch.nn.init.normal_(self.ua)
        torch.nn.init.normal_(self.wa)
        torch.nn.init.normal_(self.va)
        
        """ OUTPUT """
        self.output_ = torch.nn.Parameter(torch.ones((self.bert_width, 30000)))
        #torch.nn.init.normal_(self.output_)
        self.output_to_network_embedding = torch.nn.Embedding(30000, self.bert_width)
        self.dropout = torch.nn.Dropout(0.1)
        
    def init_hidden_state(self, size):
        if (self.iscuda):
            _prev_word = torch.LongTensor([[101]]).cuda()#<- this is basically the cls marker
        else:
            _prev_word = torch.LongTensor([[101]])
        _prev_word = _prev_word.repeat(size[0], 1)
        
        _hs = None
        if (self.cuda):
            _hs = torch.rand(size).cuda()
        else:
            _hs = torch.rand(size)
        
        torch.nn.init.normal_(_hs)
        return _hs, [_prev_word]
    
    def forward(self, docs, segments, masks, output_ts = 75, y = None, tf_prob = 0.25):
        pointers = []
        atts = []
        hs, _output_words = self.init_hidden_state((docs.size()[0],1, self.bert_width))
        
        _docs, _ = self.bert(docs, segments, masks, output_all_encoded_layers = False)
        _docs = _docs * masks.unsqueeze(-1).float()
        _docs = self.dropout(_docs)
        _x = _docs
        _uahj = torch.matmul(_docs, self.ua)
        generated_words = []
        
        for i in range(output_ts):
            _generatedContext = None
            w = _output_words[-1]
            y_in = self.output_to_network_embedding(w)    

            #context vector generation for the doc space
            _stwa = torch.matmul(hs, self.wa)
            _xatt = bahdanauAttention(_uahj, _stwa, self.va)
            _dcv = _docs * _xatt.unsqueeze(-1)
            doc_context_vector = torch.sum(_dcv, dim=1).unsqueeze(1)

            #gru gating
            _gru_in = torch.cat([y_in, hs, doc_context_vector], dim=-1)            
            z = torch.sigmoid(torch.matmul(_gru_in, self.wz))
            r = torch.sigmoid(torch.matmul(_gru_in, self.wr))
            
            #candidate hidden state and final hidden state for the gru
            _cand_in = _gru_in*r
            h_cand = torch.tanh(torch.matmul(_cand_in, self.w_cand))
            hs = (1-z)*hs + z*h_cand
            
            #generate the output word
            word = torch.matmul(self.dropout(hs), self.output_)
            generated_words.append(word)
            
            _word = F.softmax(word, dim=-1)
            _word = torch.max(_word, dim=-1)[1].detach()
            if (self.teacherForcing and (y is not None)):
                choice = np.random.randint(0, 100, (1,))[0]
                if (choice < tf_prob*100):
                    _word = y[:,i].unsqueeze(-1).detach()

            _output_words.append(_word)
            atts.append(_xatt.unsqueeze(1))
        
        return torch.cat(generated_words, dim=1), torch.cat(atts, dim=1)

## Training

In [7]:
continue_training = False

In [8]:
network = None
torch.cuda.empty_cache()

In [9]:
network = Summarizer(isCuda = _cuda).cuda()

continueForCause = "BestTrainingLoss"
if (_cuda):
    if torch.cuda.device_count() > 1:
        print("Using", torch.cuda.device_count(), "GPU(s)...")
        network = torch.nn.DataParallel(network)

    epoch_losses = []
    epoch_vals = []
    epoch_accs = []

    if (continue_training):
        network.load_state_dict(torch.load("./summarizer_" + continueForCause + ".h5"))
        _training_cylce = None
        with open("./Summarizer_training_cycle_" + continueForCause + ".json") as f:
            _training_cylce = json.loads(f.read())
            epoch_losses = _training_cylce["training_losses"]
            epoch_vals = _training_cylce["validation_losses"]
            print(epoch_losses, epoch_vals)

    network.cuda()

optimizer = torch.optim.Adam(network.parameters(), lr=1e-3)

alpha = 1.
beta = 1.
gamma = 1.


In [10]:
def viz(epoch):
    with torch.no_grad():
        d, se, m, su, po = genBatch(bs=8)
        su = su[:,:10]
        words, atts = network.forward(d, se, m, output_ts=10)
        l = CoverageLoss(atts)
        data = atts.squeeze(-1).cpu().numpy()
        for i in range(2):
            """
            if (epoch % 5 == 0):
                plt.figure(figsize=(105, 15))
                _s = sns.heatmap(data[i], annot=True, vmin=0.01, vmax = 1)
                fig = _s.get_figure()
                fig.savefig("epoch_" + str(epoch) + "_" + str(i) + ".png")
            """
            words2 = F.softmax(words, dim=1)
            print(words2.topk(10, dim=-1)[1][i])
            print(su[i])

In [11]:
def _save(cause):
    print("\t\t...saving model for cause", cause)
    torch.save(network.state_dict(), "./summarizer_" + cause + ".h5")
    with open("./Summarizer_training_cycle_"  + cause + ".json", "w") as f:            
        f.write(json.dumps(
            {
                "training_losses":epoch_losses,
                "validation_losses":epoch_vals,
                "validation_accuracy":epoch_accs,
            }
        ))
        f.close()

def ValidateModel(validation_bs = 2, epoch=0):
    all_su = None
    all_pred = None
    
    print("\n\tValidating...")
    with torch.no_grad():
        for i in range(10):

            d, se, m, su, po = genBatch(bs=validation_bs, validation = True)
            su = su[:,:10]
            words, atts = network.forward(d, se, m, output_ts=10)
            if (all_su is None):
                all_su = su
                all_pred = words
            else:
                all_su = torch.cat([all_su, su], dim=0)
                all_pred = torch.cat([all_pred, words], dim=0)
        val_loss = WordLoss(all_su, all_pred)
        print("\tValidation Loss:", np.round(val_loss.data.item(), 5))
        epoch_vals.append(val_loss.data.item())

    with open("./Summarizer_training_cycle.json", "w") as f:            
        f.write(json.dumps(
            {
                "training_losses":epoch_losses,
                "validation_losses":epoch_vals,
                "validation_accuracy":epoch_accs,
            }
        ))
        f.close()
    
    if (np.min(epoch_vals) == epoch_vals[-1]):
        _save("BestValidationLoss")        
    
    if (np.min(epoch_losses) == epoch_losses[-1]):
        _save("BestTrainingLoss")
    
    print("\tVisualizing...")
    viz(epoch)

In [12]:
pointerCriterion = torch.nn.BCEWithLogitsLoss()
wordCriterion = torch.nn.CrossEntropyLoss()

def CoverageLoss(attentions):
    """
    :param attentions: (b, yt, xt)
    :param coverages: (b, yt, xt)
    """
    coverage = None
    losses = None
    if (_cuda):
        coverage = torch.zeros((attentions.size()[0], attentions.size()[-1])).cuda()
    else:
        coverage = torch.zeros((attentions.size()[0], attentions.size()[-1]))
    
    losses = []
    for i in range(attentions.size()[1]):
        cov = torch.min(coverage, attentions[:,i,:])
        _ts_loss = torch.sum(cov, dim=1)
        losses.append(_ts_loss)
        coverage = coverage + attentions[:,i,:]
    
    losses = torch.stack(losses)
    _loss = torch.sum(losses)/(attentions.size()[0]*attentions.size()[1])
    return _loss

def PointerLoss(yPointers, y_Pointers):
    return pointerCriterion(y_Pointers, yPointers)

def WordLoss(yWords, y_Words):
    return wordCriterion(y_Words.view(-1,30000), yWords.flatten())

In [13]:
def annealTFRate(epoch):
    return 0.75 - epoch/100

def Train(epochs, batches_per_epoch, bs):
    for k in range(epochs):
        batch_losses = []
        b_wl = []
        b_pl = []
        b_cl = []
        for j in range(batches_per_epoch):
            optimizer.zero_grad()
            d, se, m, su, po = genBatch(bs=bs)
            su = su[:,:10]
            words, atts = network.forward(d, se, m, output_ts=10, y=su, tf_prob=annealTFRate(k))
            _word_loss = alpha * WordLoss(su, words)
            _coverage_loss = CoverageLoss(atts)
            total_loss = _word_loss + _coverage_loss

            b_wl.append(_word_loss.data.item())
            b_cl.append(_coverage_loss.detach().cpu().numpy())
            total_loss.backward()
            optimizer.step()

            batch_losses.append(total_loss.data.item())
            
            _str = "Epoch: " + str(k+1) + \
            ";  Batch: " + str(j+1)  + "/" + str(batches_per_epoch) + \
            "; Loss: " + str(np.round(np.mean(batch_losses), 5)) 
            
            _str = _str + " (" + str(np.round(np.mean(b_wl),5))  + \
            "," + str(np.round(np.mean(b_cl), 5)) + ")"
            
            print(_str, end = "\r")
        epoch_losses.append(np.mean(batch_losses))
        ValidateModel(validation_bs = bs, epoch=k)
        print("\n")

In [14]:
ts = max_summary_length
epochs = 40
batches_per_epoch = 2560

if (_cuda):
    if (torch.cuda.device_count() > 1):
        bs = 56 * torch.cuda.device_count()
    else:
        bs = 20
else:
    bs = 2

In [15]:
def bahdanauAttention(uahj, st_wa, va):
    _intermediate = torch.tanh(st_wa + uahj)
    _att = F.softmax(torch.matmul(_intermediate, va), dim=-1)
    return _att 

In [None]:
Train(epochs, batches_per_epoch, bs = bs)

Epoch: 1;  Batch: 2560/2560; Loss: 8.47417 (7.67243,0.80174)
	Validating...
	Validation Loss: 9.5803
		...saving model for cause BestValidationLoss
		...saving model for cause BestTrainingLoss
	Visualizing...
tensor([[ 2149,  4895,  1005,  7802,  3306,  2158,  8956,  2006, 10991,  2446],
        [ 1999,  4895, 23876,  2125, 21931,  1005,  2000, 15549,  2395,  8408],
        [ 1999,  1996,  5712,  2034,  5618,  4380,  2149,  2125,  2005,  2859],
        [13492, 15549,  2102, 20351,  2154,  4030,  4119, 16558,  8366,  6335],
        [11096,  3926,  2110,  2391,  5375,  6921,  4552,  4001,  2645,  2772],
        [ 3771, 11074, 14995,  9739, 20129,  2733, 12532,  3030, 16181, 12913],
        [11074, 24636, 12944, 23902,  2733,  9368, 26775, 13525, 13317, 27105],
        [    0,  2733,  2311, 21329,  2007,  2504,  3204,  4399, 11110,  3357],
        [    0,  2923,  8858, 24168,  2007,  2733,  2504,  3864,  9109,  1999],
        [    0,  2504,  8858,  2622,  2733,  2095,  4945,  2007,  4332,

	Visualizing...
tensor([[26614, 14611,  8248, 19667, 25510, 19722, 13376,  9600, 23790, 23637],
        [17673, 22177, 15436,  7330, 26646, 14611,  2761, 28052,  5625, 13494],
        [27597, 28052, 16369, 22834, 11191, 20081, 13754, 23959, 25510, 25119],
        [13494,  4102, 16482, 14459, 14611,  5720, 28227,  9584, 12246, 20871],
        [15007, 25750, 14202, 29534, 12228,  6076, 15985, 16482,  8399, 23646],
        [12228, 28227, 28176,  8399, 15007,  6087, 19737,  5625, 13302, 15285],
        [24168, 13112, 20486, 10747,  8621,  9250, 16260, 14540, 28227, 16096],
        [24947,  5543, 19737, 26868, 28227, 18257, 12071, 27922, 21549,  5162],
        [    7,     6,     4,     5,     1,     0,     2,     3,     8,     9],
        [12910,  5738, 13619, 27178, 11687, 26614,  3170,  8248, 11409,  6131]],
       device='cuda:0')
tensor([ 2093,  9302, 17671,  2915,  2757,  1999,  2642, 14474,     0,     0],
       device='cuda:0')
tensor([[21575,  8695, 11783, 16096, 17060, 10665, 19145

In [None]:
_save("LastForcedSave")

In [None]:
network = None
torch.cuda.empty_cache()

In [None]:
network = Summarizer()
if (torch.cuda.device_count() > 1):
    network = torch.nn.DataParallel(network)
network.load_state_dict(torch.load("./summarizer_BestTrainingLoss.h5"))
network.cuda()

In [None]:
from pytorch_pretrained_bert import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
with torch.no_grad():
    d, se, m, su, po = genBatch(bs=8)
    words, atts = network.forward(d, se, m, output_ts=20)
    words2 = F.softmax(words, dim=1)
    print(words2.topk(10, dim=-1)[1][0])
    print(su[0])
    
    #print(su)
    #print(words.size())

    w = torch.max(words2, dim=-1)[1]
    #print(w)
    _pred = tokenizer.convert_ids_to_tokens(w.cpu().numpy()[0])
    _act = tokenizer.convert_ids_to_tokens(su.cpu().numpy()[0])

In [None]:
" ".join(_pred)

In [None]:
" ".join(_act)