In [1]:
import torch.nn as nn
import torch
import json
import numpy as np
from tqdm import tqdm_notebook as tqdm
from PointerGenerator import PointerGenerator
# from tqdm import tqdm
import math

In [2]:
folder = 'preprocessing-cnn-all/'
data_name = folder+'train_seq.json'
validation_name = folder+'valid_seq.json'
testdata_name = folder+'testdata_seq.json'
vocab_name = folder+'vocab.json'
wv_name = folder+'wv_matrix'

In [3]:
HIDDEN_DIM = 256
EMB_DIM = 50
INPUT_MAX = 400
OUTPUT_MAX = 100
num_epochs = 12
save_rate = 1 #how many epochs per modelsave
clip_grad_norm = 5. #maximum gradient norm
continue_from = "models/covModel5" # if none, put None
# continue_from = None
epsilon = 1e-10
validation_size = 5000

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
print(device)

cuda


In [5]:
vocab = json.load(open(vocab_name, 'r'))
VOC_SIZE = len(vocab)
# word_vectors = np.load(wv_name+".npy")
word_vectors = None

In [6]:
# vocab={'<pad>':0, '<bos>': 1, '<eos>': 2, '<unk>': 3}
from torch.utils import data

class Dataset(data.Dataset):    
    def __init__(self, data_name, vocab, cutoff=None):
        data = json.load(open(data_name, 'r'))
        sum_list = data['summary']
        data_list = data['document']
        
        if cutoff is not None:
            sum_list = sum_list[:cutoff]
            data_list = data_list[:cutoff]
        # idata -> list
        self.size = len(sum_list)
        self.dataset = []
        self.sum_len = 0
        
        for i in tqdm(range(len(sum_list))):
            if(len(data_list[i]) <= INPUT_MAX):
                data = [vocab['<pad>']]*(INPUT_MAX-len(data_list[i])) + data_list[i]
            else:
                data = data_list[i][:INPUT_MAX]
                
            if(len(sum_list[i]) <= OUTPUT_MAX):
                sum_in = sum_list[i] + [vocab['<pad>']]*(OUTPUT_MAX-len(sum_list[i]))
            else:
                sum_in = sum_list[i][:OUTPUT_MAX]
                
            sum_out_raw = sum_list[i][1:]
            if(len(sum_out_raw) <= OUTPUT_MAX):
                sum_out = sum_out_raw + [vocab['<pad>']]*(OUTPUT_MAX-len(sum_out_raw))
            else:
                sum_out = sum_out_raw[:OUTPUT_MAX]
                
            self.dataset.append([data, sum_in, sum_out])
     
    def __len__(self):
        return self.size
    def __getitem__(self, index):
        # output ari sum sumout
#         print([vocab['<bos>']]+self.dataset[index][1]+[vocab['<eos>']])
        return (torch.tensor(self.dataset[index][0]),\
                torch.tensor(self.dataset[index][1]),\
                torch.tensor(self.dataset[index][2]))

In [7]:
training_set = Dataset(data_name, vocab)
validation_set = Dataset(validation_name, vocab, cutoff=validation_size)
params = {'batch_size':8,
         'shuffle': True,
         'num_workers': 4}
training_generator = data.DataLoader(training_set, **params)
validation_generator = data.DataLoader(validation_set, **params)

HBox(children=(IntProgress(value=0, max=284367), HTML(value='')))




HBox(children=(IntProgress(value=0, max=2860), HTML(value='')))




In [8]:
model = PointerGenerator(HIDDEN_DIM, EMB_DIM, INPUT_MAX, OUTPUT_MAX, VOC_SIZE, word_vectors, coverage=True).to(device)
# optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-3)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5, weight_decay=1e-6)

if continue_from is None:    
    epoch = 0
