In [None]:
import torch
from torch import nn, optim
from torch.autograd import Variable as var
from torch.nn import functional as F
import torchtext.vocab as vocab
from tqdm import tqdm
from pprint import pprint
import json
import _pickle as pkl

In [None]:
with open('../data/data.json', 'r') as f:
    data = json.load(f)

def clean(token):
    cleaned_token = token.strip(".,?!-:;'()[]\"`")
    if cleaned_token[-2:] == "'s":
        cleaned_token = cleaned_token[:-2]
    if cleaned_token[-2:] == "'t":
        cleaned_token = cleaned_token[:-2]+'t'
    return cleaned_token

def vectorize(input_txt, max_len):
    input_seq = [clean(w) for w in input_txt.split(" ") if len(clean(w).strip())]
    glove_vec = []
    for w in input_seq:
        try:
            glove_vec.append(glove.stoi[w])
        except:
            glove_vec.append(400001) # <unk> token
    if len(glove_vec)<max_len:
        padding_zeros = [400000]*(max_len-len(glove_vec)) # <pad> token
        glove_vec = padding_zeros + glove_vec
    return glove_vec[:max_len]
    
def make_data(raw_X):
    X = []
    y = []
    for (c, q, a) in raw_X:
        context_rep = vectorize(c.lower(), 600)
        ques_rep = vectorize(q.lower(), 100)
        X.append(context_rep+ques_rep) #only context for now
    return X

DIM=50
glove = vocab.GloVe(name='6B', dim=DIM)

glove.stoi['<pad>'] = len(glove.stoi)+1
glove.vectors = torch.cat((glove.vectors, torch.zeros(1, DIM)))
glove.stoi['<unk>'] = len(glove.stoi)+1 # add token->index for unknown/oov
glove.vectors = torch.cat((glove.vectors, torch.ones(1, DIM)*-1)) # add index->vec for unknown/oov

print(glove.vectors.size())
VOCAB_SIZE = glove.vectors.size()[0]

In [None]:
idx = 5
example_X = (data['X_train'][idx])
example_y = (data['y_train'][idx])
print("Context:", example_X[0])
print("Question:", example_X[1])
print("Answer Span:", example_y)
print("Answer:", example_X[2])
X = make_data([example_X])

In [None]:
num_ex = 4000
X_pass = make_data(data['X_train'][:num_ex])
y_pass = data['y_train'][:num_ex]

