In [1]:
import torch
from torch.autograd import Variable
from tqdm import tqdm
import torch.nn as nn

import os
import argparse
from collections import Counter
import pickle
from tensorboard import Logger
from utils.ReCuda import ReCuda

from readData import get_data
from model import TextModel

In [2]:
##
def setup():
    if not os.path.isdir('logs'):
        os.mkdir('logs')
    logger = Logger('./logs')

    args = type('test', (), {})()
    args.train=False
    args.test = False
    args.ckpt = None
    
    args.source_dir = '/home/jiwan/tqa/prepro/data'
    args.ckpt_dir = './ckpt'
    args.emb_dim = 300
    args.repeat = False
    args.learning_rate = 0.001
    args.if_pair = False
    args.log_epoch = 4
    args.bi_gru = True
    args.batch_size = 36
    args.verbose = False
    args.end_epoch = 100
    args.single_topic = False

    args.test_iter = 'val'

    args.cuda = True
    if not torch.cuda.is_available():
        args.cuda = False

    config = args
    config.recuda = ReCuda(config)
    config.resume = False
    if config.ckpt is not None:
        config.resume = True
    config.single_topic_ckpt = ''
    if not config.single_topic:
        config.single_topic_ckpt = '_all'

    config.logger = logger

    config.recuda.torch.manual_seed(1)

    return config



##
# get net
def get_net(config, vocab):
    if config.resume:
        print('RESUME {}th epoch'.format(config.ckpt))
        assert os.path.isdir('ckpt'), 'Error: no dir'
        ckpt = torch.load(os.path.join(config.ckpt_dir, 'ckpt{}_{}.t7'.format(config.single_topic_ckpt, config.ckpt)))
        net = TextModel(vocab, config, 100)
        net.load_state_dict(ckpt['params'])
        best_acc = ckpt['acc']
        start_epoch = ckpt['epoch']
    else:
        net = TextModel(vocab, config, 100)
        best_acc = 0
        start_epoch = 0
    net = config.recuda.var(net)
    print('PARAMS: ', net.parameters)
    return net, best_acc, start_epoch


##
def run_net(net, config, data):
    answers_size = len(data.answers)
    answers = torch.stack(data.answers, dim=2)
        
    if config.single_topic:
        topics = data.topic.data
    else:
        topics = torch.stack(data.topic, dim=2)

    target = Variable(data.correct_answer.data, requires_grad=False)
    target = config.recuda.var(target)
    print('t:', topics.size(), type(topics))
    # run
    return net.forward(topics, data.question, answers, answers_size)

##
def train_epoch(net, config, data, train_iter, epoch):

    # train
    train_loss = 0
    for batch_index, data in tqdm(enumerate(train_iter)):
        net.zero_grad()
        
        y = run_net(net, config, data)
        if config.verbose:
            print('y:', y.data)
            print('t:', target.data)
        loss = config.loss_fn(y, target)
        # count loss
        loss.backward()
        # optimize
        config.optimizer.step()

        train_loss += loss.data[0]
        loss_per = train_loss/(batch_index+1)
        print("Training {} epoch, loss: {}".format(epoch, loss_per))
        config.logger.scalar_summary('tr_loss', loss_per, epoch+1)

##
def validate_epoch(net, config, data, val_iter, epoch):
    # validate from time to time

    print("begin validation")
    correct = 0
    total = 0
    for index_v, data in tqdm(enumerate(val_iter)):
        y = run_net(net, config, data)
        
        value, pred = torch.max(y, 1)
        check = torch.eq(data.correct_answer.data, pred.data)
        if config.verbose:
            print(torch.sum(check), check.size())
        correct += torch.sum(check)
        total += (check.size()[0])

    acc = 100.*correct/total
    print("Val {} epoch, acc: {}".format(epoch, acc))

    config.logger.scalar_summary('val_acc', acc, (epoch + 1))

    return acc

##
def save_net(net, config, epoch, acc):
    print('saving')
    state = {
        'params': net.state_dict(),
        'acc': acc,
        'epoch': epoch,
    }
    if not os.path.isdir('ckpt'):
        os.mkdir('ckpt')
    torch.save(state, os.path.join(config.ckpt_dir, 'ckpt{}_{}.t7'.format(config.single_topic_ckpt,epoch)))


