In [None]:
import h5py
import numpy as np
import os
import pickle
import torch
from scipy.optimize import linear_sum_assignment
from sklearn.metrics import adjusted_rand_score, adjusted_mutual_info_score

In [None]:
def compute_order(data, results):
    segment = torch.from_numpy(data['segment'])[:, None, ..., None]
    scatter_shape = [segment.shape[0], segment.max() + 1, *segment.shape[2:]]
    obj_mask_true = torch.zeros(scatter_shape).scatter_(1, segment, 1).numpy().astype(np.float64)[:, :-1]
    obj_shp_true = data['masks'][:, :-1]
    binary_mat_true = np.zeros((obj_shp_true.shape[0], obj_shp_true.shape[1], obj_shp_true.shape[1]))
    for i in range(obj_shp_true.shape[1] - 1):
        for j in range(i + 1, obj_shp_true.shape[1]):
            region = np.minimum(obj_shp_true[:, i], obj_shp_true[:, j])
            area_i = (obj_mask_true[:, i] * region).reshape(region.shape[0], -1).sum(-1)
            area_j = (obj_mask_true[:, j] * region).reshape(region.shape[0], -1).sum(-1)
            binary_mat_true[:, i, j] = (area_i >= area_j) * 2 - 1
    obj_mask_all = results['mask'][:, :, :-1]
    order_cost_all = -(obj_mask_true[None, :, :, None] * obj_mask_all[:, :, None])
    order_cost_all = order_cost_all.reshape(*order_cost_all.shape[:-3], -1).sum(-1)
    order_all = []
    for cost_list in order_cost_all:
        order_list = []
        for cost in cost_list:
            _, cols = linear_sum_assignment(cost)
            order_list.append(cols)
        order_all.append(order_list)
    order_all = np.array(order_all)
    return order_all, binary_mat_true

def compute_oca(data, results):
    segment = data['segment']
    seg_back = segment.max()
    count_true = np.array([(np.unique(val) != seg_back).sum() for val in segment])
    pres_all = results['pres']
    count_pred = (pres_all[..., :-1] >= 0.5).sum(-1)
    oca_all = (count_pred == count_true[None]).astype(np.float64)
    oca_all = oca_all.mean(-1)
    return oca_all

def compute_ooa(data, results, order_all, binary_mat_true):
    obj_shp_true = data['masks'][:, :-1]
    weights = np.zeros((obj_shp_true.shape[0], obj_shp_true.shape[1], obj_shp_true.shape[1]))
    for i in range(obj_shp_true.shape[1] - 1):
        for j in range(i + 1, obj_shp_true.shape[1]):
            region = np.minimum(obj_shp_true[:, i], obj_shp_true[:, j])
            weights[:, i, j] = region.reshape(region.shape[0], -1).sum(-1)
    sum_weights = weights.sum()
    ooa_all = []
    for order in order_all:
        binary_mat_pred = np.zeros(weights.shape)
        for i in range(order.shape[1] - 1):
            for j in range(i + 1, order.shape[1]):
                binary_mat_pred[:, i, j] = (order[:, i] < order[:, j]) * 2 - 1
        binary_mat = (binary_mat_true * binary_mat_pred) == 1
        ooa_all.append((binary_mat * weights).sum() / sum_weights)
    ooa_all = np.array(ooa_all)
    return ooa_all

def compute_ari_ami(data, results):
    segment_true = data['segment']
    overlap = data['overlap']
    segment_sel = overlap >= 1
    mask_all = results['mask']
    outputs = {key: [] for key in ['ari_all', 'ari_obj', 'ami_all', 'ami_obj']}
    for mask in mask_all:
        segment_a = np.argmax(mask, axis=1).squeeze(-1)
        segment_o = np.argmax(mask[:, :-1], axis=1).squeeze(-1)
        sub_outputs = {key: [] for key in outputs}
        for seg_true, seg_sel, seg_a, seg_o in zip(segment_true, segment_sel, segment_a, segment_o):
            seg_a_true_sel = seg_true.reshape(-1)
            seg_o_true_sel = seg_true[seg_sel]
            seg_a_sel = seg_a.reshape(-1)
            seg_o_sel = seg_o[seg_sel]
            sub_outputs['ari_all'].append(adjusted_rand_score(seg_a_true_sel, seg_a_sel))
            sub_outputs['ari_obj'].append(adjusted_rand_score(seg_o_true_sel, seg_o_sel))
            sub_outputs['ami_all'].append(adjusted_mutual_info_score(seg_a_true_sel, seg_a_sel, average_method='arithmetic'))
            sub_outputs['ami_obj'].append(adjusted_mutual_info_score(seg_o_true_sel, seg_o_sel, average_method='arithmetic'))
        for key, val in sub_outputs.items():
            outputs[key].append(val)
    outputs = {key: np.array(val).mean(-1) for key, val in outputs.items()}
    return outputs

