In [None]:
from iterator import SmartIterator
from utils.visualization_utils import get_att_map, objdict, get_dict
from models import ReferringRelationshipsModel
from utils.eval_utils import iou_bbox

from sklearn.metrics import roc_auc_score
from keras import backend as K
import numpy as np
import os
from PIL import Image
import json
import h5py
import seaborn as sns
import matplotlib.pyplot as plt

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
########### METRICS #########################################

def sim_metric_np(y_true, y_pred, eps=10e-8):
    y_true = (y_true.T/(eps + y_true.sum(axis=1).T)).T
    y_pred = (y_pred.T/(eps + y_pred.sum(axis=1).T)).T
    mini = ((y_true*(y_true<y_pred)) + (y_pred*(y_pred<y_true))).sum(axis=1)
    return list(mini)

def iou_np(y_true, y_pred, thresh=0.5, eps=10e-8):
    y_pred = y_pred > thresh
    intersection = (y_pred * y_true).sum(axis=1)
    union = eps + ((y_pred + y_true)>0).sum(axis=1)
    return list(intersection/union)

def recall_np(y_true, y_pred, thresh=0.5, eps=10e-8):
    y_pred = y_pred > thresh
    tp = (y_pred * y_true).sum(axis=1)
    fn = (1*((y_true - y_pred)>0)).sum(axis=1)
    recall = tp/(tp+fn+eps)
    return list(recall)

def precision_np(y_true, y_pred, thresh=0.5, eps=10e-8):
    y_pred = y_pred > thresh
    tp = (y_pred * y_true).sum(axis=1)
    p = y_pred.sum(axis=1)
    prec = tp/(p+eps)
    return list(prec)

def kl_metric_np(y_true, y_pred, eps=10e-8):
    y_true = (y_true.T/(eps + y_true.sum(axis=1).T)).T
    y_pred = (y_pred.T/(eps + y_pred.sum(axis=1).T)).T
    x = np.log(eps+(y_true/(eps+y_pred)))
    return list((x*y_true).sum(axis=1))

def cc_metric_np(y_true, y_pred, eps=10e-8):
    sigma_true = y_true.var(axis=1)
    sigma_pred = y_pred.var(axis=1)
    cov = ((y_true-y_true.mean(axis=1, keepdims=True)) * (y_pred - y_pred.mean(axis=1, keepdims=True))).mean(axis=1)
    return list(cov/np.sqrt((sigma_true*sigma_true)+eps))

def iou_bbox_np(y_true, y_pred, thresh=0.5, eps=10e-8):
    gt_bbox = get_bbox_from_heatmap(y_true, thresh)
    pred_bbox = get_bbox_from_heatmap(y_pred, thresh)
    return list(iou(gt_bbox, pred_bbox))

########### HELPERS #########################################

def get_bbox_from_heatmap(heatmap, threshold, input_dim=224):
    heatmap = heatmap.reshape((-1, input_dim, input_dim)) 
    heatmap[heatmap < threshold] = 0
    horiz = 1. * (heatmap.sum(axis=2, keepdims=True)>0)
    horiz = horiz.repeat(input_dim, axis=2)
    vert = 1. * (heatmap.sum(axis=1, keepdims=True)>0)
    vert = vert.repeat(input_dim, axis=1)
    mask = horiz * vert
    return mask

def iou(y_true, y_pred, eps=10e-8):
    intersection = (y_pred * y_true).sum(axis=1)
    union = eps + ((y_pred + y_true)>0).sum(axis=1)
    return intersection/union

In [None]:
def load_model(model_checkpoint):
    params = objdict(json.load(open(os.path.join(os.path.dirname(model_checkpoint), "args.json"), "r")))
    params.discovery = False
    model_weights = h5py.File(model_checkpoint)
    relationships_model = ReferringRelationshipsModel(params)
    model = relationships_model.build_model()
    model.load_weights(model_checkpoint)
    return model, params
    

