In [None]:
import h5py
import numpy as np
import os
import pickle
import torch
from scipy.optimize import linear_sum_assignment
from sklearn.metrics import adjusted_rand_score, adjusted_mutual_info_score

In [None]:
def compute_oca(data, results):
    segment = data['segment']
    seg_back = segment.max()
    count_true = np.array([(np.unique(val) != seg_back).sum() for val in segment])
    pres_all = results['pres']
    count_pred = (pres_all >= 0.5).sum(-1) - 1
    oca_all = (count_pred == count_true[None]).astype(np.float64)
    oca_all = oca_all.mean(-1)
    return oca_all

def compute_ooa(data, results):
    return None

def compute_ari_ami(data, results):
    segment_true = data['segment']
    overlap = data['overlap']
    segment_sel = overlap >= 1
    mask_all = results['mask']
    outputs = {key: [] for key in ['ari_all', 'ari_obj', 'ami_all', 'ami_obj']}
    for mask in mask_all:
        segment_a = np.argmax(mask, axis=1).squeeze(-1)
        segment_o = segment_a
        sub_outputs = {key: [] for key in outputs}
        for seg_true, seg_sel, seg_a, seg_o in zip(segment_true, segment_sel, segment_a, segment_o):
            seg_a_true_sel = seg_true.reshape(-1)
            seg_o_true_sel = seg_true[seg_sel]
            seg_a_sel = seg_a.reshape(-1)
            seg_o_sel = seg_o[seg_sel]
            sub_outputs['ari_all'].append(adjusted_rand_score(seg_a_true_sel, seg_a_sel))
            sub_outputs['ari_obj'].append(adjusted_rand_score(seg_o_true_sel, seg_o_sel))
            sub_outputs['ami_all'].append(adjusted_mutual_info_score(seg_a_true_sel, seg_a_sel, average_method='arithmetic'))
            sub_outputs['ami_obj'].append(adjusted_mutual_info_score(seg_o_true_sel, seg_o_sel, average_method='arithmetic'))
        for key, val in sub_outputs.items():
            outputs[key].append(val)
    outputs = {key: np.array(val).mean(-1) for key, val in outputs.items()}
    return outputs

def compute_iou_f1(data, results):
    return {key: None for key in ['iou', 'f1']}

folder_out = 'outs'
folder_data = '../../compositional-scene-representation-datasets'
name_data_list = ['mnist', 'dsprites', 'clevr', 'shop']
phase_list = ['test', 'general']
metrics = {}
for name_data in name_data_list:
    metrics[name_data] = {}
    for phase in phase_list:
        metrics[name_data][phase] = {}
        with h5py.File(os.path.join(folder_data, '{}.h5'.format(name_data)), 'r') as f:
            data = {key: f[phase][key][()] for key in f[phase]}
            for key, val in data.items():
                if key in ['segment', 'overlap']:
                    data[key] = val.astype(np.int64)
                else:
                    data[key] = val.astype(np.float64) / 255
        with h5py.File(os.path.join(folder_out, name_data, '{}.h5'.format(phase)), 'r') as f:
            results = {key: f[key][()] / 255 for key in f}
        metrics[name_data][phase]['oca'] = compute_oca(data, results)
        metrics[name_data][phase]['ooa'] = compute_ooa(data, results)
        metrics[name_data][phase].update(compute_ari_ami(data, results))
        metrics[name_data][phase].update(compute_iou_f1(data, results))
with open('metrics.pkl', 'wb') as f:
    pickle.dump(metrics, f)

In [None]:
with open('metrics.pkl', 'rb') as f:
    metrics = pickle.load(f)
for name_data in metrics:
    print(name_data)
    for phase in metrics[name_data]:
        print(phase)
        for key_list in [['ari_all', 'ami_all', 'ari_obj', 'ami_obj'], ['iou', 'f1', 'oca', 'ooa']]:
            text_list = []
            for key in key_list:
                val = metrics[name_data][phase][key]
                if val is None:
                    text_list.append('{:<7}: {:<11}'.format(key, 'N/A'))
                else:
                    text_list.append('{:<7}: {:.3f}'.format(key, val.mean()) + '\u00b1' + '{:.0e}'.format(val.std()))
            text = (' ' * 8).join(text_list)
            print(text)
    print()