In [None]:
import torch.nn as nn
from torch.autograd import Variable
import torch
from qanta.guesser.torch.dan import *
from qanta.datasets.quiz_bowl import QuizBowlDataset
from qanta.manager import BaseLogger, TerminateOnNaN, EarlyStopping, ModelCheckpoint, MaxEpochStopping, TrainingManager
import abc
from collections import defaultdict
import os

In [None]:
dataset = QuizBowlDataset(1, guesser_train=True)
training_data = dataset.training_data()

In [None]:
x_train_text, y_train, x_test_text, y_test, vocab, class_to_i, i_to_class = preprocess_dataset(
    training_data
)

In [None]:
embeddings, embedding_lookup = load_embeddings(vocab=vocab, expand_glove=True)

In [None]:
x_train = np.array([convert_text_to_embeddings_indices(q, embedding_lookup) for q in x_train_text])
y_train = np.array(y_train)

x_test = np.array([convert_text_to_embeddings_indices(q, embedding_lookup) for q in x_test_text])
y_test = np.array(y_test)

In [None]:
n_classes = compute_n_classes(training_data[1])

In [None]:
i_to_word = {ind: word for word, ind in embedding_lookup.items()}

In [None]:
def flatten_and_offset(x_batch):
    flat_x_batch = []
    for r in x_batch:
        flat_x_batch.extend(r)
    flat_x_batch = np.array(flat_x_batch)
    x_lengths = [len(r) for r in x_batch]
    offsets = np.cumsum([0] + x_lengths[:-1])
    return flat_x_batch, offsets

def batchify(batch_size, x_array, y_array, truncate=True):
    n_examples = x_array.shape[0]
    n_batches = n_examples // batch_size
    random_order = np.random.permutation(n_examples)
    x_array = x_array[random_order]
    y_array = y_array[random_order]

    t_x_batches = []
    t_offset_batches = []
    t_y_batches = []

    for b in range(n_batches):
        x_batch = x_array[b * batch_size:(b + 1) * batch_size]
        y_batch = y_array[b * batch_size:(b + 1) * batch_size]
        flat_x_batch, offsets = flatten_and_offset(x_batch)

        t_x_batches.append(torch.from_numpy(flat_x_batch).long().cuda())
        t_offset_batches.append(torch.from_numpy(offsets).long().cuda())
        t_y_batches.append(torch.from_numpy(y_batch).long().cuda())
    
    if (not truncate) and (batch_size * n_batches < n_examples):
        x_batch = x_array[n_batches * batch_size:]
        y_batch = y_array[n_batches * batch_size:]
        flat_x_batch, offsets = flatten_and_offset(x_batch)
        
        t_x_batches.append(torch.from_numpy(flat_x_batch).long().cuda())
        t_offset_batches.append(torch.from_numpy(offsets).long().cuda())
        t_y_batches.append(torch.from_numpy(y_batch).long().cuda())

    t_x_batches = np.array(t_x_batches)
    t_offset_batches = np.array(t_offset_batches)
    t_y_batches = np.array(t_y_batches)
    
    return n_batches, t_x_batches, t_offset_batches, t_y_batches

In [None]:
batch_size = 512
n_batches_train, t_x_train, t_offset_train, t_y_train = batchify(batch_size, x_train, y_train, truncate=True)
n_batches_test, t_x_test, t_offset_test, t_y_test = batchify(batch_size, x_test, y_test, truncate=False)

In [None]:
t_x = Variable(t_x_test[0], volatile=True)
t_offset = Variable(t_offset_test[0], volatile=True)
t_y = Variable(t_y_test[0], volatile=True)

model.eval()
out = model(t_x, t_offset)
probs = torch.nn.functional.softmax(out)
scores, preds = torch.max(probs, 1)

In [None]:
preds.data.cpu().numpy()

In [None]:
scores

In [None]:
def run_epoch(model, n_batches, t_x_array, t_offset_array, t_y_array, evaluate=False):
    if not evaluate:
        random_batch_order = np.random.permutation(n_batches)
        t_x_array = t_x_array[random_batch_order]
        t_offset_array = t_offset_array[random_batch_order]
        t_y_array = t_y_array[random_batch_order]
    
    batch_accuracies = []
    batch_losses = []
    epoch_start = time.time()
    for batch in range(n_batches):
        t_x_batch = Variable(t_x_array[batch], volatile=evaluate)
        t_offset_batch = Variable(t_offset_array[batch], volatile=evaluate)
        t_y_batch = Variable(t_y_array[batch], volatile=evaluate)
        
        model.zero_grad()
        out = model(t_x_batch, t_offset_batch)
        _, preds = torch.max(out, 1)
        accuracy = torch.mean(torch.eq(preds, t_y_batch).float()).data[0]
        batch_loss = criterion(out, t_y_batch)
        if not evaluate:
            batch_loss.backward()
            optimizer.step()
        
        batch_accuracies.append(accuracy)
        batch_losses.append(batch_loss.data[0])
    
    epoch_end = time.time()
        
    return np.mean(batch_accuracies), np.mean(batch_losses), epoch_end - epoch_start

In [None]:
model = DanModel(embeddings.shape[0], n_classes)
model.init_weights(initial_embeddings=embeddings)
model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [None]:
def create_save_model(model):
    def save_model(path):
        torch.save(model, path)
    return save_model

manager = TrainingManager([
    BaseLogger(), TerminateOnNaN(),
    EarlyStopping(patience=5), ModelCheckpoint(create_save_model(model), '/tmp/dan.pt')
])

for epoch in range(100):
    print('Starting epoch... ', end='')
    model.train()
    train_acc, train_loss, train_time = run_epoch(
        model, n_batches_train,
        t_x_train, t_offset_train, t_y_train, evaluate=False
    )
    
    model.eval()
    test_acc, test_loss, test_time = run_epoch(
        model, n_batches_test,
        t_x_test, t_offset_test, t_y_test, evaluate=True
    )
    
    stop_training, reasons = manager.instruct(
        train_time, train_loss, train_acc,
        test_time, test_loss, test_acc
    )
    
    if stop_training:
        print(reasons)
        break