In [None]:
import os
import json
import random
import numpy as np
from six.moves import range
from six import iteritems
import h5py
from IPython.display import Image, display


import skimage.io
from skimage.transform import resize
from sklearn.preprocessing import normalize
from nltk.tokenize import word_tokenize

import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision

import options
from utils import utilities as utils
from dataloader import VisDialDataset
from torch.utils.data import DataLoader
from eval_utils.rank_answerer import rankABot
from eval_utils.rank_questioner import rankQBot
from utils import utilities as utils
from utils.visualize import VisdomVisualize

# Define Helper Functions

In [None]:
# pad image id to 12 digits with zeros out front, and add jpg extension
def image_id_to_suffix(image_id):
    return str(image_id).zfill(12) + '.jpg'

def get_image_path(image_id, images_path):
    for path in images_path:
        filename = path + image_id_to_suffix(image_id)
        if os.path.exists(filename):
            #info_json['images'].append({'id':image_id,'file_path':
            #                  os.path.join(os.path.basename(os.path.dirname(path)), os.path.basename(filename))})
            #second to last dir in path is train2014 or val2014, join with image filename
            return filename
        
    raise ValueError("Image id \"{}\" could not be found in given paths \"{}\""
                     .format(image_id, images_path))


def visualize_example(dialogue_entry, questions, answers, images_path, verbose=False):
    image_id = dialogue_entry['image_id']
    image_filename = get_image_path(image_id, images_path)
    if image_filename is not None:
        display(Image(filename=image_filename))
    else:
        print("Image could not be found.")

    if verbose:
        print("\nDialogue entry: \n{}".format(dialogue_entry))
        print("Image from filename {}\n".format(image_filename))
        
    print("Caption = \"{}\"\n".format(dialogue_entry['caption']))
    for turn in dialogue_entry['dialog']:
        question_id = turn['question']
        print("Question = \"{}\"".format(questions[question_id]))
        if 'answer' in turn:
            answer_id = turn['answer']
            print("\t\t\t\tAnswer = \"{}\"\n".format(answers[answer_id]))
        else:
            answer_options_ids = turn['answer_options']
            print("\t\t\t\tAnswer options = \n{}".format([answers[a_id] for a_id in answer_options_ids]))
            

def visualize_predictions(dialogue_entry, questions, answers, images_path, sortedScoreAll, verbose=False):
    
    image_id = dialogue_entry['image_id']
    #image_filenames = [path + str(dialogue_entry['image_id']).zfill(12) + '.jpg' for path in images_path]
    image_filename = get_image_path(image_id, images_path)
    display(Image(filename=image_filename))
    
        
    print("Caption = \"{}\"\n".format(dialogue_entry['caption']))
    for i, turn in enumerate(dialogue_entry['dialog']):
        question_id = turn['question']
        print("\nQuestion = \"{}\"".format(questions[question_id]))
        if 'answer' in turn:
            answer_id = turn['answer']
            print("\t\t\t\tGround Truth Answer = \"{}\"\n".format(answers[answer_id]))
        else:
            answer_options_ids = turn['answer_options']
            print("\t\t\t\tAnswer options = \n{}".format([answers[a_id] for a_id in answer_options_ids]))
        #print("i=",i)
        #print(question_turn_index)
        #print(len(dialogue_entry['dialog'])+i)
        if i==question_turn_index or i==len(dialogue_entry['dialog'])+question_turn_index:
            sortedScores, sortedInds = sortedScoreAll[i]
            sortedAnswers = [options_str[i] for i in sortedInds.data[0]]
            sortedAnswersStr = [answers[i] for i in sortedAnswers]
            print("\t\t\t\tSelected answers = {}".format(list(zip(sortedAnswersStr, list(sortedScores[0].data)))))
            
            
def rankOptions(options, scores):
    '''Rank a batch of examples against a list of options.'''
    numOptions = options.size(1)
    
    # Sort all predicted scores
    sortedScore, sortedInds = torch.sort(scores, 1, descending=True)
    #print("s = ", scores)
    #print("ss = ", sortedScore)
    #print("si = ", sortedInds)
    
    #sortedAnswers = [options_str[i] for i in sortedInds.data[0]]
    #sortedAnswersStr = [raw_data['data']['answers'][i] for i in sortedAnswers]
    #print("sa = ", sortedAnswersStr)
    return sortedScore, sortedInds
    

# Load data

In [None]:
# Parse dataset

static_params = {
    'numRounds':10,
    'useGPU': False,
    'imgNorm': 0,
    
    'inputImg': 'data/visdial/data_1.0_img.h5',    
    
    'inputJson': "/scr/anarc/motm/data/visdial/data/visdial_params_v2.json",
    'inputQues': '/scr/anarc/motm/data/visdial/data/visdial_data_v2.h5',
    'cocoDir': '/scr/anarc/motm/data/visdial/data/visdial_images',
    'cocoInfo': '/scr/anarc/motm/data/visdial/data/visdial_images/coco_info.json',
}

