In [None]:
from iterator import SmartIterator
from utils.visualization_utils import get_att_map, objdict, get_dict
from models import ReferringRelationshipsModel
from utils.eval_utils import iou_bbox

from sklearn.metrics import roc_auc_score
from keras import backend as K
import numpy as np
import os
from PIL import Image
import json
import h5py
import seaborn as sns
import matplotlib.pyplot as plt

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
def sim_metric_np(y_true, y_pred, eps=10e-8):
    y_true = (y_true.T/(eps + y_true.sum(axis=1).T)).T
    y_pred = (y_pred.T/(eps + y_pred.sum(axis=1).T)).T
    mini = ((y_true*(y_true<y_pred)) + (y_pred*(y_pred<y_true))).sum(axis=1)
    return mini

def iou_np(y_true, y_pred, thresh=0.5, eps=10e-8):
    y_pred = y_pred > thresh
    intersection = (y_pred * y_true).sum(axis=1)
    union = eps + ((y_pred + y_true)>0).sum(axis=1)
    return intersection/union

def kl_metric_np(y_true, y_pred, eps=10e-8):
    y_true = (y_true.T/(eps + y_true.sum(axis=1).T)).T
    y_pred = (y_pred.T/(eps + y_pred.sum(axis=1).T)).T
    x = np.log(eps+(y_true/(eps+y_pred)))
    return (x*y_true).sum(axis=1)

def cc_metric(y_true, y_pred, eps=10e-10):
    sigma_true = y_true.var(axis=1)
    sigma_pred = y_pred.var(axis=1)
    cov = (y_true*y_pred-(y_true.mean(axis=1, keepdims=True)*y_pred.mean(axis=1, keepdims=True))).mean(axis=1)
    return cov/np.sqrt((sigma_true*sigma_true)+eps)

In [None]:
def evaluate_model(model_checkpoint):
    params = objdict(json.load(open(os.path.join(os.path.dirname(model_checkpoint), "args.json"), "r")))
    test_data_dir = params.test_data_dir
    params.batch_size = 64
    test_generator = SmartIterator(test_data_dir, params)
    model_weights = h5py.File(model_checkpoint)
    relationships_model = ReferringRelationshipsModel(params)
    model = relationships_model.build_model()
    model.load_weights(model_checkpoint)
    model_name = os.path.basename(model_checkpoint)
    sim_s = []
    sim_o = []
    iou_s = []
    iou_o = []
    kl_s = []
    kl_o = []
    cc_s = []
    cc_o = []
    for i in range(len(test_generator)):
        if i%10 == 0:
            print("{}/{}".format(i, len(test_generator)))
        batch_in, batch_out = test_generator[i]
        preds = model.predict(batch_in)
        preds[0] = np.maximum(0, preds[0])
        preds[1] = np.maximum(0, preds[1])
        sim_s += list(sim_metric_np(batch_out[0], preds[0]))
        sim_o += list(sim_metric_np(batch_out[1], preds[1]))
        iou_s += list(iou_np(batch_out[0], preds[0]))
        iou_o += list(iou_np(batch_out[1], preds[1]))
        kl_s += list(kl_metric_np(batch_out[0], preds[0]))
        kl_o += list(kl_metric_np(batch_out[1], preds[1]))
        cc_s += list(cc_metric(batch_out[0], preds[0]))
        cc_o += list(cc_metric(batch_out[1], preds[1]))
    print("{} | {:.4f} & {:.4f} & {:.4f} & {:.4f} & {:.4f} & {:.4f} & {:.4f} & {:.4f}".format(
        model_name,
        np.mean(iou_s), np.mean(iou_o), 
        np.mean(sim_s), np.mean(sim_o), 
        np.mean(kl_s), np.mean(kl_o), 
        np.mean(cc_s), np.mean(cc_o)))
    return preds

In [None]:
model_checkpoint = "/data/chami/ReferringRelationships/models/VRD/10_27_2017/baseline/8/model12-0.51.h5"
preds = evaluate_model(model_checkpoint)

In [None]:
model_checkpoint = "/data/chami/ReferringRelationships/models/VRD/11_02_2017/ssn/10/model37-1.34.h5"
preds = evaluate_model(model_checkpoint)

In [None]:
model_checkpoint = "/data/chami/ReferringRelationships/models/VRD/11_02_2017/ssn/8/model18-0.92.h5"
preds = evaluate_model(model_checkpoint)

