# Attention Saccades

In the paper, we show that we can attend over a scene graph and localize each entity in a sequence. This notebook allows you to visualize some user specified scene graphs.

Note that this notebook is not commented and might require some changes to run. Feel free to update it and send us a pull request.

In [None]:
from iterator import SmartIterator
from utils.visualization_utils import get_att_map, objdict, get_dict, add_attention, add_bboxes, get_bbox_from_heatmap, add_bbox_to_image
from keras.models import load_model
from models import ReferringRelationshipsModel
from keras.utils import to_categorical
import numpy as np
import os
from PIL import Image
from keras.models import Model
import json
import matplotlib
import matplotlib.pyplot as plt
import h5py
from keras.models import Model
import keras.backend as K
from keras.layers import Dense, Flatten, UpSampling2D, Input
import seaborn as sns
import tensorflow as tf
from scipy.misc import imresize

matplotlib.rcParams.update({'font.size': 34})
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
sess = tf.InteractiveSession()

In [None]:
###################
img_dir = '/data/chami/VRD/sg_dataset/sg_test_images/'
###################
annotations_file = "data/VRD/annotations_test.json"
vocab_dir = os.path.join('data/VRD')
model_checkpoint = "pretrained/vrd.h5"

In [None]:
annotations_test = json.load(open(annotations_file))
predicate_dict, obj_subj_dict = get_dict(vocab_dir)
image_ids = sorted(list(annotations_test.keys()))[:1000]
params = objdict(json.load(open(os.path.join(os.path.dirname(model_checkpoint), "args.json"), "r")))
params.cnn = 'resnet'
params.discovery = False
relationships_model = ReferringRelationshipsModel(params)
test_generator = SmartIterator(params.test_data_dir, params)
images = test_generator.get_image_dataset()
print(' | '.join(obj_subj_dict))
print('')
print(' | '.join(predicate_dict))

In [None]:
model = relationships_model.build_model()
model.load_weights(model_checkpoint)

In [None]:
input_im = Input(shape=(params.input_dim, params.input_dim, 3))
input_pred = Input(shape=(params.num_predicates,))
input_obj = Input(shape=(1,))

In [None]:
# Image model that returns image feature maps
im_output = model.get_layer("conv2d_1").output
image_model = Model(inputs=model.inputs, outputs=im_output)

# Embedding weights that returns object embeddings
model_weights = h5py.File(model_checkpoint)
embeddings = model_weights["embedding_1"]["embedding_1"]["embeddings:0"][()]

In [None]:
convs = {}
for i, predicate in enumerate(predicate_dict):
    convs[predicate] = []
    for j in range(params.nb_conv_att_map):
        layer_name = "conv{}-predicate{}".format(j, i)
        convs[predicate] += [model_weights[layer_name][layer_name]["kernel:0"][()]]
        
convs_T = []     
upsampling_factor = params.input_dim / params.feat_map_dim
k = int(np.log(upsampling_factor) / np.log(2))
for i in range(k):
    layer_name = "subject-convT-{}".format(i)
    convs_T += [model_weights[layer_name][layer_name]["kernel:0"][()]]

In [None]:
def shift(att, convs, predicate):
    att = K.constant(att)
    for j in range(params.nb_conv_att_map):
        kernel = convs[predicate][j]
        att = K.conv2d(att, kernel, padding='same', data_format='channels_last')
        att = K.relu(att)
    att = att.eval()
    shifted_att = np.tanh(att)
    return att

def get_att(obj_idx, embeddings, im_features):
    obj_emb = embeddings[obj_idx,:].reshape((1, 1, 1, im_features.shape[-1]))
    att = (im_features*obj_emb).sum(axis=3, keepdims=True)
    #att = np.tanh(att)
    att = (att>0)*att
    return att