splits = ['val','test']
dataset = VisDialDataset(static_params, splits)


In [None]:
# Raw, unparsed data from json
inputs_val = {
                "dialog_path":"visdial_1.0_val.json",
                "image_locations":["visdial_images/VisualDialog_val2018"],
                "image_prefix": ["VisualDialog_val2018_"]
             }

inputs_test = {
                "dialog_path":"visdial_1.0_test.json",
                "image_locations":["visdial_images/VisualDialog_test2018"],
                "image_prefix": ["VisualDialog_test2018_"]
               }  
data_basedir = "../data/visdial/data"


val_dialog_path = os.path.join(data_basedir, inputs_val["dialog_path"])
val_image_paths = [os.path.join(data_basedir, location, prefix) 
                for location, prefix in list(zip(inputs_val["image_locations"],inputs_val["image_prefix"]))]
val_raw_data = json.load(open(val_dialog_path,'r'))


test_dialog_path = os.path.join(data_basedir, inputs_test["dialog_path"])
test_image_paths = [os.path.join(data_basedir, location, prefix) 
                for location, prefix in list(zip(inputs_test["image_locations"],inputs_test["image_prefix"]))]
test_raw_data = json.load(open(test_dialog_path,'r'))



In [None]:
qa_category_mapping = json.load(open('data/qa_category_mapping.json','r'))

cat = "color"
split = "val"
print("examples for category \"{}\":".format(cat))
print(qa_category_mapping[split][cat])

In [None]:
qa_category_mapping.keys()

# Configurable Parameters and Example Selection

In [None]:
params = {
    
    # A-Bot checkpoint
    #'startFrom': "./checkpoints/all_duplicate_duplicate_duplicate/abot_ep_22.vd",
    'startFrom': "./checkpoints/color/abot_ep_37.vd",
    
    # Q-Bot checkpoint should given if interactive dialog is required
    # 'qstartFrom': "./checkpoints/qbot_sl.vd",
    
    'beamSize': 5,
    'imgFeatureSize':16384,
    
}

for key, value in iteritems(static_params):
    params[key] = value

In [None]:
example_split = 'val'
example_index = 4
question_turn_index = -1

# View Example

In [None]:
dataset.split = example_split

example_parsed = dataset[example_index]
example_batched = dataset.collate_fn([example_parsed])
print("Parsed example from dataset = \n{}".format(example_parsed))
print("Batched parsed example from dataset = \n{}".format(example_batched))

In [None]:
raw_data = val_raw_data if example_split=='val' else test_raw_data
image_paths = val_image_paths if example_split=='val' else test_image_paths



example_raw_data = raw_data['data']['dialogs'][example_index]
options_str = example_raw_data['dialog'][question_turn_index]['answer_options']
num_turns = len(example_raw_data['dialog'])

visualize_example(example_raw_data, 
                  raw_data['data']['questions'], 
                  raw_data['data']['answers'], 
                  image_paths)



# Build and load and run the model

In [None]:


# RNG seed
manualSeed = 1597
random.seed(manualSeed)
torch.manual_seed(manualSeed)
if params['useGPU']:
    torch.cuda.manual_seed_all(manualSeed)

print('Loading json file: ' + params['inputJson'])
with open(params['inputJson'], 'r') as fileId:
    info = json.load(fileId)

wordCount = len(info['word2ind'])
# Add <START> and <END> to vocabulary
info['word2ind']['<START>'] = wordCount + 1
info['word2ind']['<END>'] = wordCount + 2
startToken = info['word2ind']['<START>']
endToken = info['word2ind']['<END>']
# Padding token is at index 0
vocabSize = wordCount + 3
print('Vocab size with <START>, <END>: %d' % vocabSize)

# Construct the reverse map
info['ind2word'] = {
    int(ind): word
    for word, ind in info['word2ind'].items()
}
    
