# Visualize some results.

Sometimes, its useful to visualize your model's results to see what it gets right and what it gets wrong. This notebook guides you through iterativing through the dataset and visualizing some results.

In [None]:
from iterator import SmartIterator
from utils.visualization_utils import get_att_map, objdict, get_dict, add_attention, add_bboxes, get_bbox_from_heatmap, add_bbox_to_image
from keras.models import load_model
from models import ReferringRelationshipsModel
from keras.utils import to_categorical
import numpy as np
import os
from PIL import Image
import json
import matplotlib
import matplotlib.pyplot as plt
import h5py
from keras.models import Model
import seaborn as sns
from scipy.misc import imresize
from urllib.request import urlopen
from io import BytesIO
from keras.applications.resnet50 import preprocess_input

matplotlib.rcParams.update({'font.size': 34})
%matplotlib inline
%load_ext autoreload
%autoreload 2

## Choose the dataset you want to visualize.

Note that you will have to point the `img_dir` variable to where you saved the images for that dataset.

In [None]:
###################
data_type = "clevr"
###################
if data_type=="vrd":
    annotations_file = "data/VRD/annotations_test.json"
    img_dir = '/data/chami/VRD/sg_dataset/sg_test_images/' # You will have to change this to where your images are stored.
    vocab_dir = os.path.join('data/VRD')
    model_checkpoint = "pretrained/vrd.h5"
elif data_type=="clevr":
    annotations_file = "data/clevr/annotations_test.json"
    img_dir = '/data/ranjaykrishna/clevr/images/test'  # You will have to change this to where your images are stored.
    vocab_dir = os.path.join('data/clevr')
    model_checkpoint = "pretrained/clevr.h5"
elif data_type=="visualgenome":
    annotations_file = "data/VisualGenome/annotations_test.json"
    img_dir = '/data/chami/VRD/sg_dataset/sg_test_images/'  # You will have to change this to where your images are stored.
    vocab_dir = os.path.join('data/VisualGenome')
    model_checkpoint = "pretrained/visualgenome.h5"

## Setup

In [None]:
def iou(y_true, y_pred, thresh=0.5, eps=10e-8):
    y_pred = y_pred > thresh
    intersection = (y_pred * y_true).sum(axis=1)
    union = eps + ((y_pred + y_true)>0).sum(axis=1)
    return list(intersection/union)

def recall(y_true, y_pred, thresh=0.5, eps=10e-8):
    y_pred = y_pred > thresh
    tp = (y_pred * y_true).sum(axis=1)
    fn = (1*((y_true - y_pred)>0)).sum(axis=1)
    recall = tp/(tp+fn+eps)
    return list(recall)

def precision(y_true, y_pred, thresh=0.5, eps=10e-8):
    y_pred = y_pred > thresh
    tp = (y_pred * y_true).sum(axis=1)
    p = y_pred.sum(axis=1)
    prec = tp/(p+eps)
    return list(prec)

In [None]:
annotations_test = json.load(open(annotations_file))
predicate_dict, obj_subj_dict = get_dict(vocab_dir)
image_ids = sorted(list(annotations_test.keys()))
params = objdict(json.load(open(os.path.join(os.path.dirname(model_checkpoint), "args.json"), "r")))
params.norm_scale=1
params.use_internal_loss = False
params.cnn = 'resnet'
params.batch_size = 1
params.discovery = False
relationships_model = ReferringRelationshipsModel(params)
test_generator = SmartIterator(params.test_data_dir, params)
print(' | '.join(obj_subj_dict))
print('')
print(' | '.join(predicate_dict))

## Load the model.

In [None]:
model = relationships_model.build_model()
model.load_weights(model_checkpoint)

# Visualize some results.

The `rel_range` indicates how many and starting from what index you want to visualize results.

In [None]:
#################
rel_range = [200, 300]
#################
metrics = [iou, recall, precision]
metrics = [lambda x, y: precision(x, y, thresh=0.8)]

for rel_idx in range(rel_range[0], rel_range[1]):
    inputs, outputs = test_generator[rel_idx]
    s_pred, o_pred = model.predict(inputs)
    
    # Evaluate
    results = {}
    for metric in metrics:
        results['s_' + metric.__name__] = metric(outputs[0], s_pred)
        results['o_' + metric.__name__] = metric(outputs[1], o_pred)
    
    # visualize
    relationship = [obj_subj_dict[int(inputs[1][0,0])], 
                    predicate_dict[int(inputs[2][0,0])], 
                    obj_subj_dict[int(inputs[3][0,0])]] 
    
    #image_index = #TODO
    #img = Image.open(os.path.join(img_dir, image_ids[image_index]))
    #img = img.resize((params.input_dim, params.input_dim))
    img = inputs[0][0] + np.array([103.939, 116.779, 123.68])
    img = Image.fromarray(img.astype('uint8'), 'RGB')
    att_map = get_att_map(img, np.maximum(s_pred[0],0), o_pred[0], params.input_dim, relationship)
    plt.figure(figsize=(15, 15))
    plt.imshow(att_map)
    plt.axis("off")
    plt.show()
    print(relationship)
    print(results)