def upsample(att, convs_transpose, k):
    _, shape, _, _ = att.shape
    att = K.constant(att)
    for i in range(k):
        kernel = convs_T[i]
        att = K.repeat_elements(att, 2, axis=1)
        att = K.repeat_elements(att, 2, axis=2)
        att = K.conv2d_transpose(att, kernel, padding='same', output_shape=(1, (2**(i+1)) * shape, (2**(i+1))*shape, 1))
        att = K.relu(att)
    att = K.tanh(att)
    att = att.eval()
    return att[0, :, :, 0]

In [None]:
#################
image_index = np.random.randint(1000)
print(image_index)
#################
img = Image.open(os.path.join(img_dir, image_ids[image_index]))
img = img.resize((params.input_dim, params.input_dim))
plt.figure(figsize=(5,5))
plt.imshow(img)
plt.axis("off")

In [None]:
predicate_id = np.zeros((1, params.num_predicates))
obj_id = np.zeros((1, 1))
im_features = image_model.predict([images[image_index:image_index+1], 
                                   np.zeros((1, 1)), 
                                   np.zeros((1, params.num_predicates)), 
                                   np.zeros((1, 1))])

In [None]:
objects = ["plate", "table", "person"]
predicates = ["on", "on the right of"]
nb_plots = 2 + 2 * len(predicates)
att = get_att(obj_subj_dict.index(objects[0]), embeddings, im_features)
fig, axes = plt.subplots(1, nb_plots, figsize=(20, 5))
for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])
ax_counter = 0
axes[ax_counter].imshow(img)
axes[ax_counter].set_xlabel("input image", {'fontsize': 18})
ax_counter += 1
axes[ax_counter].imshow(upsample(att, convs_T, k), interpolation='spline16')
axes[ax_counter].set_xlabel(objects[0], {'fontsize': 18})
for i in range(len(objects)-1):
    ax_counter += 1
    shifted_att = shift(att, convs, predicates[i])
    axes[ax_counter].set_xlabel(predicates[i], {'fontsize': 18})
    axes[ax_counter].imshow(upsample(att, convs_T, k), interpolation='spline16')
    att = get_att(obj_subj_dict.index(objects[i+1]), embeddings, im_features*shifted_att)
    ax_counter += 1
    axes[ax_counter].imshow(upsample(att, convs_T, k), interpolation='spline16')
    axes[ax_counter].set_xlabel(objects[i+1], {'fontsize': 18})

## Visualize saccades

The cell below is attempting to attend over 'plate -> on -> table -> on the right of -> person'. We encode the entities in the objects list and the predicates in the predicates list.

In [None]:
#############################
objects = ["plate", "table", "person"]
predicates = ["on", "on the right of"]
threshold = [0.5, 0.3, 0.2] 
#############################
ncols = 2*(len(objects) + 1)
nrows = 4
fig = plt.figure(figsize=(14, 6))

ax = plt.subplot2grid((nrows, ncols), (1, 0), colspan=2, rowspan=2)
ax.imshow(img)
ax.set_xticks([])
ax.set_yticks([])

att = img
features = im_features
for i in range(len(objects)):
    att = get_att(obj_subj_dict.index(objects[i]), embeddings, features)
    up_att = upsample(att, convs_T, k)
    ax = plt.subplot2grid((nrows, ncols), (0, 2*i+2), colspan=2, rowspan=2)
    ax.imshow(up_att, interpolation='spline16')
    ax.set_xticks([])
    ax.set_yticks([])
    
    bbox = get_bbox_from_heatmap(up_att, threshold=threshold[i])
    bboxed_image = add_bbox_to_image(img, bbox, color='red', width=3)
    ax = plt.subplot2grid((nrows, ncols), (2, 2*i+2), colspan=2, rowspan=2)
    ax.imshow(bboxed_image)
    axes[ax_counter].set_xlabel(objects[i], {'fontsize': 18})
    ax.set_xticks([])
    ax.set_yticks([])
    
    if i >= len(predicates):
        break
    att = shift(att, convs, predicates[i])
    features = features*shifted_att
    
plt.tight_layout(pad=0.1, w_pad=-1, h_pad=-2)