In [None]:
from iterator import SmartIterator
from utils.visualization_utils import get_att_map, objdict, get_dict
import keras.backend as K

import numpy as np
import os
from PIL import Image
import json
import h5py
import seaborn as sns
import matplotlib.pyplot as plt
import tensorflow as tf

%matplotlib inline
%load_ext autoreload
%autoreload 2

### Specify data type here: clevr, vrd of vg

In [None]:
###################
data_type = "vrd"
###################
if data_type=="vrd":
    nrows=7
    ncols=5
    ssn_checkpoint = "/data/ranjaykrishna/ReferringRelationships/temp/vrd_ssn_convs3/model29-1.33.h5"
    sym_ssn_checkpoint = "/data/ranjaykrishna/ReferringRelationships/temp/vrd_sym_ssn_convs3/model29-1.29.h5"
    vocab_dir = os.path.join('data/VRD')
elif data_type=="clevr":
    nrows=3
    ncols=2
    ssn_checkpoint = "/data/ranjaykrishna/ReferringRelationships/temp/clevr_ssn/model03-0.15.h5"
    sym_ssn_checkpoint = "/data/ranjaykrishna/ReferringRelationships/temp/clevr_sym_ssn_convs3_iterations2/model00-0.18.h5"
    #annotations_test = json.load(open("/data/chami/ReferringRelationships/data/Clevr/annotations_test.json"))
    #img_dir = '/data/chami/ReferringRelationships/data/Clevr/images/val'
    vocab_dir = os.path.join('/data/chami/ReferringRelationships/data/Clevr/')
predicate_dict, obj_subj_dict = get_dict(vocab_dir)

## SSN Model ONLY

In [None]:
model_weights = h5py.File(ssn_checkpoint)
params = objdict(json.load(open(os.path.join(os.path.dirname(ssn_checkpoint), "args.json"), "r")))
#params.nb_conv_att_map = params.nb_conv_move_map
conv_filters = {}
for i in range(params.num_predicates):
    predicate = predicate_dict[i]
    conv_filters[predicate] = []
    for j in range(params.nb_conv_att_map):
        conv_weights_name = "conv{}-predicate{}-0".format(j, i)
        conv_filters[predicate] += [model_weights[conv_weights_name][conv_weights_name]['kernel:0'][()]]

### Show layer 0 - you can change the layer you want to visualize.

In [None]:
###################
#methods = [None, 'none', 'nearest', 'bilinear', 'bicubic', 'spline16',
#           'spline36', 'hanning', 'hamming', 'hermite', 'kaiser', 'quadric',
#           'catrom', 'gaussian', 'bessel', 'mitchell', 'sinc', 'lanczos']
interp_method = "spline16"
cmap_0 = sns.cubehelix_palette(light=0.9, as_cmap=True, dark=0.3)
cmap_1 = sns.light_palette("navy", as_cmap=True)
cmap_2 = sns.cubehelix_palette(8, start=2, rot=0, dark=0.3, light=.95, reverse=True, as_cmap=True)
cmap_3 = sns.color_palette("coolwarm", 8)
sns.set_palette(cmap_3)
layer = 0
###################
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(20,12))
fig.suptitle("Conv filters for layer {}".format(layer), fontsize=16)
idx = 0
for i in range(nrows):
    for j in range(ncols):
        predicate = predicate_dict[idx]
        ax = axs[i,j]
        im = np.uint8(conv_filters[predicate][layer].sum(axis=3).sum(axis=2)*255)
        plot = ax.imshow(im, interpolation=interp_method)
        ax.set_title(predicate_dict[idx])
        ax.axis("off")
        idx += 1
fig.colorbar(plot, ax=axs.ravel().tolist())

### Visualize the average weights for all layers.

In [None]:
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(20,12))
fig.suptitle("Conv filters for all layer", fontsize=16)
idx = 0
sess = tf.InteractiveSession()
for i in range(nrows):
    for j in range(ncols):
        predicate = predicate_dict[idx]
        ax = axs[i,j]
        kernel = K.constant(conv_filters[predicate][0])
        for k in range(1, params.nb_conv_att_map):
            kernel = K.conv2d(kernel, conv_filters[predicate][k], padding='same', data_format='channels_last')
            kernel = K.relu(kernel)
        kernel = K.sum(kernel, axis=3)
        kernel = K.sum(kernel, axis=2)
        im = np.uint8(kernel.eval()*255)
        plot = ax.imshow(im, interpolation=interp_method)
        ax.set_title(predicate_dict[idx])
        ax.axis("off")
        idx += 1