else:
    saved_model = torch.load(continue_from)
    model.load_state_dict(saved_model['model'])    
    epoch = int(saved_model['epoch'] // 1) + 1


In [9]:
# class PointerGenerator2(nn.Module):
#     def __init__(self, hidden_dim, emb_dim, input_len, output_len, voc_size, word_vectors=None, eps=1e-10, coverage=False):
#         super(PointerGenerator2, self).__init__()
        
#         self.hidden_dim = hidden_dim
#         self.emb_dim = emb_dim
#         self.input_len = input_len
#         self.output_len = output_len
#         self.voc_size = voc_size
#         self.teacher_prob = 1.
#         self.epsilon = eps
#         self.coverage = coverage
        
#         self.emb_layer = model.emb_layer
#         self.encoder = model.encoder
#         self.decoder = model.decoder
        
#         self.attention_softmax = model.attention_softmax
        
#         self.pro_layer = model.pro_layer
#         self.pgen_layer = model.pgen_layer
        
#         self.cov_weight = nn.Parameter(torch.randn(1, dtype=torch.float)/10)
        
# model2 = PointerGenerator2(HIDDEN_DIM, EMB_DIM, INPUT_MAX, OUTPUT_MAX, VOC_SIZE, word_vectors, coverage=True).to(device)

# torch.save({
#                 'epoch': epoch - 1,
#                 'loss': saved_model['loss'],
#                 'val_loss': saved_model['val_loss'],
#                 'model': model2.state_dict()
#             }, './models/covModel' + str(epoch-1))  

In [10]:
loss_function = torch.nn.NLLLoss(ignore_index=vocab['<pad>'])

In [11]:
def validate(model, valid_gen):
    vlosses = []
    with torch.no_grad():
        for vi, (vin1, vin2, vout) in enumerate(validation_generator):
            vin1, vin2, vout = vin1.to(device), vin2.to(device), vout.to(device)

            vpredict, covloss = model.forward(vin1, vin2) ## teacher
#             vpredict = model.forward(vin1, vin2, 0.) ## full decode
            vloss = loss_function(vpredict.view(vpredict.shape[0]*vpredict.shape[1], VOC_SIZE), vout.view(vout.shape[0]*vout.shape[1]))
            vlosses.append(vloss.item())
    return sum(vlosses) / (len(vlosses)+epsilon)

In [12]:
train_losses = []
val_losses = []
while epoch <= num_epochs:
    

    print("Epoch", epoch)
    running_loss = 0.0
    iters = int(math.ceil(len(training_set)/params['batch_size']))
    for i, (in1, in2, out) in tqdm(enumerate(training_generator), total=iters):
        in1, in2, out = in1.to(device), in2.to(device), out.to(device)

        predict, covloss = model.forward(in1, in2)
        
        if torch.isnan(predict).any():
            raise RuntimeError
        
        loss = loss_function(predict.view(predict.shape[0]*predict.shape[1], VOC_SIZE), out.view(out.shape[0]*out.shape[1]))
        loss = loss + 1. * covloss[0]
        model.zero_grad()
        loss.backward()
        # gradient clipping
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_grad_norm, norm_type=2)
        optimizer.step()
        
        if np.isnan(loss.item()):
            raise RuntimeError
        running_loss += loss.item()
                
        if i % 2400 == 2399:
            val_loss = validate(model, validation_generator)
            train_losses.append(running_loss/(i+1))
            val_losses.append(val_loss)
            print('\nval_loss:', val_loss)            
        
#         print(running_loss/(i+1), end='\r')
        print(running_loss/(i+1), covloss[0].item(), model.cov_weight.item(),   end='\r')
    print('loss:', running_loss / iters)
#     train_losses.append(running_loss / iters)
    if epoch % save_rate == 0:
        torch.save({
                'epoch': epoch,
                'loss': train_losses,
                'val_loss': val_losses,
                'model': model.state_dict()
            }, './models/covModel' + str(epoch))    
    
    epoch += 1

Epoch 6


HBox(children=(IntProgress(value=0, max=35546), HTML(value='')))

2.9238118674169735 0.07659244537353516 -1.15426385402679441
val_loss: 2.982838207116173
2.9109369034716477 0.07607347518205643 -1.15918290615081793
val_loss: 2.9804309806335905
2.9161628306549545 0.08453889936208725 -1.15978419780731223
val_loss: 2.9751662042545557
2.9162335401401505 0.0825531929731369 -1.162144899368286197
val_loss: 2.979126229964655
2.9133666700815715 0.08794687688350677 -1.16031718254089368
val_loss: 2.97895218340295
2.911317778267838 0.07796411216259003 -1.164819002151489336
val_loss: 2.973126296889684
2.9114860364302078 0.06331135332584381 -1.16598284244537359
val_loss: 2.972892109241401
2.91168302490386 0.07533876597881317 -1.1669613122940063376
val_loss: 2.974148467931503
2.9114834577586124 0.07967305928468704 -1.17215764522552533
val_loss: 2.974524546601088
2.913330653994435 0.08784729987382889 -1.172015309333801335
val_loss: 2.9680621620654764
2.9141482907292735 0.06932928413152695 -1.17220580577850342
val_loss: 2.9671010747959863
2.913608938903667 0.081872686

