In [61]:
import random
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import functional as F

from qanta.datasets.quiz_bowl import QuestionDatabase
from qanta.util.constants import GUESSER_TRAIN_FOLD, GUESSER_DEV_FOLD
from qanta.guesser.rnn_entity import RnnEntityGuesser, BatchedDataset,\
    clean_question, repackage_hidden
from qanta.guesser.nn import convert_text_to_embeddings_indices
from qanta.preprocess import tokenize_question

In [2]:
guesser = RnnEntityGuesser.load('output/guesser/qanta.guesser.rnn_entity.RnnEntityGuesser/')
guesser.model = guesser.model.cuda()
criterion = nn.CrossEntropyLoss(reduce=False)

In [3]:
extracted_grads = {}

def extract_grad_hook(name):
    def hook(grad):
        extracted_grads[name] = grad
    return hook

def extract_grad_hook(name):
    def hook(grad):
        extracted_grads[name] = grad
    return hook

guesser.model = guesser.model.cuda()
guesser.model.eval()

RnnEntityModel(
  (dropout): Dropout(p=0.25)
  (word_embeddings): Embedding(139580, 300, padding_idx=0)
  (rnn): WeightDrop(
    (module): GRU(300, 1000, batch_first=True, dropout=0.25, bidirectional=True)
  )
  (classification_layer): Sequential(
    (0): Linear(in_features=2000, out_features=8297)
    (1): BatchNorm1d(8297, eps=1e-05, momentum=0.1, affine=True)
    (2): Dropout(p=0.15)
  )
)

In [161]:
questions = QuestionDatabase().all_questions().values()
questions = [q for q in questions if q.fold == GUESSER_DEV_FOLD]

In [117]:
def get_stuff(question_list):
    x_test_tokens = [x for x in question_list]
    y_test = np.zeros(len(question_list))
    dataset = BatchedDataset(
        guesser.batch_size, guesser.multi_embedding_lookup, guesser.rel_position_vocab, guesser.rel_position_lookup,
        x_test_tokens, y_test,
        truncate=False, shuffle=False, train=False
    )

    grads = []
    outputs = []
    losses = []
    hidden = guesser.model.init_hidden(guesser.batch_size)

    for b in range(len(dataset.t_x_w_batches)):
        t_x_w_batch = Variable(dataset.t_x_w_batches[b])
        t_x_pos_batch = Variable(dataset.t_x_pos_batches[b])
        t_x_iob_batch = Variable(dataset.t_x_iob_batches[b])
        t_x_type_batch = Variable(dataset.t_x_type_batches[b])
        t_x_mention_batch = Variable(dataset.t_x_mention_batches[b])

        length_batch = dataset.length_batches[b]
        sort_batch = dataset.sort_batches[b]

        hidden = guesser.model.init_hidden(len(length_batch))
        
#         if len(length_batch) != guesser.batch_size:
#             # This could happen for the last batch which is shorter than batch_size
#             hidden = guesser.model.init_hidden(len(length_batch))
#         else:
#             hidden = repackage_hidden(hidden, reset=True)

        guesser.model.eval()
        out, hidden = guesser.model(
            t_x_w_batch, t_x_pos_batch, t_x_iob_batch, t_x_type_batch, t_x_mention_batch,
            length_batch, hidden, extract_grad_hook('embed')
        )
        scores, preds = torch.max(out, 1) # take gradient w.r.t top guess
        outputs.append(out.data)

        guesser.model.zero_grad()
        loss = criterion(out, preds)
        losses.append(loss.data)
        
        loss.sum().backward()
        batch_size, length = t_x_w_batch.size()
        embed_grad = extracted_grads['embed'].contiguous()
        embed = guesser.model.word_embeddings(t_x_w_batch)
        onehot_grad = embed.view(-1) * embed_grad.view(-1)
        onehot_grad = onehot_grad.view(batch_size, length, -1).sum(-1)
        onehot_grad = onehot_grad.data.cpu().numpy()
        grads.append(onehot_grad)
    
    grads = np.concatenate(grads)
    outputs = torch.cat(outputs)
    losses = torch.cat(losses)
    
    return grads, outputs, losses

In [6]:
# def greedy_remove(question_list):
#     _xs = [list(guesser.nlp(clean_question(x.flatten_text()))) for x in question_list]
#     _ys = [x[0][0] for x in guesser.guess(_xs, 10, tokenize=False)]

#     removed_indices = [[] for _ in question_list]
#     indices = [list(range(len(x))) for x in _xs]
#     alive = [True for _ in _xs]
    
#     xs = list(_xs)
        
#     while True:
#         onehot_grad, out, loss = get_onehot_grad(xs)
#         for i, x in enumerate(xs):
#             if len(x) == 1:
#                 alive[i] = False # stop removing when there is only one token left
#             if alive[i]:
#                 drop_idx = np.argmax(onehot_grad[i][:len(x)])
#                 removed_indices[i].append(indices[i][drop_idx])
#                 indices[i] = indices[i][:drop_idx] + indices[i][drop_idx + 1:]
#                 x = x[:drop_idx] + x[drop_idx + 1:]
#             xs[i] = x

#         if sum(alive) == 0:
#             break
        