##
def train_all(net, data, iters, config):
    config.loss_fn = nn.CrossEntropyLoss()
    config.optimizer = torch.optim.Adam(net.parameters(), lr=config.learning_rate)

    for epoch in range(config.start_epoch, config.end_epoch):
        print("{} epoch".format(epoch))
        train_epoch(net, config, data, iters['train'], epoch)
        acc = validate_epoch(net, config, data, iters['val'], epoch)

        save_net(net, config, epoch, acc)
##


def test_epoch(net, config, data, test_iter):
    test_net = Counter()
    net_dict = {}

    print("begin testing")
    for index_t, data in tqdm(enumerate(test_iter)):
        y = run_net(net, config, data)
        
        value, pred = torch.max(y, 1)
        check = torch.eq(data.correct_answer.data, pred.data)
        for i in range(len(check)):
            test_net[data.id[i]] += int(check[i])
            net_dict[data.id[i]] = [pred.data[i], data.correct_answer.data[i]]

    return test_net, net_dict


def test_all(net, data, test_iter, config):
    test_counter, test_dict = test_epoch(net, config, data, test_iter)

    with open(os.path.join(config.source_dir, 'correct_counter_{}.pickle'.format(config.test_iter)), 'wb') as outfile:
        pickle.dump(test_counter, outfile)

    with open(os.path.join(config.source_dir, 'correct_dict_{}.pickle'.format(config.test_iter)), 'wb') as outfile:
        pickle.dump(test_dict, outfile)


In [3]:
config = setup()

In [4]:
data, iters, vocab = get_data(config)

loading data_train_full.tsv, data_test_full.tsv, data_val_full.tsv


In [5]:
print('loading model')
net, best_acc, config.start_epoch = get_net(config, vocab)

loading model
('PARAMS: ', <bound method TextModel.parameters of TextModel(
  (embed): Embedding(27226, 300)
  (embed_context): GRU(300, 100, bidirectional=True)
  (embed_question): GRU(300, 100, bidirectional=True)
  (embed_answer): GRU(300, 100, bidirectional=True)
)>)


In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

class TextModelMain(nn.Module):
    def __init__(self, vocab, config, embed_size):
        super(TextModelMain, self).__init__()

        self.embed_size = embed_size
        self.config = config

        self.embed = nn.Embedding(len(vocab), config.emb_dim)
        self.embed.weight.data.copy_(vocab.vectors)

        self.bi = 2 if config.bi_gru else 1

        self.embed_context = nn.GRU(config.emb_dim, embed_size, bidirectional=config.bi_gru)
        self.embed_question = nn.GRU(config.emb_dim, embed_size, bidirectional=config.bi_gru)
        self.embed_answer = nn.GRU(config.emb_dim, embed_size, bidirectional=config.bi_gru)

    def forward(self, context, question, answers, answers_size):

        if not self.config.single_topic:
            context_shape = list(context.data.size())
            context_shape.append(self.config.emb_dim)
            context = context.view(-1, context.size()[2])
        
        context = self.embed(context)
        question = self.embed(question)

        if not self.config.single_topic:
            context = context.view(*context_shape)
            context = torch.sum(context, 1) # sum along num of topics

        M, hm = self.embed_context(context) # P x embed_size
        U, hu = self.embed_question(question) # Q X embed_size

        M = M.permute(1,0,2)
        U = U.permute(1,2,0)
        S = torch.matmul(M, U)
        S, S_index = torch.max(S, dim=2)
        a = F.softmax(S).unsqueeze(0).permute(1,2,0)
        a = a.expand(M.data.size())
        m = torch.mul(a, M)
        m = torch.sum(m, 1).unsqueeze(0)

        origin_size = answers.data.size()
        answers = answers.view(-1, answers.size()[2])
        if self.config.verbose:
            if len(answers.data.size()) < 3:
                print(answers.data)
        answers = self.embed(answers)
        C, hc = self.embed_answer(answers) # A X embed_size
        C = C.unsqueeze(0).view(origin_size[0], origin_size[1], origin_size[2], self.bi*self.embed_size)
        c = torch.sum(C, dim=0)
        r = torch.matmul(m.permute(1,0,2), c.permute(1,2,0)).squeeze()

        return r

