In [None]:
import sys, os, re, json, time

import pandas as pd
import pickle
import h5py

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
import plotting
from PIL import Image
from tqdm import tqdm
from utils import imread, img_data_2_mini_batch, imgs2batch

from sklearn import metrics
from sklearn.metrics import accuracy_score

# from naive import EncDec
from attention import EncDec as FuseAttEncDec
# from rnn_att import EncDec
from data_loader import VQADataSet

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as Data
from torchvision import transforms

%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [None]:
N = 5000
dataset_filename = "./data/data_{}.pkl".format(N)
dataset = None
print(dataset_filename)
if (os.path.exists(dataset_filename)):
    with open(dataset_filename, 'rb') as handle:
        print("reading from " + dataset_filename)
        dataset = pickle.load(handle)
else:
    dataset = VQADataSet(Q=N)
    with open(dataset_filename, 'wb') as handle:
        print("writing to " + dataset_filename)
        pickle.dump(dataset, handle)

assert(dataset is not None)
def debug(v,q,a):
    print('\nV: {}\nQ: {}\nA: {}'.format(v.shape, q.shape, a.shape))


In [None]:
embed_size        = 300
hidden_size       = 1024
batch_size        = 50
ques_vocab_size   = len(dataset.vocab['question'])
ans_vocab_size    = len(dataset.vocab['answer'])
num_layers        = 1
n_epochs          = 30
learning_rate     = 0.001
momentum          = 0.98
attention_size    = 512
debug             = False

print(ques_vocab_size, ans_vocab_size)

In [None]:
def eval_model(data_loader, model, criterion, optimizer, batch_size, training=False,
              epoch = 0, total_loss_over_epochs=[], scores_over_epochs=[]):
    running_loss = 0.
    final_labels, final_preds = [], []
    scores, losses = [], []
    if data_loader is None:
        return
    
    run_type = None
    if training:
        run_type = 'train'
        model.train()
    else:
        run_type = 'test'
        model.eval()
    
    for i, minibatch in enumerate(data_loader):
        # extract minibatch
        t0 = time.time()
        idxs, v, q, a, q_len = minibatch
        
        # convert torch's DataLoader output to proper format.
        # torch gives a List[Tensor_1, ... ] where tensor has been transposed. 
        # batchify transposes back.`
        v = v.to(device)
        q = VQADataSet.batchify_questions(q).to(device)
        a = a.to(device)

        logits = model(v, q, q_len)
        preds = torch.argmax(logits, dim=1)

#         loss = criterion(logits, a)
        loss = F.nll_loss(logits, a)
        running_loss += loss.item()
        
#         score = metrics.precision_recall_fscore_support(preds.tolist(),
#                                                         a.tolist(),
#                                                         average='weighted')
        score = metrics.accuracy_score(preds.tolist(),a.tolist())
    
        scores.append(score)
        losses.append(loss)
        
        loss_key = '{}_loss'.format(run_type)
        total_loss_over_epochs['{}_loss'.format(run_type)].append(loss)
        scores_over_epochs['{}_scores'.format(run_type)].append(score)
        
        if training and optimizer is not None:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
   
        final_labels += a.tolist()
        final_preds  += preds.tolist()
        if i%10==0:
            score = np.mean(scores)
            print("Epoch {}: {} Loss: {} Score: {} t: {}".format(epoch, run_type,loss, score, time.time()-t0))
#             plotting.plot_score_over_n_epochs(scores_over_epochs, score_type='precision', fig_size=(7,3))
#             plotting.plot_loss_over_n_epochs(total_loss_over_epochs, hard_key=loss_key, fig_size=(7, 3))
            
    return running_loss, final_labels, final_preds

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# model = EncDec(embed_size, hidden_size, ques_vocab_size, ans_vocab_size, rnn_layers).to(device)
model = FuseAttEncDec(embed_size, hidden_size, attention_size, 
                      ques_vocab_size, ans_vocab_size, num_layers, debug).to(device)

criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.SGD(model.get_parameters(), lr=learning_rate, momentum=momentum)
# optimizer = torch.optim.Adam(model.get_parameters(), lr=learning_rate)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

