In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import functional as F

from qanta.datasets.quiz_bowl import QuestionDatabase
from qanta.util.constants import GUESSER_TRAIN_FOLD, GUESSER_DEV_FOLD
from qanta.guesser.dan import DanGuesser, batchify
from qanta.guesser.nn import convert_text_to_embeddings_indices
from qanta.preprocess import tokenize_question

In [2]:
guesser = DanGuesser.load('output/guesser/qanta.guesser.dan.DanGuesser')
guesser.model = guesser.model.cuda()
criterion = nn.CrossEntropyLoss()

In [3]:
extracted_grads = {}

def extract_grad_hook(name):
    def hook(grad):
        extracted_grads[name] = grad
    return hook

def extract_grad_hook(name):
    def hook(grad):
        extracted_grads[name] = grad
    return hook

In [4]:
questions = QuestionDatabase().all_questions().values()
questions = [q for q in questions if q.fold == GUESSER_DEV_FOLD]

In [5]:
def prepare_batches(question_list):
    x_test = [convert_text_to_embeddings_indices(q, guesser.embedding_lookup)
            for q in question_list]
    
    for r in x_test:
        if len(r) == 0:
            r.append(guesser.embedding_lookup['UNK'])

    x_test = np.array(x_test)
    y_test = np.zeros(len(x_test))

    _, t_x_batches, t_len_batches, t_y_batches = batchify(
        guesser.batch_size, x_test, y_test, truncate=False, shuffle=False)
    
    return t_x_batches, t_len_batches

def guess(question_list, max_n_guesses):
    t_x_batches, t_len_batches = prepare_batches(question_list)
    guesses = []
    for b in range(len(t_x_batches)):
        t_x = Variable(t_x_batches[b])
        t_len = Variable(t_len_batches[b])
        
        guesser.model.eval()
        out = guesser.model(t_x, t_len, 0).data.cpu().numpy()
        for preds_scores in out:
            guesses.append([])
            preds = np.argsort(preds_scores)[:-10:-1]
            scores = preds_scores[preds]
            for p, s in zip(preds, scores):
                guesses[-1].append((guesser.i_to_class[p], s))
    return guesses

def get_onehot_grad(question_list):
    t_x_batches, t_len_batches = prepare_batches(question_list)
    grads = []
    for b in range(len(t_x_batches)):
        t_x = Variable(t_x_batches[b])
        t_len = Variable(t_len_batches[b])
        
        guesser.model.eval()
        out = guesser.model(t_x, t_len, 0, extract_grad_hook('embed'))
        scores, preds = torch.max(out, 1) # take gradient w.r.t top guess

        guesser.model.zero_grad()
        loss = criterion(out, preds)
        loss.backward()

        batch_size, length = t_x.size()
        embed = guesser.model.embeddings(t_x)
        embed_grad = extracted_grads['embed'].contiguous()
        onehot_grad = embed.view(-1) * embed_grad.view(-1)
        onehot_grad = onehot_grad.view(batch_size, length, -1).sum(-1)
        onehot_grad = onehot_grad.data.cpu().numpy()
        grads.append(onehot_grad)
    return np.concatenate(grads)

In [44]:
def greedy_remove(question_list):
    _xs = [tokenize_question(x.flatten_text()) for x in question_list]
    _ys = [x[0][0] for x in guess(_xs, 10)]

    xs, ys = list(_xs), list(_ys)

    removed_indices = [[] for _ in question_list]
    indices = [list(range(len(x))) for x in xs]
    alive = [True for _ in xs]

    while True:
        onehot_grad = get_onehot_grad(xs)
        for i, x in enumerate(xs):
            if len(x) == 1:
                alive[i] = False
            if alive[i]:
                drop_idx = np.argmax(onehot_grad[i][:len(x)])
                removed_indices[i].append(indices[i][drop_idx])
                indices[i] = indices[i][:drop_idx] + indices[i][drop_idx + 1:]
                x = x[:drop_idx] + x[drop_idx + 1:]
            xs[i] = x

        if sum(alive) == 0:
            break
                
        pred = [x[0][0] for x in guess(xs, 10)]
        for i, (x, y, z) in enumerate(zip(xs, ys, pred)):
            if z != y:
                alive[i] = False
                removed_indices[i] = removed_indices[i][:-1]
    return removed_indices

In [46]:
question_list = questions[:30]
removed_indices = greedy_remove(question_list)
xs = [tokenize_question(x.flatten_text()) for x in question_list]

In [47]:
for i in range(30):
    print('Original Question')
    print(' '.join(xs[i]))
    print()
    x_after = [w for j, w in enumerate(xs[i]) if j not in removed_indices[i]]
    print('Modified Question')
    print(' '.join(x_after))
    print()
    print('Original Guesses')
    for g, s in guess([xs[i]], 4)[0]:
        print(g, s)
    print()
    print('Modified Guesses')
    for g, s in guess([x_after], 4)[0]:
        print(g, s)
    print()
    print()

Original Question
he rescues a hindu widow from ritual burning and takes her as his bride but first he must convince detective fix that he is not a bank robber this is after he has crossed three continents and two oceans on trains steamers and an elephant with his valet passepartout name this welltraveled jules verne character who circumnavigated the globe in 1920 hours

Modified Question
passepartout

Original Guesses
Around_the_World_in_Eighty_Days 5.0191
Robinson_Crusoe 4.03111
Phileas_Fogg 3.26174
Ferdinand_Magellan 2.58749
Sancho_Panza 2.37755
Galahad 2.2981
Captain_Nemo 2.28571
Theseus 2.167
Vishnu 2.14411

Modified Guesses
Around_the_World_in_Eighty_Days 26.0226
Phileas_Fogg 22.0607
International_Date_Line 19.9724
William_Lyon_Mackenzie_King 10.9296
Middlemarch 10.4852
Abortion 10.2257
Hypnos 9.58031
A_Confederacy_of_Dunces 9.48001
Lens 9.40255


Original Question
some scholars identify him with the two horned one mentioned in the koran he defeated darius iii at issus in 333 bc 

Marc_Chagall 3.0457
The_Crucible 2.2388
Death_of_a_Salesman 2.14662
Nighthawks 2.03471
Tom_Stoppard 1.99829
Poland 1.74604
Sylvia_Plath 1.73453
Karma 1.72444

Modified Guesses
Ulysses_(novel) 20.8699
Finnegans_Wake 6.71746
Gideon 6.34709
Nighthawks 5.40759
Sister_Carrie 5.12291
Tree 4.99581
Éamon_de_Valera 4.98108
Marc_Chagall 4.6637
J._D._Salinger 4.63881


Original Question
caused by the spirochete borrelia burgdorferi if caught early it can be treated with tetracycline if not treated the characteristic rash is followed by periodic arthritis and possibly meningitis name this disease carried by some species of deer tick

Modified Question
caused by the borrelia burgdorferi if early can be treated with if not treated the characteristic rash followed by arthritis and possibly meningitis disease by some of deer tick

Original Guesses
Lyme_disease 12.4911
Syphilis 4.48661
Malaria 3.45534
Huntington's_disease 2.95676
Precession 2.87862
Cholera 2.81971
Joseph_Haydn 2.782
Leprosy 2.72935
Add