In [None]:
model_checkpoint = "/data/chami/ReferringRelationships/models/VRD/11_02_2017/sym_ssn/2/model19-1.01.h5"
preds = evaluate_model(model_checkpoint)

In [None]:
model_checkpoint = "/data/chami/ReferringRelationships/models/VRD/11_02_2017/sym_ssn/2/model21-0.97.h5"
preds = evaluate_model(model_checkpoint)

### VRD Models

In [None]:
# models selected with val iou
best_baseline_checkpoint = "/data/chami/ReferringRelationships/models/VRD/10_27_2017/baseline/8/model12-0.51.h5"
best_baseline_no_pred_checkpoint = "/data/chami/ReferringRelationships/models/VRD/10_27_2017/baseline_no_predicate/22/model36-0.53.h5"
best_ssn_checkpoint = "/data/chami/ReferringRelationships/models/VRD/10_27_2017/ssn/3/model35-0.54.h5"
best_sym_ssn_checkpoint = "/data/chami/ReferringRelationships/models/VRD/10_27_2017/sym_ssn/21/model11-0.52.h5"
best_sym_ssn_internal_loss_checkpoint = "/data/chami/ReferringRelationships/models/VRD/10_31_2017/1/model18-1.06.h5"

# models selected with val iou bbox
#best_baseline_checkpoint = "/data/chami/ReferringRelationships/models/VRD/10_27_2017/baseline/3/model17-0.50.h5"
#best_baseline_no_pred_checkpoint = '/data/chami/ReferringRelationships/models/VRD/10_27_2017/baseline_no_predicate/3/model25-0.51.h5'
#best_ssn_checkpoint = '/data/chami/ReferringRelationships/models/VRD/10_27_2017/ssn/22/model12-0.55.h5'
#best_sym_ssn_checkpoint = '/data/chami/ReferringRelationships/models/VRD/10_27_2017/sym_ssn/12/model17-0.54.h5'

# models selected with val_loss
#best_baseline_checkpoint = "/data/chami/ReferringRelationships/models/VRD/10_27_2017/baseline/2/model34-0.49.h5"
#best_baseline_no_pred_checkpoint = '/data/chami/ReferringRelationships/models/VRD/10_27_2017/baseline_no_predicate/3/model31-0.50.h5'
#best_ssn_checkpoint = '/data/chami/ReferringRelationships/models/VRD/10_27_2017/ssn/19/model47-0.50.h5'
#best_sym_ssn_checkpoint = '/data/chami/ReferringRelationships/models/VRD/10_27_2017/sym_ssn/5/model47-0.50.h5'


### CLEVR Models

In [None]:
# models selected with val iou
best_baseline_checkpoint = "/data/chami/ClevrModels/baseline/model.h5"
best_baseline_no_pred_checkpoint = "/data/chami/ClevrModels/baseline_no_predicate/model16-0.05.h5"
best_ssn_checkpoint = "/data/chami/ClevrModels/ssn/model.h5"
best_sym_ssn_checkpoint = "/data/chami/ClevrModels/sym_ssn/model.h5"

In [None]:
model_checkpoint = "/data/chami/ClevrModels/baseline_no_predicate/model16-0.05.h5"
params = objdict(json.load(open(os.path.join(os.path.dirname(model_checkpoint), "args.json"), "r")))
test_data_dir = params.test_data_dir
test_data_dir = '/data/chami/CLEVR/test'
params.use_internal_loss = False
params.batch_size = 512
params.att_activation = 'tanh'
params.use_predicate = 0
test_generator = SmartIterator(test_data_dir, params)

## Test Results 

In [None]:
evaluate_model(best_baseline_checkpoint)

In [None]:
evaluate_model(best_baseline_no_pred_checkpoint)

In [None]:
evaluate_model(best_ssn_checkpoint)

In [None]:
evaluate_model(best_sym_ssn_checkpoint)

In [None]:
evaluate_model(best_sym_ssn_internal_loss_checkpoint)

In [None]:
def sim_metric(y_true, y_pred):
    y_true = y_true / K.sum(y_true, axis=1)
    y_pred = y_pred / K.sum(y_pred, axis=1)
    mini_idx = K.cast(K.greater(y_pred, y_true), "float32")
    mini = y_true*mini_idx + y_pred*(1-mini_idx)
    res = K.mean(K.sum(mini, axis=1), axis=0)
    return res