train_loader = dataset.build_data_loader(train=True, args={'batch_size': batch_size})
test_loader  = dataset.build_data_loader(test=True, args={'batch_size': batch_size})

best_score = 0

train_all_loss, train_all_labels, train_all_preds = [], [], []
print("model built, start training.")
total_loss_over_epochs, scores_over_epochs = plotting.get_empty_stat_over_n_epoch_dictionaries()
total_loss_over_epochs2, scores_over_epochs2 = plotting.get_empty_stat_over_n_epoch_dictionaries()
for epoch in tqdm(range(n_epochs)):
    t0= time.time()
    tr_loss, tr_labels, tr_preds = eval_model(data_loader = train_loader,
                                     model       = model,
                                     criterion   = criterion,
                                     optimizer   = optimizer,
                                     batch_size  = batch_size,
                                     training    = True,
                                     epoch       = epoch,
                                     total_loss_over_epochs = total_loss_over_epochs,
                                     scores_over_epochs     = scores_over_epochs)
    
    tr_loss, ts_labels, ts_preds = eval_model(data_loader = test_loader,
                                     model       = model,
                                     criterion   = criterion,
                                     optimizer   = None,
                                     batch_size  = batch_size,
                                     training    = False,
                                     epoch       = epoch,
                                     total_loss_over_epochs = total_loss_over_epochs2,
                                     scores_over_epochs     = scores_over_epochs2)
    
    
    score = metrics.accuracy_score(ts_preds,ts_labels)
#     total_loss_over_epochs['train_loss'].append(tr_loss)
#     scores_over_epochs['train_scores'].append(train_scores)
    
#     if True:# or epoch%1 == 0:
    print("\n"+"#==#"*7 + "epoch: {}".format(epoch) + "#==#"*7)
    print('TEST ACC: {}'.format(score))
    print("#==#"*7 + "time: {}".format(time.time()-t0) + "#==#"*7 + "\n")
#         print(train_scores)
#     plotting.plot_score_over_n_epochs(scores_over_epochs, score_type='precision', fig_size=(8,5))
#     plotting.plot_loss_over_n_epochs(total_loss_over_epochs, fig_size=(8, 5), title="Loss")
    
    
    
    

In [None]:
for epoch in range(1):
    ts_loss, ts_labels, ts_preds = eval_model(data_loader = test_loader,
                                     model       = model,
                                     criterion   = criterion,
                                     optimizer   = None,
                                     batch_size  = batch_size,
                                     training    = False,
                                     epoch       = epoch,
                                     total_loss_over_epochs = total_loss_over_epochs2,
                                     scores_over_epochs     = scores_over_epochs2)
    score = metrics.accuracy_score(ts_preds,ts_labels)
    print("ACC: " + str(score))

In [None]:
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

count = 1
err_anal_data = []
for i, minibatch in enumerate(test_loader):
    # extract minibatch
    t0 = time.time()
    idxs, v, q, a, q_len = minibatch

    v = v.to(device)
    q = VQADataSet.batchify_questions(q).to(device)
    a = a.to(device)
    
    logits = model(v,q,q_len)
    preds = torch.argmax(logits, dim=1)

    for i in range(len(a)):
        idx = idxs[i]
        enc_ans = a[i].item()
        enc_ques = q[i].detach().cpu().numpy()
        img_v = v[i].detach().cpu().numpy()
        question = dataset.decode_question(enc_ques)
        answer_dec = dataset.decode_answer(preds[i])
        answer = dataset.decode_answer(enc_ans)
#         img_v = img_v.reshape(224, 224, 3)
        plt.figure()
        plt.imshow(img_v[0,:,:], interpolation='nearest')
        plt.show()
        question = question.replace("<pad>", "")
        question = question.replace("<start>", "")
        question = question.replace("<end>", "").strip()
        result = answer_dec==answer
        err_anal_data.append([question, answer_dec, answer])
        if not result:
            print("{}. [Q] {} [A] {} [PRED] {}".format(count, question, answer, answer_dec))
            count+=1
#         print(err_anal_data[-1])
#         print('question:',  question)
#         print("[{}] - predicted: {} - ground-truth: {}".format(answer_dec==answer, answer_dec, answer))
        
    torch.argmax(a)