In [1]:
import torch

from models import TextModel, ModuleNet

import argparse
import os
from tensorboard import Logger
from utils.ReCuda import ReCuda

logger = Logger('./logs')

args = type('args_test', (object,), {})()
args.load_ckpt = None
args.train = True
args.test = False
args.source_dir = '/home/jiwan/tqa/prepro/data'
args.ckpt_dir = './ckpt'
args.emb_dim = 300
args.repeat = False
args.learning_rate = 0.001
args.if_pair = False
args.log_epoch = 4
args.bi_gru = True
args.batch_size = 36
args.verbose = False
args.end_epoch = 100
args.single_topic = False
args.embed_size = 100
args.shuffle = True
args.large_topic = False
args.reversible = True
args.fix_length = True
args.reasoning_planes = 16
args.k = 4
args.conf = 0.7
args.h_size = 128
args.hyper = False
args.hidden_size = 300
args.dim_words = 2
args.ans_k = 7
args.bi = 2 if args.bi_gru else 1

args.resume = False
if args.load_ckpt is not None:
    args.resume = True

args.test_iter = 'val'

args.cuda = True
if not torch.cuda.is_available():
    args.cuda = False

config = args
config.recuda = ReCuda(config)
config.ckpt_name = '_single'
if not config.single_topic:
    config.ckpt_name = '_all'
if config.large_topic:
    config.ckpt_name = '_full'

config.logger = logger

config.recuda.torch.manual_seed(1)

config.model = ModuleNet



In [2]:
from readData import get_data

data, iters, vocab, stats = get_data(config)

config.q_size = stats['question_size']
config.a_size = stats['answer_size']
config.c_size = stats['topic_size']
config.keys = ['A', 'c']
config.sizes = {'A': config.a_size, 'c': config.c_size}

loading data_train_full.tsv, data_test_full.tsv, data_val_full.tsv


In [3]:
def get_net(config, vocab):
    if config.resume:
        assert os.path.isdir('ckpt'), 'Error: no dir'
        ckpt = torch.load(os.path.join(config.ckpt_dir, config.load_ckpt))
        net = config.model(config, vocab)
        net.load_state_dict(ckpt['params'])
        best_acc = ckpt['acc']
        start_epoch = ckpt['epoch']
        print('RESUME {}th epoch'.format(start_epoch))
    else:
        net = config.model(config, vocab)
        best_acc = 0
        start_epoch = 0
    net = config.recuda.var(net)
    print('PARAMS: ', net.parameters)
    return net, best_acc, start_epoch

net, best_acc, config.start_epoch = get_net(config, vocab)

