In [None]:
from iterator import SmartIterator
from utils.visualization_utils import get_att_map, objdict, get_dict, add_attention, add_bboxes
from keras.models import load_model
from models import ReferringRelationshipsModel
from keras.utils import to_categorical
import numpy as np
import os
from PIL import Image
from keras.models import Model
import json
import matplotlib
import matplotlib.pyplot as plt
import h5py
from keras.models import Model
import keras.backend as K
from keras.layers import Dense, Flatten, UpSampling2D, Input
import seaborn as sns
import tensorflow as tf
from scipy.misc import imresize

matplotlib.rcParams.update({'font.size': 34})
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
sess = tf.InteractiveSession()

In [None]:
###################
data_type = "vrd"
###################
annotations_file = "data/VRD/annotations_test.json"
img_dir = '/data/chami/VRD/sg_dataset/sg_test_images/'
vocab_dir = os.path.join('data/VRD')
model_checkpoint = "model29-1.33.h5"
model_checkpoint = "/data/chami/ReferringRelationships/models/VRD/11_07_2017/vgg_14/1/model30-1.61.h5"

In [None]:
annotations_test = json.load(open(annotations_file))
predicate_dict, obj_subj_dict = get_dict(vocab_dir)
image_ids = sorted(list(annotations_test.keys()))[:1000]
params = objdict(json.load(open(os.path.join(os.path.dirname(model_checkpoint), "args.json"), "r")))
relationships_model = ReferringRelationshipsModel(params)
test_generator = SmartIterator(params.test_data_dir, params)
images = test_generator.get_image_dataset()
print(' | '.join(obj_subj_dict))
print('')
print(' | '.join(predicate_dict))

In [None]:
model = relationships_model.build_model()
model.load_weights(model_checkpoint)

In [None]:
input_im = Input(shape=(params.input_dim, params.input_dim, 3))
input_pred = Input(shape=(params.num_predicates,))
input_obj = Input(shape=(1,))

In [None]:
# Image model that returns image feature maps
im_output = model.get_layer("conv2d_1").output
image_model = Model(inputs=model.inputs, outputs=im_output)

# Embedding weights that returns object embeddings
model_weights = h5py.File(model_checkpoint)
embeddings = model_weights["embedding_1"]["embedding_1"]["embeddings:0"][()]

In [None]:
convs = {}
for i, predicate in enumerate(predicate_dict):
    convs[predicate] = []
    for j in range(params.nb_conv_att_map):
        layer_name = "conv{}-predicate{}".format(j, i)
        convs[predicate] += [model_weights[layer_name][layer_name]["kernel:0"][()]]
        
convs_T = []     
upsampling_factor = params.input_dim / params.feat_map_dim
k = int(np.log(upsampling_factor) / np.log(2))
for i in range(k):
    layer_name = "subject-convT-{}".format(i)
    convs_T += [model_weights[layer_name][layer_name]["kernel:0"][()]]

In [None]:
def shift(att, convs, predicate):
    att = K.constant(att)
    for j in range(params.nb_conv_att_map):
        kernel = convs[predicate][j]
        att = K.conv2d(att, kernel, padding='same', data_format='channels_last')
        att = K.relu(att)
    att = att.eval()
    shifted_att = np.tanh(att)
    return att

def get_att(obj_idx, embeddings, im_features):
    obj_emb = embeddings[obj_idx,:].reshape((1, 1, 1, im_features.shape[-1]))
    att = (im_features*obj_emb).sum(axis=3, keepdims=True)
    att = np.tanh(att)
    att = (att>0)*att
    return att

def upsample(att, convs_transpose, k):
    _, shape, _, _ = att.shape
    att = K.constant(att)
    for i in range(k):
        kernel = convs_T[i]
        att = K.repeat_elements(att, 2, axis=1)
        att = K.repeat_elements(att, 2, axis=2)
        att = K.conv2d_transpose(att, kernel, padding='same', output_shape=(1, (2**(i+1)) * shape, (2**(i+1))*shape, 1))
        att = K.relu(att)
    att = att.eval()
    att = np.tanh(att)
    return att[0, :, :, 0]

In [None]:
#################
image_index = np.random.randint(1000)
print(image_index)
#################
img = Image.open(os.path.join(img_dir, image_ids[image_index]))
img = img.resize((params.input_dim, params.input_dim))
plt.figure(figsize=(5,5))
plt.imshow(img)
plt.axis("off")

In [None]:
predicate_id = np.zeros((1, params.num_predicates))
obj_id = np.zeros((1, 1))
im_features = image_model.predict([images[image_index:image_index+1], 
                                   np.zeros((1, 1)), 
                                   np.zeros((1, params.num_predicates)), 
                                   np.zeros((1, 1))])

In [None]:
objects = ["table", "chair", "computer"]
predicates = ["above", "below"]
nb_plots = 2 + 2 * len(predicates)
att = get_att(obj_subj_dict.index(objects[0]), embeddings, im_features)
fig, axes = plt.subplots(1, nb_plots, figsize=(20, 5))
ax_counter = 0
axes[ax_counter].imshow(img)
axes[ax_counter].axis("off")
axes[ax_counter].set_title("input image")
ax_counter += 1
axes[ax_counter].imshow(upsample(att, convs_T, k), interpolation='spline16')
axes[ax_counter].axis('off')
axes[ax_counter].set_title(objects[0])
for i in range(len(objects)-1):
    ax_counter += 1
    shifted_att = shift(att, convs, predicates[i])
    axes[ax_counter].set_title(predicates[i])
    axes[ax_counter].imshow(upsample(att, convs_T, k), interpolation='spline16')
    axes[ax_counter].axis('off')
    att = get_att(obj_subj_dict.index(objects[i+1]), embeddings, im_features*shifted_att)
    ax_counter += 1
    axes[ax_counter].imshow(upsample(att, convs_T, k), interpolation='spline16')
    axes[ax_counter].set_title(objects[i+1])
    axes[ax_counter].axis('off')