def evaluate_model(model, params, metrics=[iou_np, recall_np, precision_np]):
    test_data_dir = params.test_data_dir
    params.batch_size = 490
    test_generator = SmartIterator(test_data_dir, params)
    results = {}
    for metric in metrics:
        results[metric.__name__+'_s'] = []
        results[metric.__name__+'_o'] = []
    for i in range(len(test_generator)):
        print("{}/{}".format(i, len(test_generator)))
        batch_in, batch_out = test_generator[i]
        preds = model.predict(batch_in)
        for metric in metrics:
            results[metric.__name__+'_s'] += metric(batch_out[0], preds[0])
            results[metric.__name__+'_o'] += metric(batch_out[1], preds[1])
    for metric in metrics:
        print("{} : {:.4f} & {:.4f} ".format(metric.__name__, 
                                            np.mean(results[metric.__name__+'_s']),
                                            np.mean(results[metric.__name__+'_o'])))
    return results

In [None]:
model_checkpoint = "/data/chami/ReferringRelationships/models/VRD/11_07_2017/sym_ssn/1/model18-1.66.h5"
#model_checkpoint = "/data/chami/ReferringRelationships/models/VRD/11_07_2017/sym_ssn/3/model21-0.96.h5"
#model_checkpoint = "/data/chami/ReferringRelationships/models/VRD/11_07_2017/sym_ssn/2/model17-1.37.h5"
model, params = load_model(model_checkpoint)
results = evaluate_model(model, params)

In [None]:
model_checkpoint = "/data/chami/ReferringRelationships/models/VRD/11_07_2017/ssn/5/model19-2.03.h5"
model, params = load_model(model_checkpoint)
preds = evaluate_model(model, params)

In [None]:
model_checkpoint = "/data/chami/ReferringRelationships/models/VRD/11_07_2017/baseline/7/model28-1.51.h5"
model, params = load_model(model_checkpoint)
preds = evaluate_model(model, params)

In [None]:
model_checkpoint = "/data/chami/ReferringRelationships/models/VRD/11_07_2017/baseline_no_predicate/8/model24-1.38.h5"
model, params = load_model(model_checkpoint)
preds = evaluate_model(model, params)

### VRD Models

In [None]:
# models selected with val iou
best_baseline_checkpoint = "/data/chami/ReferringRelationships/models/VRD/10_27_2017/baseline/8/model12-0.51.h5"
best_baseline_no_pred_checkpoint = "/data/chami/ReferringRelationships/models/VRD/10_27_2017/baseline_no_predicate/22/model36-0.53.h5"
best_ssn_checkpoint = "/data/chami/ReferringRelationships/models/VRD/10_27_2017/ssn/3/model35-0.54.h5"
best_sym_ssn_checkpoint = "/data/chami/ReferringRelationships/models/VRD/10_27_2017/sym_ssn/21/model11-0.52.h5"
best_sym_ssn_internal_loss_checkpoint = "/data/chami/ReferringRelationships/models/VRD/10_31_2017/1/model18-1.06.h5"

# models selected with val iou bbox
#best_baseline_checkpoint = "/data/chami/ReferringRelationships/models/VRD/10_27_2017/baseline/3/model17-0.50.h5"
#best_baseline_no_pred_checkpoint = '/data/chami/ReferringRelationships/models/VRD/10_27_2017/baseline_no_predicate/3/model25-0.51.h5'
#best_ssn_checkpoint = '/data/chami/ReferringRelationships/models/VRD/10_27_2017/ssn/22/model12-0.55.h5'
#best_sym_ssn_checkpoint = '/data/chami/ReferringRelationships/models/VRD/10_27_2017/sym_ssn/12/model17-0.54.h5'

# models selected with val_loss
#best_baseline_checkpoint = "/data/chami/ReferringRelationships/models/VRD/10_27_2017/baseline/2/model34-0.49.h5"
#best_baseline_no_pred_checkpoint = '/data/chami/ReferringRelationships/models/VRD/10_27_2017/baseline_no_predicate/3/model31-0.50.h5'
#best_ssn_checkpoint = '/data/chami/ReferringRelationships/models/VRD/10_27_2017/ssn/19/model47-0.50.h5'
#best_sym_ssn_checkpoint = '/data/chami/ReferringRelationships/models/VRD/10_27_2017/sym_ssn/5/model47-0.50.h5'


## Test Results 

In [None]:
import math

def divisorGenerator(n):
    large_divisors = []
    for i in range(1, int(math.sqrt(n) + 1)):
        if n % i == 0:
            yield i
            if i*i != n:
                large_divisors.append(n / i)
    for divisor in reversed(large_divisors):
        yield divisor
        
list(divisorGenerator(test_generator.samples-3))