# Visualizaing the predicate shifts

In the paper, we visualize all the predicate shifts that we learn. This notebook takes you through the process of creating such shifts.

In [None]:
from utils.visualization_utils import get_att_map, objdict, get_dict
from scipy.stats import multivariate_normal
import keras.backend as K

import numpy as np
import os
from PIL import Image
import json
import h5py
import seaborn as sns
import matplotlib.pyplot as plt
import tensorflow as tf

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
# Let's create initial attention.
im_width = 14
im_height = 14

def create_gaussian(center):
    xlim = (-2, 2)
    ylim = (-2, 2)
    kernel = multivariate_normal(mean=center, cov=np.eye(2))
    x = np.linspace(xlim[0], xlim[1], im_width)
    y = np.linspace(ylim[0], ylim[1], im_height)
    xx, yy = np.meshgrid(x,y)
    xxyy = np.c_[xx.ravel(), yy.ravel()]
    zz = kernel.pdf(xxyy)
    in_att = zz.reshape((im_height, im_width))
    return in_att

in_att = create_gaussian((0, 0))
plt.imshow(in_att, interpolation='spline16')
plt.show()

## Choose the dataset we want to visualize the predicates for.

In [None]:
###################
data_type = "visualgenome"
###################
if data_type=="vrd":
    nrows=7
    ncols=10
    figsize = (14,20)
    sym_ssn_checkpoint = "pretrained/vrd.h5"
    vocab_dir = os.path.join('data/VRD')
elif data_type=="clevr":
    nrows=3
    ncols=4
    figsize = (12,6)
    sym_ssn_checkpoint = "pretrained/clevr.h5"
    vocab_dir = os.path.join('data/Clevr/')
elif data_type=="visualgenome":
    nrows=7
    ncols=10
    figsize = (14,20)
    ssn_checkpoint = ""
    sym_ssn_checkpoint = "pretrained/visualgenome.h5"
    vocab_dir = os.path.join('data/VisualGenome/')

In [None]:
# Grab all the weights
predicate_dict, obj_subj_dict = get_dict(vocab_dir)
model_weights = h5py.File(sym_ssn_checkpoint)
params = objdict(json.load(open(os.path.join(os.path.dirname(sym_ssn_checkpoint), "args.json"), "r")))
conv_filters = {}
inv_conv_filters = {}
for i in range(params.num_predicates):
    predicate = predicate_dict[i]
    conv_filters[predicate] = []
    inv_conv_filters[predicate] = []
    for j in range(params.nb_conv_att_map):
        if 'conv0-predicate0' in model_weights:
            conv_weights_name = "conv{}-predicate{}".format(j, i)
            inv_conv_weights_name = "conv{}-inv-predicate{}".format(j, i)
        else:
            conv_weights_name = "conv{}-predicate{}-0".format(j, i)
            inv_conv_weights_name = "conv{}-predicate{}-1".format(j, i)
        if j == 0:
            conv_filters[predicate] += [model_weights[conv_weights_name][conv_weights_name]['kernel:0'][()]]
            inv_conv_filters[predicate] += [model_weights[inv_conv_weights_name][inv_conv_weights_name]['kernel:0'][()]]
        else:
            conv_filters[predicate] += [model_weights[conv_weights_name][conv_weights_name]['kernel:0'][()]]
            inv_conv_filters[predicate] += [model_weights[inv_conv_weights_name][inv_conv_weights_name]['kernel:0'][()]]

## Before we continue, let's visualize just one of them to make sure that everything works.

Make sure the 'above' is actually a predicate in the dataset you are visualizing. Otherwise, type in a different predicate here.

In [None]:
###################
predicate = "above"
###################
sess = tf.InteractiveSession()
att = in_att.reshape(1, im_height, im_width, 1)
att = K.constant(att)
for j in range(params.nb_conv_att_map):
    kernel = np.array(conv_filters[predicate][j])
    att = K.conv2d(att, kernel, padding='same', data_format='channels_last')
    att = K.relu(att)
att = K.sum(att, axis=3)
att = att.eval().reshape((im_height, im_width))
sess.close()
plt.imshow(att, interpolation='gaussian')
plt.show()

## Let's compute the shifts for all the predicates

In [None]:
# Compute all the attentions
shifts = {}
inv_shifts = {}
in_att = in_att.reshape(1, im_height, im_width, 1)
sess = tf.InteractiveSession()
for i in range(params.num_predicates):
    att = K.constant(in_att)
    inv_att = K.constant(in_att)
    predicate = predicate_dict[i]
    for j in range(params.nb_conv_att_map):
        kernel = np.array(conv_filters[predicate][j])
        att = K.conv2d(att, kernel, padding='same', data_format='channels_last')
        att = K.relu(att)
        inv_kernel = np.array(inv_conv_filters[predicate][j])
        inv_att = K.conv2d(inv_att, inv_kernel, padding='same', data_format='channels_last')
        inv_att = K.relu(inv_att)
    att = K.sum(att, axis=3)
    att = att.eval().reshape((im_height, im_width))
    inv_att = K.sum(inv_att, axis=3)
    inv_att = inv_att.eval().reshape((im_height, im_width))
    shifts[predicate] = att
    inv_shifts[predicate] = inv_att
sess.close()

## Now let's visualize all of them.

In [None]:
######################
interp_method = 'spline16'
######################
# plot all the shifts
fig, axs = plt.subplots(nrows=nrows*2, ncols=ncols, figsize=figsize)
fig.tight_layout()
row = 0
col = 0
for idx in range(params.num_predicates):
    ax = axs[row, col]
    predicate = predicate_dict[idx]
    im = shifts[predicate]
    plot = ax.imshow(im, interpolation=interp_method)
    ax.set_title(predicate)
    ax.axis("off")
    ax = axs[row, col+1]
    im = inv_shifts[predicate]
    plot = ax.imshow(im, interpolation=interp_method)
    ax.set_title("INV {}".format(predicate))
    ax.axis("off")
    col += 2
    if col >= ncols:
        row += 1
        col = 0
for row in range(nrows*2):
    for col in range(ncols):
        ax = axs[row, col]
        ax.axis("off")