In [None]:
import h5py
import numpy as np
import os
import pickle
from sklearn.metrics import adjusted_rand_score

In [None]:
def compute_ari(data, results):
    segment_true = data['segment']
    mask_all_dict = {
        'gen': results['mask'],
        'att': results['mask_att'],
    }
    outputs = {mode: {key: [] for key in ['ari_all']} for mode in ['gen', 'att']}
    for mode, mask_all in mask_all_dict.items():
        for mask in mask_all:
            segment_a = np.argmax(mask, axis=1).squeeze(-1)
            sub_outputs = {key: [] for key in outputs[mode]}
            for seg_true, seg_a in zip(segment_true, segment_a):
                seg_a_true_sel = seg_true.reshape(-1)
                seg_a_sel = seg_a.reshape(-1)
                sub_outputs['ari_all'].append(adjusted_rand_score(seg_a_true_sel, seg_a_sel))
            for key, val in sub_outputs.items():
                outputs[mode][key].append(val)
    outputs = {'{}_{}'.format(key, mode): np.array(val) for mode in outputs for key, val in outputs[mode].items()}
    return outputs

folder_out = 'outs'
folder_data = '../data'
configs = {
    'room': [
        {'phase': 'test', 'suffix': '_train'},
        {'phase': 'empty_room', 'suffix': '_empty_room'},
        {'phase': 'six_objects', 'suffix': '_six_objects'},
        {'phase': 'identical_color', 'suffix': '_identical_color'},
    ],
    'dsprites': [
        {'phase': 'test', 'suffix': ''},
    ],
    'clevr': [
        {'phase': 'test', 'suffix': ''},
    ],
}
batch_size = 1000
metrics = {key: {} for key in configs}
for name_data, cfg_list in configs.items():
    for cfg in cfg_list:
        phase = cfg['phase']
        suffix = cfg['suffix']
        metrics[name_data][phase] = {}
        f_data = h5py.File(os.path.join(folder_data, '{}{}.h5'.format(name_data, suffix)), 'r')
        f_result = h5py.File(os.path.join(folder_out, name_data, '{}.h5'.format(phase)), 'r')
        num_data = f_data['image'].shape[0]
        for offset in range(0, num_data, batch_size):
            index_sel = slice(offset, offset + batch_size)
            data = {key: f_data[key][index_sel] for key in f_data}
            for key, val in data.items():
                if key in ['segment', 'overlap']:
                    data[key] = val.astype(np.int64)
                else:
                    data[key] = val.astype(np.float64) / 255
            results = {key: f_result[key][:, index_sel] / 255 for key in f_result}
            metrics_ari = compute_ari(data, results)
            for key, val in metrics_ari.items():
                if key in metrics[name_data][phase]:
                    metrics[name_data][phase][key].append(val)
                else:
                    metrics[name_data][phase][key] = [val]
        for key, val in metrics[name_data][phase].items():
            metrics[name_data][phase][key] = np.concatenate(val, axis=-1).mean(-1)
        f_data.close()
        f_result.close()
with open('metrics.pkl', 'wb') as f:
    pickle.dump(metrics, f)

In [None]:
with open('metrics.pkl', 'rb') as f:
    metrics = pickle.load(f)
for name_data in metrics:
    print(name_data)
    for phase in metrics[name_data]:
        print(phase)
        for key_list in [['ari_all_gen', 'ari_all_att']]:
            text_list = []
            for key in key_list:
                val = metrics[name_data][phase][key]
                text_list.append('{}: {:6.3f}'.format(key, val.mean()) + u'\u00B1' + '{:.0e}'.format(val.std()))
            text = (' ' * 4).join(text_list)
            print(text)
    print()