In [None]:
from iterator import SmartIterator
from utils.visualization_utils import get_att_map, objdict, get_dict
from keras.models import load_model
from models import ReferringRelationshipsModel
from keras.utils import to_categorical
import numpy as np
import os
from PIL import Image
import json
import matplotlib.pyplot as plt
import h5py
from keras.models import Model
import seaborn as sns
%matplotlib inline
%load_ext autoreload
%autoreload 2

# VRD


In [None]:
annotations_test = json.load(open("data/VRD/annotations_test.json"))
img_dir = '/data/chami/VRD/sg_dataset/sg_test_images/'
vocab_dir = os.path.join('data/VRD')
model_checkpoint = "/data/chami/ReferringRelationships/models/VRD/11_02_2017/ssn/16/model10-1.37.h5"

### Setup

In [None]:
predicate_dict, obj_subj_dict = get_dict(vocab_dir)
image_ids = sorted(list(annotations_test.keys()))[:1000]
params = objdict(json.load(open(os.path.join(os.path.dirname(model_checkpoint), "args.json"), "r")))
params.use_internal_loss = False
relationships_model = ReferringRelationshipsModel(params)
test_generator = SmartIterator(params.test_data_dir, params)
images = test_generator.get_image_dataset()

### Load the model.

In [None]:
model = relationships_model.build_model()
model.load_weights(model_checkpoint)

In [None]:
output = model.get_layer("before-pred-subj").output
before_pred = Model(inputs=model.input, outputs=output)
output = model.get_layer("after-pred-subj").output
after_pred = Model(inputs=model.input, outputs=output)

### USER INPUT - Pick an image 

In [None]:
#################
image_index = np.random.randint(1000)
print(image_index)
#################
img = Image.open(os.path.join(img_dir, image_ids[image_index]))
img = img.resize((params.input_dim, params.input_dim))
plt.figure(figsize=(5,5))
plt.imshow(img)
plt.axis("off")

### USER INPUT - Pick a relationship 

In [None]:
#################
subj = "sky"
predicate = "in"
obj = "building"
#################
subj_id = np.zeros((1, 1))
predicate_id = np.zeros((1, params.num_predicates))
obj_id = np.zeros((1, 1))
relationship = [subj, predicate, obj]
subj_id[0, 0] = obj_subj_dict.index(subj)
predicate_id[0, predicate_dict.index(predicate)] = 1
obj_id[0, 0] = obj_subj_dict.index(obj)

### Run the model and visualize the heatmaps.

In [None]:
subject_heatmap, object_heatmap = model.predict([images[image_index:image_index+1], subj_id, predicate_id, obj_id])
att_map = get_att_map(img, np.maximum(subject_heatmap[0],0), np.maximum(object_heatmap[0],0), params.input_dim, relationship)
plt.figure(figsize=(15, 15))
plt.imshow(att_map)
plt.title("-".join(relationship))
plt.axis("off")

In [None]:
interp_method = 'gaussian'
map_1 = before_pred.predict([images[image_index:image_index+1], subj_id, predicate_id, obj_id])
map_2 = after_pred.predict([images[image_index:image_index+1], subj_id, predicate_id, obj_id])
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
img_1 = map_1.reshape((params.feat_map_dim, params.feat_map_dim))
plot0 = axes[0].imshow(img_1, interpolation=interp_method)
plot1 = axes[1].imshow(map_2.reshape((params.feat_map_dim, params.feat_map_dim)), interpolation=interp_method)
fig.colorbar(plot0, ax=axes[0])
axes[0].axis("off")
axes[0].set_title("before {}".format(predicate))
axes[1].axis("off")
axes[1].set_title("after {}".format(predicate))
fig.colorbar(plot1, ax=axes[1])

In [None]:
map_1 = before_pred.predict([images[image_index:image_index+1], subj_id, predicate_id, obj_id]).reshape((params.feat_map_dim, params.feat_map_dim))
map_2 = after_pred.predict([images[image_index:image_index+1], subj_id, predicate_id, obj_id]).reshape((params.feat_map_dim, params.feat_map_dim))
fig, axes = plt.subplots(1, 2, figsize=(30, 10))
sns.heatmap(map_1, annot=True, linewidths=.5, ax=axes[0])
sns.heatmap(map_2, annot=True, linewidths=.5, ax=axes[1])
for i in range(2):
    axes[i].axis("off")
axes[0].set_title("before-pred")
axes[1].set_title("after-pred {}".format(predicate))

# Clevr

In [None]:
annotations_test = json.load(open("/data/ranjaykrishna/ReferringRelationships/data/clevr/annotations_test.json"))
test_data_dir = '/data/ranjaykrishna/ReferringRelationships/data/dataset-clevr-small/test/'
img_dir = '/data/ranjaykrishna/clevr/images/val/'
vocab_dir = '/data/chami/ReferringRelationships/data/Clevr/'
model_checkpoint = "/data/chami/ReferringRelationships/models/Clevr/10_14_2017/2/model04-0.13.h5"
predicate_dict, obj_subj_dict = get_dict(vocab_dir)
image_ids = sorted(list(annotations_test.keys()))[:1000]
params = objdict(json.load(open(os.path.join(os.path.dirname(model_checkpoint), "args.json"), "r")))
params.use_internal_loss = False
params.categorical_predicate = False
#relationships_model = ReferringRelationshipsModel(params)
test_generator = SmartIterator(test_data_dir, params)
images = test_generator.get_image_dataset()
subj_id = np.zeros((1, 1))
predicate_id = np.zeros((1, params.num_predicates))
obj_id = np.zeros((1, 1))

In [None]:
#################
image_index = 0 
#################
import seaborn as sns
cmap_2 = sns.cubehelix_palette(8, start=2, rot=0, dark=0.3, light=.95, reverse=True, as_cmap=True)
fig, axes = plt.subplots(1, 3, figsize=(15,5))
img = Image.open(os.path.join(img_dir, image_ids[image_index]))
img = img.resize((params.input_dim, params.input_dim))
axes[0].imshow(img)
axes[0].axis("off")
axes[0].set_title("Original image")
axes[1].imshow(255*test_generator[0][1][0][0].reshape(224,224),cmap=cmap_2)
axes[1].axis("off")
axes[1].set_title("Subject bounding box")
axes[2].imshow(255*test_generator[0][1][1][0].reshape(224,224),cmap=cmap_2)
axes[2].axis("off")
axes[2].set_title("Object bounding box")