In [None]:
from iterator import PredicateIterator
from utils.visualization_utils import get_att_map, objdict, get_dict
from keras.models import load_model

import numpy as np
import os
from PIL import Image
import json
import matplotlib.pyplot as plt

%matplotlib inline
%load_ext autoreload
%autoreload 2

# VRD


In [None]:
annotations_test = json.load(open("data/VRD/annotations_test.json"))
test_data_dir = 'data/predicate-vrd/test'
img_dir = '/data/chami/VRD/sg_dataset/sg_test_images/'
vocab_dir = os.path.join('data/VRD')
model_checkpoint = "temp/pred-vrd/model49-1.08.h5"

# Clevr

In [None]:
annotations_test = json.load(open("data/clevr/annotations_test.json"))
test_data_dir = 'data/pred-clevr/test'
img_dir = '/data/ranjaykrishna/clevr/images/test/'
vocab_dir = os.path.join('data/VRD')
model_checkpoint = "temp/pred-vrd/model49-1.08.h5"

In [None]:
predicate_dict, obj_subj_dict = get_dict(vocab_dir)
image_ids = sorted(list(annotations_test.keys()))
params = objdict(json.load(open(os.path.join(os.path.dirname(model_checkpoint), "args.json"), "r")))
params.nb_conv_move_map = 3
if params.model == 'ssn':
    from ssn import ReferringRelationshipsModel
else:
    from model import ReferringRelationshipsModel
relationships_model = ReferringRelationshipsModel(params)
test_generator = PredicateIterator(test_data_dir, params)
images = test_generator.get_image_dataset()
subj_id = np.zeros((1, 1))
predicate_id = np.zeros((1, 1))
obj_id = np.zeros((1, 1))

### Load the model.

In [None]:
model = relationships_model.build_model()
model = load_model(model_checkpoint)

In [None]:
import keras
from keras.applications.resnet50 import ResNet50
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
output = base_model.get_layer('activation_40').output

### USER INPUT - Pick an image 

In [None]:
image_index = 10

In [None]:
img = Image.open(os.path.join(img_dir, image_ids[image_index]))
img = img.resize((params.input_dim, params.input_dim))
plt.figure(figsize=(5,5))
plt.imshow(img)

### USER INPUT - Pick a relationship 

In [None]:
subj = "bus"
predicate = "next to"
obj = "building"
relationship = [subj, predicate, obj]
subj_id[0, 0] = obj_subj_dict.index(subj)
predicate_id[0, 0] = predicate_dict.index(predicate)
obj_id[0, 0] = obj_subj_dict.index(obj)

In [None]:
subject_heatmap, object_heatmap = model.predict([images[image_index:image_index+1], subj_id, predicate_id, obj_id])

In [None]:
att_map = get_att_map(img.astype(np.float32), subject_heatmap[0], object_heatmap[0], params.input_dim, relationship)

In [None]:
plt.figure(figsize=(15, 15))
plt.imshow(att_map.astype(np.uint8))

In [None]:
from utils.visualization_utils import get_att_map
subject_heatmap = np.ones((params.input_dim, params.input_dim))
object_heatmap = np.zeros((params.input_dim, params.input_dim))
object_heatmap[0:50,0:100] = 0.2
new = get_att_map(img, subject_heatmap, object_heatmap, params.input_dim, relationship)
plt.figure(figsize=(10,45))
plt.imshow(new)

In [None]:
# Writing text test
from PIL import Image, ImageDraw, ImageFont
txt = Image.new('RGBA', image.size, (255,255,255,0))
fnt = ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 40)
d = ImageDraw.Draw(txt)
d.text((10,10), "Hello", font=fnt, fill=(255,255,255,128))
d.text((10,60), "World", font=fnt, fill=(255,255,255,255))
out = Image.alpha_composite(image.convert('RGBA'), txt)
plt.imshow(out)