net = TextModelMain(vocab, config, 100)
net.cuda()

TextModelMain(
  (embed): Embedding(27226, 300)
  (embed_context): GRU(300, 100, bidirectional=True)
  (embed_question): GRU(300, 100, bidirectional=True)
  (embed_answer): GRU(300, 100, bidirectional=True)
)

In [9]:

if config.train:
    print("Let\'s start Training")
    train_all(net, data, iters, config)
else:
    print("Let\'s start Testing")
    test_all(net, data, iters[config.test_iter], config)


0it [00:00, ?it/s]

Let's start Testing
begin testing
('t:', torch.Size([105, 16, 36]), <class 'torch.autograd.variable.Variable'>)


3it [00:01,  2.97it/s]

('t:', torch.Size([169, 16, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([169, 14, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([248, 8, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([248, 18, 36]), <class 'torch.autograd.variable.Variable'>)


7it [00:01,  5.41it/s]

('t:', torch.Size([254, 18, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([254, 12, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([163, 13, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([145, 9, 36]), <class 'torch.autograd.variable.Variable'>)


12it [00:01,  7.55it/s]

('t:', torch.Size([145, 21, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([171, 21, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([229, 19, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([451, 9, 36]), <class 'torch.autograd.variable.Variable'>)


14it [00:01,  8.01it/s]

('t:', torch.Size([451, 8, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([282, 14, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([249, 18, 36]), <class 'torch.autograd.variable.Variable'>)


18it [00:02,  8.50it/s]

('t:', torch.Size([515, 18, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([515, 8, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([338, 5, 36]), <class 'torch.autograd.variable.Variable'>)


23it [00:02,  9.74it/s]

('t:', torch.Size([209, 4, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([283, 4, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([234, 5, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([198, 5, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([239, 11, 36]), <class 'torch.autograd.variable.Variable'>)


27it [00:02, 10.44it/s]

('t:', torch.Size([292, 3, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([226, 9, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([284, 6, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([185, 7, 36]), <class 'torch.autograd.variable.Variable'>)


30it [00:02, 10.98it/s]

('t:', torch.Size([256, 6, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([276, 8, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([202, 7, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([181, 10, 36]), <class 'torch.autograd.variable.Variable'>)


35it [00:02, 11.70it/s]

('t:', torch.Size([192, 10, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([286, 8, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([255, 8, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([264, 10, 36]), <class 'torch.autograd.variable.Variable'>)


39it [00:03, 12.00it/s]

('t:', torch.Size([264, 19, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([199, 19, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([199, 13, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([194, 15, 36]), <class 'torch.autograd.variable.Variable'>)


43it [00:03, 12.21it/s]

('t:', torch.Size([158, 15, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([225, 14, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([459, 14, 36]), <class 'torch.autograd.variable.Variable'>)


45it [00:03, 12.29it/s]

('t:', torch.Size([399, 6, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([498, 6, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([303, 6, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([327, 3, 36]), <class 'torch.autograd.variable.Variable'>)


49it [00:03, 12.55it/s]

('t:', torch.Size([410, 5, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([492, 4, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([433, 4, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([234, 5, 36]), <class 'torch.autograd.variable.Variable'>)


53it [00:04, 12.85it/s]

('t:', torch.Size([234, 5, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([207, 12, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([295, 12, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([301, 11, 36]), <class 'torch.autograd.variable.Variable'>)


57it [00:04, 13.05it/s]

('t:', torch.Size([301, 5, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([252, 5, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([284, 7, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([495, 10, 36]), <class 'torch.autograd.variable.Variable'>)


61it [00:04, 13.11it/s]

('t:', torch.Size([495, 10, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([248, 8, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([218, 13, 36]), <class 'torch.autograd.variable.Variable'>)


66it [00:04, 13.44it/s]

('t:', torch.Size([251, 13, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([278, 4, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([229, 4, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([218, 4, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([342, 5, 36]), <class 'torch.autograd.variable.Variable'>)


71it [00:05, 13.73it/s]

('t:', torch.Size([441, 10, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([348, 4, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([291, 5, 36]), <class 'torch.autograd.variable.Variable'>)
('t:', torch.Size([291, 3, 4]), <class 'torch.autograd.variable.Variable'>)