#         pred = [x[0][0] for x in guesser.guess(xs, 10, tokenize=False)]
#         for i, (x, y, z) in enumerate(zip(xs, _ys, pred)):
#             if z != y:
#                 alive[i] = False
#                 removed_indices[i] = removed_indices[i][:-1]
#     return removed_indices

In [183]:
def beam_search_remove(question, max_beam_size=10):
    original = guesser.guess([question], 10, tokenize=False)[0][0][0]
    
    removed_indices = [[]]
    indices = [list(range(len(question)))]
    
    xs = [list(question)]
    
    while True:        
        print(len(xs[0]), end=' ')
        
        if len(xs) == 0:
            break
        
        assert len(removed_indices) == len(xs)
        assert len(indices) == len(xs)
        
        onehot_grad, _, _ = get_stuff(xs)
        new_xs = []
        new_removed = []
        new_indices = []

        for i, x in enumerate(xs):
            order = np.argsort(onehot_grad[i][:len(x)])[:max_beam_size]
            for k in order:
                new_xs.append(x[:k] + x[k + 1:])
                new_removed.append(removed_indices[i] + [indices[i][k]])
                new_indices.append(indices[i][:k] + indices[i][k + 1:])
        
        guesses = [x[0] for x in guesser.guess(new_xs, 10, tokenize=False)]
        # print(sum(g == original for g, s in guesses))
        
        indices = [(i, (g, s)) for i, (g, s) in enumerate(guesses) if g == original]
        indices = [i for i, _ in sorted(indices, key=lambda x: x[1][1])[:max_beam_size]]
        if len(indices) == 0:
            return removed_indices
        
        xs = [new_xs[i] for i in indices]
        removed_indices = [new_removed[i] for i in indices]
        indices = [new_indices[i] for i in indices]
        # guesses = [x[0][0] for x in guesser.guess(xs, 10, tokenize=False)]
        # print(sum(x == original for x in guesses))
        # print('----------')

In [190]:
removed_indices = []
for q in questions[:5]:
    q = list(guesser.nlp(clean_question(q.flatten_text())))
    removed_indices.append(beam_search_remove(q, 20)[0])
    print()

  result = self.forward(*input, **kwargs)
  probs = F.softmax(out)


74 73 72 71 70 69 68 67 66 65 64 63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32 31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 9 8 7 6 5 4 3 2 1 
53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32 31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 9 8 7 6 5 4 3 2 1 
77 76 75 74 73 72 71 70 69 68 67 66 65 64 63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32 31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 9 8 7 6 
74 73 72 71 70 69 68 67 66 65 64 63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32 31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 9 8 7 6 5 4 
43 42 41 40 39 38 37 36 35 34 33 32 31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 9 8 7 6 5 4 3 2 


In [194]:
for i, removed in enumerate(removed_indices):
    q = list(guesser.nlp(clean_question(questions[i].flatten_text())))
    guesses = guesser.guess([q], 10, tokenize=False)[0][0]
    print(guesses)
    qq = [w for i, w in enumerate(q) if i not in removed]
    guesses = guesser.guess([qq], 10, tokenize=False)[0][0]
    print(' '.join(w.lower_ for w in qq), guesses)
    print()
    print()

('Around_the_World_in_Eighty_Days', 0.16779573)
fix ('Around_the_World_in_Eighty_Days', 0.088940158)


('Alexander_the_Great', 0.99782467)
horned ('Alexander_the_Great', 0.02978701)


('Seminole_Wars', 0.10996531)
wars them costing mexican everglades its ('Seminole_Wars', 0.032432083)


('Joseph_Stalin', 0.35837778)
this soviet epidemics occidental ('Joseph_Stalin', 0.017363174)


('Gulf_Stream', 0.37491512)
water current ('Gulf_Stream', 0.062081344)




  result = self.forward(*input, **kwargs)
  probs = F.softmax(out)


In [None]:
# question_list = questions[:30]
# xs = [guesser.nlp(clean_question(x.flatten_text())) for x in question_list]
# removed_indices = greedy_remove(question_list)

In [None]:
# for i in range(30):
#     print('Original Question')
#     print(' '.join([x.lower_ for x in xs[i]]))
#     print()
#     x_afterfor i in range(30):
#     print('Original Question')
#     print(' '.join([x.lower_ for x in xs[i]]))
#     print()
#     x_after = [w for j, w in enumerate(xs[i]) if j not in removed_indices[i]]
#     print('Modified Question')
#     print(' '.join([x.lower_ for x in x_after]))
#     print()
#     print('Original Guesses')
#     for g, s in guesser.guess([xs[i]], 4, tokenize=False)[0]:
#         print(g, s)
#     print()
#     print('Modified Guesses')
#     for g, s in guesser.guess([x_after], 4, tokenize=False)[0]:
#         print(g, s)
#     print()
#     print() = [w for j, w in enumerate(xs[i]) if j not in removed_indices[i]]
#     print('Modified Question')
#     print(' '.join([x.lower_ for x in x_after]))
#     print()
#     print('Original Guesses')
#     for g, s in guesser.guess([xs[i]], 4, tokenize=False)[0]:
#         print(g, s)
#     print()
#     print('Modified Guesses')
#     for g, s in guesser.guess([x_after], 4, tokenize=False)[0]:
#         print(g, s)
#     print()
#     print()