In [None]:
model_checkpoint = "/data/chami/ReferringRelationships/models/VRD/10_27_2017/baseline_no_predicate/3/model25-0.51.h5"
annotations_test = json.load(open("data/VRD/annotations_test.json"))
img_dir = '/data/chami/VRD/sg_dataset/sg_test_images/'
vocab_dir = os.path.join('data/VRD')
predicate_dict, obj_subj_dict = get_dict(vocab_dir)
params = objdict(json.load(open(os.path.join(os.path.dirname(model_checkpoint), "args.json"), "r")))
test_data_dir = params.test_data_dir
model_weights = h5py.File(model_checkpoint)
relationships_model = ReferringRelationshipsModel(params)
model = relationships_model.build_model()
model.load_weights(model_checkpoint)

In [None]:
iou_bbox_metric = lambda gt, pred: iou_bbox(gt, pred, 0.5, params.input_dim)
iou_bbox_metric.__name__ = "iou_bbox"
model.compile(loss=['binary_crossentropy', 'binary_crossentropy'], optimizer='sgd', metrics=[iou_bbox_metric, sim_metric])

In [None]:
params.batch_size = 1
test_generator = SmartIterator(test_data_dir, params)
test_steps = len(test_generator)
test_steps = 10
outputs = model.evaluate_generator(generator=test_generator,
                                       steps=test_steps,
                                       use_multiprocessing=params.multiprocessing,
                                       workers=params.workers)

In [None]:
def get_bbox(pred, gt, thresh=0.5):
    pred = pred.reshape(224, 224)
    gt = gt.reshape(224, 224)
    pred = pred > thresh
    horiz = pred.sum(axis=1, keepdims=True)
    horiz = horiz > 0
    vert = pred.sum(axis=0, keepdims=True)
    vert = vert > 0
    mask_horiz = np.repeat(horiz, 224, axis=1)
    mask_vert = np.repeat(vert, 224, axis=0)
    mask = mask_horiz * mask_vert
    intersection = mask*gt
    union = (mask+gt)>0
    print("iou bbox = {}".format(round(intersection.sum()*1./union.sum(), 2)))
    return mask

In [None]:
nb_examples = 50
test_generator.batch_size = 1
#k = 36
k = np.random.randint(nb_examples)
print(k)
image, subj, obj = test_generator.__getitem__(k)[0]
plt.imshow(123+image[0].astype(np.uint8))
gt_subj, gt_obj = test_generator.__getitem__(k)[1]
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
axes[0].imshow((255*gt_subj[0]).astype(np.uint8).reshape(224, 224))
axes[1].imshow((255*gt_obj[0]).astype(np.uint8).reshape(224, 224))
print(obj_subj_dict[int(subj[0])])
print(obj_subj_dict[int(obj[0])])
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
out = model.predict_generator(test_generator, steps=nb_examples)
axes[0].imshow((255*out[0][k]).astype(np.uint8).reshape(224, 224))
axes[1].imshow((255*out[1][k]).astype(np.uint8).reshape(224, 224))
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
bbox1 = get_bbox(out[0][k], gt_subj[0])
print("sim_metric subj = {}".format(sim_metric_np(gt_subj[0], out[0][k])))
print("KL subj = {}".format(kl_metric(gt_subj[0], out[0][k])))
print("auc score subj = {}".format(auc_metric(gt_subj[0], out[0][k])))
bbox2 = get_bbox(out[1][k], gt_obj[0])
print("sim_metric obj = {}".format(sim_metric_np(gt_obj[0], out[1][k])))
print("KL obj = {}".format(kl_metric(gt_obj[0], out[1][k])))
print("auc score obj = {}".format(auc_metric(gt_obj[0], out[1][k])))
axes[0].imshow((255*bbox1).astype(np.uint8).reshape(224, 224))
axes[1].imshow((255*bbox2).astype(np.uint8).reshape(224, 224))

In [None]:
evaluate_model('/data/chami/ReferringRelationships/models/VRD/10_27_2017/baseline_no_predicate/3/model25-0.51.h5')

In [None]:
evaluate_model('/data/chami/ReferringRelationships/models/VRD/10_27_2017/ssn/22/model12-0.55.h5')

In [None]:
evaluate_model('/data/chami/ReferringRelationships/models/VRD/10_27_2017/sym_ssn/12/model17-0.54.h5')