('PARAMS: ', <bound method ModuleNet.parameters of ModuleNet(
  (encoder): Encoder(
    (embed): Embedding(27818, 300)
    (embed_context): GRU(300, 100, bidirectional=True)
    (embed_question): GRU(300, 100, bidirectional=True)
    (embed_answer): GRU(300, 100, bidirectional=True)
    (normalize_row): Softmax()
  )
  (controller): POCController(
    (module): SimpleModule(
      (memory_attention): MemoryAttention(
      )
      (reasoning): Reasoning(
        (res_conv_a): Sequential(
          (0): BasicBlock(
            (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True)
            (relu): ReLU(inplace)
            (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True)
          )
          (1): BasicBlock(
            (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(

In [4]:
import torch
from torch.autograd import Variable
from tqdm import tqdm
import torch.nn as nn

import os
from collections import Counter
import pickle

##
def run_net(net, config, data):
    answers_size = len(data.answers)
    answers = torch.stack(data.answers, dim=2)

    if config.single_topic:
        topics = data.topic.data
    else:
        topics = torch.stack(data.topic, dim=2)

    target = Variable(data.correct_answer.data, requires_grad=False)
    target = config.recuda.var(target)
    # print('t:', topics.size(), type(topics))
    # run
    return net.forward(topics, data.question, answers)


##
def train_epoch(net, config, data, train_iter, epoch):

    # train
    train_loss = 0
    for batch_index, data in tqdm(enumerate(train_iter)):
        net.zero_grad()
        target = Variable(data.correct_answer.data, requires_grad=False)
        target = config.recuda.var(target)

        if config.verbose:
            if config.single_topic:
                print('context:', data.topic.data)
            else:
                print('context_list:', data.topic[0])

        # run
        y = run_net(net, config, data)

        if config.verbose:
            print('y:', y.data)
            print('t:', target.data)

        loss = config.loss_fn(y, target)
        # count loss
        loss.backward()
        # optimize
        config.optimizer.step()

        train_loss += loss.data[0]
        loss_per = train_loss / (batch_index + 1)
        print("Training {} epoch, loss: {}".format(epoch, loss_per))
        config.logger.scalar_summary('tr_loss{}'.format(config.ckpt_name), loss_per, epoch + 1)


##
def validate_epoch(net, config, data, val_iter, epoch):
    # validate from time to time

    print("begin validation")
    correct = 0
    total = 0
    for index_v, data in tqdm(enumerate(val_iter)):
        # run
        y = run_net(net, config, data)

        value, pred = torch.max(y, 1)
        check = torch.eq(data.correct_answer.data, pred.data)
        if config.verbose:
            print(torch.sum(check), check.size())
        correct += torch.sum(check)
        total += (check.size()[0])

    acc = 100.*correct/total
    print("Val {} epoch, acc: {}".format(epoch, acc))

    config.logger.scalar_summary('val_acc{}'.format(config.ckpt_name), acc, (epoch + 1))

    return acc


##
def save_net(net, config, epoch, acc):
    print('saving')
    state = {
        'params': net.state_dict(),
        'acc': acc,
        'epoch': epoch,
    }
    if not os.path.isdir('ckpt'):
        os.mkdir('ckpt')
    if not os.path.isdir('ckpt/temp'):
        os.mkdir('ckpt/temp')
    torch.save(state, os.path.join(config.ckpt_dir, 'temp', 'ckpt{}_{}.t7'.format(config.ckpt_name,epoch)))


##
def train_all(net, data, iters, config):
    config.loss_fn = nn.CrossEntropyLoss()
    config.optimizer = torch.optim.Adam(net.parameters(), lr=config.learning_rate)

    for epoch in range(config.start_epoch, config.end_epoch):
        print("{} epoch".format(epoch))
        train_epoch(net, config, data, iters['train'], epoch)
        acc = validate_epoch(net, config, data, iters['val'], epoch)

        save_net(net, config, epoch, acc)

In [21]:
# refresh

net, best_acc, config.start_epoch = get_net(config, vocab)

reload(net)

('PARAMS: ', <bound method ModuleNet.parameters of ModuleNet(
  (encoder): Encoder(
    (embed): Embedding(27818, 300)
    (embed_context): GRU(300, 100, bidirectional=True)
    (embed_question): GRU(300, 100, bidirectional=True)
    (embed_answer): GRU(300, 100, bidirectional=True)
    (normalize_row): Softmax()
  )
  (controller): POCController(
    (module): SimpleModule(
      (memory_attention): MemoryAttention(
      )
      (reasoning): Reasoning(
        (res_conv_a): Sequential(
          (0): BasicBlock(
            (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True)
            (relu): ReLU(inplace)
            (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True)
          )
          (1): BasicBlock(
            (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(

TypeError: reload() argument must be module

In [22]:
train_all(net, data, iters, config)

0it [00:00, ?it/s]

0 epoch


Exception KeyError: KeyError(<weakref at 0x7f6387622470; to 'tqdm' at 0x7f638c26c710>,) in <bound method tqdm.__del__ of 0it [00:00, ?it/s]> ignored


TypeError: Performing basic indexing on a tensor and encountered an error indexing dim 0 with an object of type torch.cuda.LongTensor. The only supported types are integers, slices, numpy scalars, or if indexing with a torch.cuda.LongTensor or torch.cuda.ByteTensor only a single Tensor may be passed.