sess.close()
fig.colorbar(plot, ax=axs.ravel().tolist())

## SYM SSN model - PREDICATE AND INVERSE PREDICATE

In [None]:
model_weights = h5py.File(sym_ssn_checkpoint)
params = objdict(json.load(open(os.path.join(os.path.dirname(sym_ssn_checkpoint), "args.json"), "r")))
conv_filters = {}
inv_conv_filters = {}
for i in range(params.num_predicates):
    predicate = predicate_dict[i]
    conv_filters[predicate] = []
    inv_conv_filters[predicate] = []
    for j in range(params.nb_conv_att_map):
        if 'conv0-predicate0' in model_weights:
            conv_weights_name = "conv{}-predicate{}".format(j, i)
            inv_conv_weights_name = "conv{}-inv-predicate{}".format(j, i)
        else:
            conv_weights_name = "conv{}-predicate{}-0".format(j, i)
            inv_conv_weights_name = "conv{}-predicate{}-1".format(j, i)
        if j == 0:
            conv_filters[predicate] += [model_weights[conv_weights_name][conv_weights_name]['kernel:0'][()]]
            inv_conv_filters[predicate] += [model_weights[inv_conv_weights_name][inv_conv_weights_name]['kernel:0'][()]]
        else:
            conv_filters[predicate] += [model_weights[conv_weights_name][conv_weights_name]['kernel:0'][()]]
            inv_conv_filters[predicate] += [model_weights[inv_conv_weights_name][inv_conv_weights_name]['kernel:0'][()]]

### Show layer 0 - you can change the layer you want to visualize.

In [None]:
k = np.random.randint(70)
im = np.uint8(model_weights['conv0-predicate{}-0'.format(k)]['conv0-predicate{}-0'.format(k)]['kernel:0'][()][:,:,0,:].max(axis=2)*255)
plt.imshow(im, interpolation="gaussian", cmap=cmap_0)
plt.axis('off')
predicate_dict[k]

In [None]:
model_weights['conv0-predicate{}-0'.format(k)]['conv0-predicate{}-0'.format(k)]['kernel:0'][()].shape

In [None]:
model_weights['conv0-predicate0-0']['conv0-predicate0-0']['kernel:0'][()][:,:,0,0]

In [None]:
###################
layer = 0
###################
fig.suptitle("Conv filters for layer {}".format(layer), fontsize=16)
fig, axs = plt.subplots(nrows=nrows, ncols=ncols*2, figsize=(20,22))
fig.suptitle("Averaged conv and inverse conv filters for all layers", fontsize=16)
idx = 0
for i in range(nrows):
    for j in range(ncols):
        ax = axs[i, 2*j]
        predicate = predicate_dict[idx]
        im = np.uint8(conv_filters[predicate][layer]*255)
        plot = ax.imshow(im, interpolation=interp_method)
        ax.set_title(predicate_dict[idx])
        ax.axis("off")
        ax = axs[i, 2*j+1]
        im = np.uint8(inv_conv_filters[predicate][layer]*255)
        plot = ax.imshow(im, interpolation=interp_method)
        ax.set_title("INV {}".format(predicate_dict[idx]))
        ax.axis("off")
        idx += 1

### Visualize the average weights for all layers.

In [None]:
fig, axs = plt.subplots(nrows=nrows, ncols=ncols*2, figsize=(20,22))
fig.suptitle("Averaged conv and inverse conv filters for all layers", fontsize=16)
idx = 0
for i in range(nrows):
    for j in range(ncols):
        ax = axs[i, 2*j]
        predicate = predicate_dict[idx]
        im = np.uint8(conv_filters[predicate].mean(axis=0)*255)
        plot = ax.imshow(im, interpolation=interp_method)
        ax.set_title(predicate_dict[idx])
        ax.axis("off")
        ax = axs[i, 2*j+1]
        im = np.uint8(inv_conv_filters[predicate].mean(axis=0)*255)
        plot = ax.imshow(im, interpolation=interp_method)
        ax.set_title("INV {}".format(predicate_dict[idx]))
        ax.axis("off")
        idx += 1

### Let's visualize one of the predicate's conv kernels.