In [38]:
import torch.nn as nn
from torch.autograd import Variable
import torch
from qanta.guesser.torch.dan import *
from qanta.datasets.quiz_bowl import QuizBowlDataset
import abc
from collections import defaultdict

In [14]:
dataset = QuizBowlDataset(1, guesser_train=True)
training_data = dataset.training_data()

In [15]:
x_train_text, y_train, x_test_text, y_test, vocab, class_to_i, i_to_class = preprocess_dataset(
    training_data
)

In [16]:
embeddings, embedding_lookup = load_embeddings(vocab=vocab, expand_glove=True)

2017-09-18 13:50:28,264 - qanta.guesser.torch.dan - INFO - Loading word embeddings from tmp cache


In [17]:
x_train = np.array([convert_text_to_embeddings_indices(q, embedding_lookup) for q in x_train_text])
y_train = np.array(y_train)

x_test = np.array([convert_text_to_embeddings_indices(q, embedding_lookup) for q in x_test_text])
y_test = np.array(y_test)

In [18]:
n_classes = compute_n_classes(training_data[1])

In [19]:
i_to_word = {ind: word for word, ind in embedding_lookup.items()}

In [23]:
def flatten_and_offset(x_batch):
    flat_x_batch = []
    for r in x_batch:
        flat_x_batch.extend(r)
    flat_x_batch = np.array(flat_x_batch)
    x_lengths = [len(r) for r in x_batch]
    offsets = np.cumsum([0] + x_lengths[:-1])
    return flat_x_batch, offsets

def batchify(batch_size, x_array, y_array, truncate=True):
    n_examples = x_array.shape[0]
    n_batches = n_examples // batch_size
    random_order = np.random.permutation(n_examples)
    x_array = x_array[random_order]
    y_array = y_array[random_order]

    t_x_batches = []
    t_offset_batches = []
    t_y_batches = []

    for b in range(n_batches):
        x_batch = x_array[b * batch_size:(b + 1) * batch_size]
        y_batch = y_array[b * batch_size:(b + 1) * batch_size]
        flat_x_batch, offsets = flatten_and_offset(x_batch)

        t_x_batches.append(torch.from_numpy(flat_x_batch).long().cuda())
        t_offset_batches.append(torch.from_numpy(offsets).long().cuda())
        t_y_batches.append(torch.from_numpy(y_batch).long().cuda())
    
    if (not truncate) and (batch_size * n_batches < n_examples):
        x_batch = x_array[n_batches * batch_size:]
        y_batch = y_array[n_batches * batch_size:]
        flat_x_batch, offsets = flatten_and_offset(x_batch)
        
        t_x_batches.append(torch.from_numpy(flat_x_batch).long().cuda())
        t_offset_batches.append(torch.from_numpy(offsets).long().cuda())
        t_y_batches.append(torch.from_numpy(y_batch).long().cuda())

    t_x_batches = np.array(t_x_batches)
    t_offset_batches = np.array(t_offset_batches)
    t_y_batches = np.array(t_y_batches)
    
    return n_batches, t_x_batches, t_offset_batches, t_y_batches

In [24]:
batch_size = 512
n_batches_train, t_x_train, t_offset_train, t_y_train = batchify(batch_size, x_train, y_train, truncate=True)
n_batches_test, t_x_test, t_offset_test, t_y_test = batchify(batch_size, x_test, y_test, truncate=False)

In [36]:
def run_epoch(model, n_batches, t_x_array, t_offset_array, t_y_array, evaluate=False):
    if not evaluate:
        random_batch_order = np.random.permutation(n_batches)
        t_x_array = t_x_array[random_batch_order]
        t_offset_array = t_offset_array[random_batch_order]
        t_y_array = t_y_array[random_batch_order]
    
    batch_accuracies = []
    batch_losses = []
    epoch_start = time.time()
    for batch in range(n_batches):
        t_x_batch = Variable(t_x_array[batch], volatile=evaluate)
        t_offset_batch = Variable(t_offset_array[batch], volatile=evaluate)
        t_y_batch = Variable(t_y_array[batch], volatile=evaluate)
        
        model.zero_grad()
        out = model(t_x_batch, t_offset_batch)
        _, preds = torch.max(out, 1)
        accuracy = torch.mean(torch.eq(preds, t_y_batch).float()).data[0]
        batch_loss = criterion(out, t_y_batch)
        if not evaluate:
            batch_loss.backward()
            optimizer.step()
        
        batch_accuracies.append(accuracy)
        batch_losses.append(batch_loss.data[0])
    
    epoch_end = time.time()
        
    return np.mean(batch_accuracies), np.mean(batch_losses), epoch_end - epoch_start