HBox(children=(IntProgress(value=0, max=35546), HTML(value='')))

2.8829459402152726 0.06724505126476288 -1.18016541004180961
val_loss: 2.967766378511101
2.8843806076357827 0.08148952573537827 -1.18064033985137945
val_loss: 2.9681600745155525
2.8835949322501393 0.0825953334569931 -1.183231830596923846
val_loss: 2.9656995178592647
2.8843777030442506 0.10143079608678818 -1.18688428401947025
val_loss: 2.9642222347197964
2.8855992440969134 0.07087964564561844 -1.19099617004394535
val_loss: 2.964073338987279
2.8846577248380236 0.07225760072469711 -1.18833374977111822
val_loss: 2.964358540885779
2.8885186619826446 0.08442594110965729 -1.19077777862548833
val_loss: 2.9593787656141783
2.887278992866835 0.07139317691326141 -1.195430636405944839
val_loss: 2.962925726474169
2.8876992970894064 0.06562183797359467 -1.19702970981597915
val_loss: 2.967258476011417
2.8880464886222224 0.07324668020009995 -1.19461107254028322
val_loss: 2.9596244409755346
2.889097456534869 0.07138185948133469 -1.194617271423339862
val_loss: 2.963972473609885
2.8896076268357342 0.083600

HBox(children=(IntProgress(value=0, max=35546), HTML(value='')))

2.853701148543968 0.07540273666381836 -1.194965600967407277
val_loss: 2.9594898207219344
2.8632902555799555 0.09399731457233429 -1.19575774669647227
val_loss: 2.9628128978784907
2.865143145384234 0.07022473216056824 -1.196420788764953668
val_loss: 2.9595829065277317
2.867405683525603 0.0710943192243576 -1.1977131366729736346
val_loss: 2.9551684889705228
2.868992186743117 0.06709316372871399 -1.198863863945007354
val_loss: 2.9616872308632236
2.8704050432866595 0.055935367941856384 -1.2005281448364258
val_loss: 2.955148843745786
2.8710112460376718 0.08304592221975327 -1.20022845268249519
val_loss: 2.9564596534433045
2.871028612508048 0.07257931679487228 -1.203784584999084582
val_loss: 2.9545687117381876
2.8720308350293666 0.07024525851011276 -1.20548200607299898
val_loss: 2.9588430107630557
2.8723309487570097 0.07340016216039658 -1.20518517494201663
val_loss: 2.9579103322660187
2.8721575227360496 0.0853748470544815 -1.207008957862854732
val_loss: 2.952262439540486
2.872796053327435 0.082

HBox(children=(IntProgress(value=0, max=35546), HTML(value='')))

2.8455965693367675 0.06539100408554077 -1.21368491649627698
val_loss: 2.956498902269422
2.8486757790155526 0.07902847230434418 -1.21808457374572756
val_loss: 2.951428675117947
2.851537873056435 0.07045388966798782 -1.221939921379089437
val_loss: 2.9577519560651138
2.8539770283988646 0.0731828436255455 -1.219545602798462665
val_loss: 2.9547737420595963
2.857980412745021 0.08748210966587067 -1.217201232910156267
val_loss: 2.9589832012205544
2.857566317333696 0.058413855731487274 -1.21877753734588623
val_loss: 2.9538536404755935
2.8577789876480755 0.07212791591882706 -1.22111487388610849
val_loss: 2.9526402370881515
2.85807928932701 0.06730855256319046 -1.2196600437164307272
val_loss: 2.9488288517096612
2.857944864206576 0.09005158394575119 -1.219484210014343346
val_loss: 2.9538314335830793
2.8588499716904923 0.09337694197893143 -1.22174251079559338
val_loss: 2.951834317025367
2.8587447581966203 0.0901583656668663 -1.226122260093689647
val_loss: 2.9519756965788795
2.859480809344157 0.0758

HBox(children=(IntProgress(value=0, max=35546), HTML(value='')))

2.8401937452441905 0.09029575437307358 -1.22715640068054222
val_loss: 2.9512185627513507
2.839367817007613 0.0880170613527298 -1.2279794216156006453

KeyboardInterrupt: 

In [None]:
validate(model, validation_generator)