def loadModel(params, agent='abot'):
    # should be everything used in encoderParam, decoderParam below
    encoderOptions = [
        'encoder', 'vocabSize', 'embedSize', 'rnnHiddenSize', 'numLayers',
        'useHistory', 'useIm', 'imgEmbedSize', 'imgFeatureSize', 'numRounds',
        'dropout'
    ]
    decoderOptions = [
        'decoder', 'vocabSize', 'embedSize', 'rnnHiddenSize', 'numLayers',
        'dropout'
    ]
    modelOptions = encoderOptions + decoderOptions

    mdict = None
    gpuFlag = params['useGPU']
    startArg = 'startFrom' if agent == 'abot' else 'qstartFrom'
    assert params[startArg], "Need checkpoint for {}".format(agent)

    if params[startArg]:
        print('Loading model (weights and config) from {}'.format(
            params[startArg]))

        if gpuFlag:
            mdict = torch.load(params[startArg])
        else:
            mdict = torch.load(params[startArg],
                map_location=lambda storage, location: storage)

        # Model options is a union of standard model options defined
        # above and parameters loaded from checkpoint
        modelOptions = list(set(modelOptions).union(set(mdict['params'])))
        for opt in modelOptions:
            if opt not in params:
                params[opt] = mdict['params'][opt]

            elif params[opt] != mdict['params'][opt]:
                # Parameters are not overwritten from checkpoint
                pass

    # Initialize model class
    encoderParam = {k: params[k] for k in encoderOptions}
    decoderParam = {k: params[k] for k in decoderOptions}

    encoderParam['startToken'] = encoderParam['vocabSize'] - 2
    encoderParam['endToken'] = encoderParam['vocabSize'] - 1
    decoderParam['startToken'] = decoderParam['vocabSize'] - 2
    decoderParam['endToken'] = decoderParam['vocabSize'] - 1

    if agent == 'abot':
        encoderParam['type'] = params['encoder']
        decoderParam['type'] = params['decoder']
        encoderParam['isAnswerer'] = True
        from visdial.models.answerer import Answerer
        model = Answerer(encoderParam, decoderParam)
        print("e param = ", encoderParam)
        print("e = ", model.encoder)

    elif agent == 'qbot':
        encoderParam['type'] = params['qencoder']
        decoderParam['type'] = params['qdecoder']
        encoderParam['isAnswerer'] = False
        encoderParam['useIm'] = False
        from visdial.models.questioner import Questioner
        model = Questioner(
            encoderParam,
            decoderParam,
            imgFeatureSize=encoderParam['imgFeatureSize'])

    if params['useGPU']:
        model.cuda()

    if mdict:
        model.load_state_dict(mdict['model'])
        
    print("Loaded agent {}".format(agent))
    return model

aBot = None
qBot = None

# load aBot
if params['startFrom']:
    aBot = loadModel(params, 'abot')
    assert aBot.encoder.vocabSize == vocabSize, "Vocab size mismatch!"
    aBot.eval()

# load qBot
if params['qstartFrom']:
    qBot = loadModel(params, 'qbot')
    assert qBot.encoder.vocabSize == vocabSize, "Vocab size mismatch!"
    qBot.eval()

# load pre-trained VGG 19
print("Loading image feature extraction model")
feat_extract_model = torchvision.models.vgg19(pretrained=True)

feat_extract_model.classifier = nn.Sequential(*list(feat_extract_model.classifier.children())[:-3])
# print(feat_extract_model)
feat_extract_model.eval()

if params['useGPU']:
    feat_extract_model.cuda()

print("Done!")

In [None]:
image = Variable(example_batched['img_feat'], volatile=True)
caption = Variable(example_batched['cap'], volatile=True)
captionLens = Variable(example_batched['cap_len'], volatile=True)
questions = Variable(example_batched['ques'], volatile=True)
quesLens = Variable(example_batched['ques_len'], volatile=True)
answers = Variable(example_batched['ans'], volatile=True)
ansLens = Variable(example_batched['ans_len'], volatile=True)
options = Variable(example_batched['opt'], volatile=True)
optionLens = Variable(example_batched['opt_len'], volatile=True)
#correctOptionInds = Variable(example_batched['ans_id'], volatile=True)

numRounds = dataset.numRounds
sortedScoreAll = []
logProbsAll = [[] for _ in range(numRounds)]
scoringFunction=utils.maskedNll

aBot.reset()
aBot.observe(-1, image=image, caption=caption, captionLens=captionLens)
for round in range(numRounds):
    print("Round = ", round)
    if quesLens[0][round].data[0] == 1:
        print("skipping round")
        continue
    
    aBot.observe(
        round,
        ques=questions[:, round],
        quesLens=quesLens[:, round],
        ans=answers[:, round],
        ansLens=ansLens[:, round])
    #print("opt = ", options[:,round])
    logProbs = aBot.evalOptions(options[:, round],
                                optionLens[:, round], scoringFunction)
    #print("lp = ", logProbs)
    logProbsCurrent = aBot.forward()
    #print("lpc = ", logProbsCurrent)
    logProbsAll[round].append(
        scoringFunction(logProbsCurrent,
                        answers[:, round].contiguous()))
    #print("lpa = ", logProbsAll)
    sortedScore, sortedInds = rankOptions(options[:,round], logProbs)
    sortedScoreAll.append((sortedScore,sortedInds))
    #batchRanks = rankOptions(options[:, round],
    #                         correctOptionInds[:, round], logProbs)
    #ranks.append(batchRanks)

# Analyze the result

In [None]:
visualize_predictions(example_raw_data, 
                      raw_data['data']['questions'], 
                      raw_data['data']['answers'], 
                      image_paths,
                      sortedScoreAll)
    