def select_by_order(val_all, order_all):
    val_all_sel = []
    for val_list, order_list in zip(val_all, order_all):
        val_sel = np.array([val[order] for val, order in zip(val_list, order_list)])
        val_all_sel.append(val_sel)
    return np.array(val_all_sel)

def compute_iou_f1(data, results, order_all, eps=1e-6):
    obj_shp_true = data['masks'][:, :-1]
    obj_shp_all = select_by_order(results['shp'], order_all)
    seg_true = obj_shp_true.reshape(*obj_shp_true.shape[:2], -1)
    pres = (seg_true.max(-1) != 0).astype(np.float64)
    sum_pres = pres.sum()
    outputs = {key: [] for key in ['iou', 'f1']}
    for obj_shp in obj_shp_all:
        seg_pred = obj_shp.reshape(*obj_shp.shape[:2], -1)
        area_i = np.minimum(seg_true, seg_pred).sum(-1)
        area_u = np.maximum(seg_true, seg_pred).sum(-1)
        iou = area_i / np.clip(area_u, eps, None)
        f1 = 2 * area_i / np.clip(area_i + area_u, eps, None)
        outputs['iou'].append((iou * pres).sum() / sum_pres)
        outputs['f1'].append((f1 * pres).sum() / sum_pres)
    outputs = {key: np.array(val) for key, val in outputs.items()}
    return outputs

folder_out = 'outs'
folder_data = '../../compositional-scene-representation-datasets'
name_data_list = ['mnist', 'dsprites', 'clevr', 'shop']
phase_list = ['test', 'general']
metrics = {}
for name_data in name_data_list:
    metrics[name_data] = {}
    for phase in phase_list:
        metrics[name_data][phase] = {}
        with h5py.File(os.path.join(folder_data, '{}.h5'.format(name_data)), 'r') as f:
            data = {key: f[phase][key][()] for key in f[phase]}
            for key, val in data.items():
                if key in ['segment', 'overlap']:
                    data[key] = val.astype(np.int64)
                else:
                    data[key] = val.astype(np.float64) / 255
        with h5py.File(os.path.join(folder_out, name_data, '{}.h5'.format(phase)), 'r') as f:
            results = {key: f[key][()] / 255 for key in f}
        order_all, binary_mat_true = compute_order(data, results)
        metrics[name_data][phase]['oca'] = compute_oca(data, results)
        metrics[name_data][phase]['ooa'] = compute_ooa(data, results, order_all, binary_mat_true)
        metrics[name_data][phase].update(compute_ari_ami(data, results))
        metrics[name_data][phase].update(compute_iou_f1(data, results, order_all))
with open('metrics.pkl', 'wb') as f:
    pickle.dump(metrics, f)

In [None]:
with open('metrics.pkl', 'rb') as f:
    metrics = pickle.load(f)
for name_data in metrics:
    print(name_data)
    for phase in metrics[name_data]:
        print(phase)
        for key_list in [['ari_all', 'ami_all', 'ari_obj', 'ami_obj'], ['iou', 'f1', 'oca', 'ooa']]:
            text_list = []
            for key in key_list:
                val = metrics[name_data][phase][key]
                if val is None:
                    text_list.append('{:<7}: {:<11}'.format(key, 'N/A'))
                else:
                    text_list.append('{:<7}: {:.3f}'.format(key, val.mean()) + '\u00b1' + '{:.0e}'.format(val.std()))
            text = (' ' * 8).join(text_list)
            print(text)
    print()