In [34]:
model = DanModel(embeddings.shape[0], n_classes)
model.init_weights(initial_embeddings=embeddings)
model.train()
model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [None]:
class Callback(abc.ABCMeta):
    @abc.abstractmethod()
    def on_epoch_end(logs):
        pass

class BaseLogger(Callback):
    def on_epoch_end(self, logs):
        print('Epoch {}: train_acc={} test_acc={} | train_loss={} test_loss={} | time={}'.format(
            len(logs['train_acc']),
            logs['train_acc'][-1], logs['test_acc'][-1],
            logs['train_loss'][-1], logs['test_loss'][-1],
            logs['train_time'][-1]
        ))
        return False, []
    
    def __repr__(self):
        return 'BaseLogger()'

class TerminateOnNaN(Callback):
    def on_epoch_end(self, logs):
        for key, arr in logs.items():
            if np.any(np.isnan(arr)):
                return True, 'NaN encountered in {} containing {}'.format(key, arr)
        else:
            return False, None
    
    def __repr__(self):
        return 'TerminateOnNaN()'

class EarlyStopping(Callback):
    def __init__(self, monitor='test_loss', min_delta=0, patience=0):
        if monitor.endswith('loss'):
            self.improvement_sign = 1
        elif monitor.endswith('accuracy'):
            self.improvement_sign = -1
        else:
            raise ValueError('Unrecognized monitor')
        self.monitor = monitor
        self.min_delta = min_delta
        self.patience = patience
        self.best_monitor_score = self.improvement_sign * float('inf')
        self.current_patience = patience
    
    def __repr__(self):
        return 'EarlyStopping(monitor={}, min_delta={}, patience={})'.format(
            self.monitor, self.min_delta, self.patience)
        
    def on_epoch_end(self, logs):
        if self.logs[self.monitor][-1] * self.improvement_sign < self.improvement_sign * self.best_monitor_score:
            self.current_patience = self.patience
            self.best_monitor_score = self.logs[self.monitor][-1]
        else:
            self.current_patience -= 1
        
        if self.current_patience == 0:
            return True, 'Ran out of patience'
        else:
            return False, None


class ModelCheckpoint(Callback):
    def __init__(model, filepath, monitor='test_loss', save_best_only=True, overwrite=True):
        self.model = model
        self.filepath = filepath
        self.save_best_only = save_best_only
        self.overwrite = overwrite
        if monitor.endswith('loss'):
            self.improvement_sign = 1
        elif monitor.endswith('accuracy'):
            self.improvement_sign = -1
        else:
            raise ValueError('Unrecognized monitor')
        self.monitor = monitor
        self.best_monitor_score = self.improvement_sign * float('inf')
    
    def on_epoch_end(self, logs):
        if self.logs[self.monitor][-1] * self.improvement_sign < self.improvement_sign * self.best_monitor_score:
            self.current_patience = self.patience
            self.best_monitor_score = self.logs[self.monitor][-1]
        else:
            self.current_patience -= 1


class TrainingManager:
    def __init__(self, callbacks, default_callbacks=[]):
        self.callbacks = callbacks
        
        for c in default_callbacks:
            pass
        
        self.logs = defaultdict(list)
    
    def instruct(self, train_time, train_loss, train_acc, test_time, test_loss, test_acc):
        self.logs['train_time'].append(train_time)
        self.logs['train_loss'].append(train_loss)
        self.logs['train_acc'].append(train_acc)
        self.logs['test_time'].append(test_time)
        self.logs['test_loss'].append(test_loss)
        self.logs['test_acc'].append(test_acc)
        
        callback_stop_reasons = []
        for c in self.callbacks:
            stop_training, reason = c.on_epoch_end(self.logs)
            if stop_training:
                callback_stop_reasons.append(reason)
        
        if len(callback_stop_reasons) > 0:
            return True, callback_stop_reasons
        else:
            return False, []
        

In [35]:
manager = TrainingManager([])

for epoch in range(100):
    print(f'Epoch {epoch}: ', end='')
    train_accuracy, train_loss, train_time = run_epoch(
        model, n_batches_train,
        t_x_train, t_offset_train, t_y_train, evaluate=False
    )
   
    test_accuracy, test_loss, test_time = run_epoch(
        model, n_batches_test,
        t_x_test, t_offset_test, t_y_test, evaluate=True
    )
    
    stop_training, reasons = manager.instruct(
        train_time, train_loss, train_acc,
        test_time, test_loss, test_acc
    )
    
    if stop_training:
        print(reason)
        break


Epoch 0: time=12.4 train_accuracy=0.0000 train_loss=0.0000test_accuracy=0.0026 test_loss=8.7164
Epoch 1: 

KeyboardInterrupt: 