In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import gensim
import pickle
USE_CUDA = torch.cuda.is_available()

from data import load_squad_data,preprop,getBatch,pad_to_batch
from model import CoattentionEncoder, DynamicDecoder

In [2]:
MAX_LEN=400

In [3]:
dataset = load_squad_data('dataset/train-v1.1.json',MAX_LEN)
word2index,train_data = preprop(dataset)

Skipped 761, 86655 question/answer
Successfully Build 114855 vocabs
Preprop Complete!


In [4]:
pickle.dump(word2index,open('dataset/vocab.squad','wb'))
pickle.dump(train_data,open('dataset/train.squad','wb'))

In [None]:
%%time
#python3 -m gensim.scripts.glove2word2vec --input  glove.840B.300d.txt --output glove.840B.300d.w2vformat.txt
model = gensim.models.KeyedVectors.load_word2vec_format('dataset/glove.840B.300d.w2vformat.txt')

In [5]:
# oov=[]
# for k in word2index.keys():
#     if k not in ['<pad>','<unk>','<s>','</s>'] and model.vocab.get(k) is None:
#         oov.append(k)
# for o in oov:
#     word2index.pop(o)
# print(len(oov),len(word2index))

22527 92328


In [None]:
pretrained = []

for i in range(len(word2index)):
    try:
        pretrained.append(model[word2index[i]])
    except:
        pretrained.append(np.zeros(300))
        
pretrained_vectors = np.vstack(pretrained)

In [5]:
# del oov
# del pretrained
# del model

# word2index,train_data = preprop(dataset,word2index)

In [5]:
RESTORE=True
EMBED_SIZE=300
HIDDEN_SIZE=200
MAXOUT_POOL=4
MAX_ITER=4
BATCH_SIZE=64
STEP=20
LR=0.001
encoder = CoattentionEncoder(len(word2index),EMBED_SIZE,HIDDEN_SIZE)
decoder = DynamicDecoder(HIDDEN_SIZE,MAXOUT_POOL,max_iter=MAX_ITER)
if RESTORE is False:
    encoder.init_embed(pretrained_vectors,is_static=False)

if RESTORE:
    encoder.load_state_dict(torch.load('models/enc_params.pkl'))
    decoder.load_state_dict(torch.load('models/dec_params.pkl'))

if USE_CUDA:
    encoder.use_cuda=True
    decoder.use_cuda=True
    encoder = encoder.cuda()
    decoder = decoder.cuda()
loss_function = nn.CrossEntropyLoss()
enc_optim = optim.Adam(filter(lambda p: p.requires_grad, encoder.parameters()),lr=LR)
dec_optim = optim.Adam(decoder.parameters(),lr=LR)

In [125]:
LR=0.00001
STEP=10
enc_optim = optim.Adam(filter(lambda p: p.requires_grad, encoder.parameters()),lr=LR)
dec_optim = optim.Adam(decoder.parameters(),lr=LR*5)