In [None]:
class ModelV1(nn.Module):
    def __init__(self, config):
        super(ModelV1, self).__init__()
        
        self.input_size = config.get("input_size", 700)
        self.hidden_size = config.get("hidden_size", 128)
        self.output_size = config.get("output_size", 5000)
        self.n_layers = config.get("n_layers", 1)
        self.vocab_size = config.get("vocab", VOCAB_SIZE)
        self.emb_dim = config.get("embedding_dim", DIM)
        self.bidir = config.get("Bidirectional", True)
        self.dirs = int(self.bidir)+1
        self.lr = config.get("learning_rate", 1e-3)
        self.batch_size = config.get("batch_size", 1)
        self.epochs = config.get("epochs", 5)
        self.opt = config.get("opt", "SGD")
        
        if self.opt == 'Adam':
            self.opt = optim.Adam
        else:
            self.opt = optim.SGD
        
        self.encoder = nn.Embedding(self.vocab_size, self.emb_dim)
        self.lstm = nn.LSTM(self.emb_dim, self.hidden_size, self.n_layers, bidirectional=self.bidir)
        self.decoder_start = nn.Linear(self.hidden_size, self.output_size)
        self.decoder_end = nn.Linear(self.hidden_size, self.output_size)
        self.init_weights()
    
    def init_weights(self):
        weight_scale = 0.01
        self.encoder.weight.data = glove.vectors
        self.decoder_start.bias.data.fill_(0)
        self.decoder_start.weight.data.uniform_(-weight_scale, weight_scale)
        self.decoder_end.bias.data.fill_(0)
        self.decoder_end.weight.data.uniform_(-weight_scale, weight_scale)

    def init_hidden(self, bs=None):
        if bs is None:
            bs = self.batch_size
        weight = next(self.parameters()).data
        return var(weight.new(self.n_layers*self.dirs, bs, self.hidden_size).zero_())
        
    def forward(self, inputs):
        if len(inputs)==1:
            inputs = var(torch.LongTensor(inputs[0]))
        else:
            inputs = var(torch.LongTensor(inputs))
        # print(inputs.size())
        embeds = self.encoder(inputs).permute(1, 0, 2) # get glove repr
        # print("embeds:", embeds.size())
        seq_len = embeds.size()[0]
        lstm_op, self.hidden = self.lstm(embeds, self.hidden)
        # print("lstm op:", lstm_op.size()) # (seq_len, bs, hidden_size*(dirs=2 for bi))
        lstm_op = lstm_op.permute(1, 0, 2) # (seq_len, bs, hdim)->(bs, seq_len, hdim)
        
        end_pred = lstm_op[:, -1, :self.hidden_size] # forward direction
        start_pred = lstm_op[:, -1, self.hidden_size:] # reverse direction
        
        # print("lstm start, end preds:", start_pred.size(), end_pred.size())
        out_start = F.log_softmax(self.decoder_start(start_pred), dim=-1)
        out_end = F.log_softmax(self.decoder_end(end_pred), dim=-1)
        # print("outs:", out_start.size(), out_end.size())
        out = torch.cat((out_start, out_end), 1)
        # print("out:", out.size())
        return out

    def fit(self, X, y):
        opt = self.opt(self.parameters(), self.lr)
        losses = [] # epoch loss
        for epoch in range(self.epochs):
            print("epoch:", epoch)
            bs = self.batch_size
            bloss = 0.0 # batch loss
            for bindex,  i in tqdm(enumerate(range(0, len(y)-bs+1, bs))):
                #print("batch:", bindex)
                h, c = self.init_hidden(), self.init_hidden()
                self.hidden = (h, c)
                # print(h.size(), c.size())
                opt.zero_grad()
                Xb = X[i:i+bs]
                Xb = torch.LongTensor(Xb)
                # print("Xb:", Xb.size())
                yb = var(torch.LongTensor(y[i:i+bs]))
                # print("yb:", yb.size())
                pred = self.forward(Xb) #prediction on batch features
            
                loss = F.nll_loss(pred[:, :self.output_size], yb[:, 0]) \
                     + F.nll_loss(pred[:, self.output_size:], yb[:, 1]) 
                bloss += loss.data[0]/bs
                # print(bloss)
                loss.backward()
                opt.step()
            losses.append(bloss)
            print("loss:", losses[-1], end=', change: ')
            if len(losses)>1:
                diff = losses[-2]-losses[-1]
                rel_diff = diff/losses[-2]
                print("%s"%rel_diff, "%")
            else:
                print("00.0%")
        return losses

    def predict(self, X, bs=None):
        self.hidden = (self.init_hidden(bs), self.init_hidden(bs))
        result = self.forward(X)
        return self.get_span_indices(result)
    
    def get_span_indices(self, preds):
        s_pred = preds[:, :self.output_size]
        e_pred = preds[:, self.output_size:]
        _,  s_index = torch.max(s_pred, -1)
        _,  e_index = torch.max(e_pred, -1)
        return torch.cat((s_index.unsqueeze(1), e_index.unsqueeze(1)), -1)

In [None]:
conf = {"learning_rate": 0.4, 
        "epochs": 10,
        "hidden_size": 128,
       "batch_size": 40,
       "opt": "Adam",
        }
model = ModelV1(conf)
print(model, model.lr, model.hidden_size, model.batch_size, model.opt)

In [None]:
res = model.fit(X_pass, y_pass)

In [None]:
for x, y in zip(data['X_val'][:20], data['y_val'][:20]):
    c = x[0]
    a = x[2]
    x = make_data([x])
    res = model.predict([x], bs=1).data.tolist()[0]
    print("Predicted span:", res)
    if res[0]>res[1]:
        res[0], res[1] = res[1], res[0]
        print("switched to:", res)
    print("Predicted Answer:", c[res[0]:res[1]])
    print("Actual:", a)
    print("="*50)

In [None]:
import matplotlib.pyplot as plt
plt.plot(list(range(len(res))), res)
plt.show()