In [126]:
for step in range(STEP):
    losses=[]
    for i,batch in enumerate(getBatch(BATCH_SIZE,train_data)):
        documents,questions,starts,ends = pad_to_batch(batch,word2index)

        encoder.zero_grad()
        decoder.zero_grad()
        U = encoder(documents,questions,True)
        _,_,entropies = decoder(U,True)

        s_ents, e_ents = list(zip(*entropies)) # x MAX_ITER
        loss_start,loss_end=0,0
        for m in range(MAX_ITER+1):
            loss_start+=loss_function(s_ents[m],starts.view(-1))
            loss_end+=loss_function(s_ents[m],ends.view(-1))

        loss = loss_start+loss_end
        losses.append(loss.data[0])
        loss.backward()
        #torch.nn.utils.clip_grad_norm(encoder.parameters(), 50) # gradient clipping
        #torch.nn.utils.clip_grad_norm(decoder.parameters(), 50) 
        enc_optim.step()
        dec_optim.step()
        
        if i % 100 == 0:
            print("[%d/%d] [%d/%d] loss : %.3f" % (step,STEP,i,len(train_data)//BATCH_SIZE,np.mean(losses)))
            losses=[]

[0/10] [0/1353] loss : 4.921
[0/10] [100/1353] loss : 5.015
[0/10] [200/1353] loss : 5.150
[0/10] [300/1353] loss : 5.116
[0/10] [400/1353] loss : 5.094
[0/10] [500/1353] loss : 5.113
[0/10] [600/1353] loss : 5.069
[0/10] [700/1353] loss : 5.080
[0/10] [800/1353] loss : 5.013
[0/10] [900/1353] loss : 5.058
[0/10] [1000/1353] loss : 5.037
[0/10] [1100/1353] loss : 5.041
[0/10] [1200/1353] loss : 5.062
[0/10] [1300/1353] loss : 4.981
[1/10] [0/1353] loss : 4.770
[1/10] [100/1353] loss : 4.985
[1/10] [200/1353] loss : 4.981
[1/10] [300/1353] loss : 5.025
[1/10] [400/1353] loss : 5.004
[1/10] [500/1353] loss : 5.014
[1/10] [600/1353] loss : 5.074
[1/10] [700/1353] loss : 5.054
[1/10] [800/1353] loss : 4.995
[1/10] [900/1353] loss : 5.058
[1/10] [1000/1353] loss : 4.999
[1/10] [1100/1353] loss : 5.064
[1/10] [1200/1353] loss : 5.097
[1/10] [1300/1353] loss : 5.023
[2/10] [0/1353] loss : 4.827
[2/10] [100/1353] loss : 5.024
[2/10] [200/1353] loss : 4.906
[2/10] [300/1353] loss : 5.029
[2/10]

KeyboardInterrupt: 

### Test 

In [6]:
import random
index2word={v:k for k,v in word2index.items()}

In [155]:
encoder = encoder.cuda()
decoder = decoder.cuda()

In [157]:
test = random.choice(test_data)

U = encoder(test[0],test[1])
s,e,entropies = decoder(U)

#         s_ents, e_ents = list(zip(*entropies)) 
test_paragraph=[index2word[p] for p in test[0].data.tolist()[0]]
print(" ".join(test_paragraph))
print(" ")
print(" ".join([index2word[p] for p in test[1].data.tolist()[0]]))
print(" ")
print("Prediction : "," ".join(test_paragraph[s.data[0]:e.data[0]+1]))
print("Groud Truth : "," ".join(test_paragraph[test[2].data.tolist()[0][0]:test[3].data.tolist()[0][0]+1]))

Sky UK Limited ( formerly British Sky Broadcasting or BSkyB ) is a British telecommunications company which serves the United Kingdom . Sky provides television and broadband internet services and fixed line telephone services to consumers and businesses in the United Kingdom . It is the UK 's largest pay-TV broadcaster with 11 million customers as of 2015 . It was the UK 's most popular digital TV service until it was overtaken by Freeview in April 2007 . Its corporate headquarters are based in <unk> .
 
How many customers does Sky UK Limited have as a pay-TV broadcaster as of 2015 ?
 
Prediction :  million
Groud Truth :  11 million


In [57]:
test_data = load_squad_data('./dataset/dev-v1.1.json')
word2index, test_data = preprop(test_data,word2index)

Skipped 177, 10384 question/answer
Successfully Build 114855 vocabs
Preprop Complete!


In [159]:
overlap=0
predicted=0
truth=0

for test in test_data:
    U = encoder(test[0],test[1])
    s,e,entropies = decoder(U)
    
    pred_span = list(range(s.data[0],e.data[0]+1))
    truth_span = list(range(test[2].squeeze(0).data[0],test[3].squeeze(0).data[0]+1))
    overlap+=len(set(truth_span) & set(pred_span))
    predicted+=len(pred_span)
    truth+=len(truth_span)
#     break
precision = overlap/predicted
recall = overlap/truth

f1_score = 2*precision*recall/(precision+recall)
print(f1_score)

0.06666666666666667


In [151]:
import datetime
cdate = datetime.datetime.strftime(datetime.datetime.now(),"%m_%d")

In [235]:
if USE_CUDA:
    encoder = encoder.cpu()
    decoder = decoder.cpu()
torch.save(encoder.state_dict(), 'models/enc_params_'+cdate+'.pkl')
torch.save(decoder.state_dict(), 'models/dec_params_'+cdate+'.pkl')