In [None]:
import h5py
import numpy as np
import os
import pickle
import yaml
from sklearn.metrics import adjusted_mutual_info_score

In [None]:
def compute_ami(config, data, results_all):
    seg_overlap = config['seg_overlap']
    segments = data['segment']
    overlaps = data['overlap']
    segments_valid = overlaps >= 1 if seg_overlap else overlaps == 1
    mask_all = results_all['mask']
    outputs = {key: [] for key in ['ami_obj']}
    for mask in mask_all:
        segment_o = np.argmax(mask, axis=1).squeeze(-1)
        sub_outputs = {key: [] for key in outputs}
        for seg_true, seg_valid, seg_o in zip(segments, segments_valid, segment_o):
            seg_true_sel = seg_true[seg_valid]
            seg_o_sel = seg_o[seg_valid]
            sub_outputs['ami_obj'].append(adjusted_mutual_info_score(seg_true_sel, seg_o_sel, average_method='max'))
        for key, val in sub_outputs.items():
            outputs[key].append(val)
    outputs = {key: np.array(val).mean(-1) for key, val in outputs.items()}
    return outputs

folder_out = 'outs'
folder_data = '../data'
configs = {
    'shapes': {'name_data': 'shapes', 'phase_list': ['test']},
    'flying_shapes_3_3': {'name_data': 'flying_shapes_3_3', 'phase_list': ['test', 'general']},
    'flying_shapes_3_5': {'name_data': 'flying_shapes_3_3', 'phase_list': ['test', 'general']},
    'flying_shapes_5_3': {'name_data': 'flying_shapes_5_3', 'phase_list': ['test', 'general']},
    'flying_shapes_5_5': {'name_data': 'flying_shapes_5_3', 'phase_list': ['test', 'general']},
    'flying_mnist': {'name_data': 'flying_mnist_2_3', 'phase_list': ['test', 'general']},
}
metrics = {key: {} for key in configs}
for name, cfg in configs.items():
    name_data = cfg['name_data']
    phase_list = cfg['phase_list']
    for phase in phase_list:
        with open('config_{}.yaml'.format(name)) as f:
            config = yaml.safe_load(f)
        num_iters = config['phase_param'][phase]['num_iters']
        with h5py.File(os.path.join(folder_data, '{}.h5'.format(name_data)), 'r') as f:
            if f[phase]['image'].shape[1] == 1:
                data = {key: f[phase][key][()].squeeze(1) for key in f[phase]}
            else:
                data = {key: f[phase][key][:, num_iters]for key in f[phase]}
            for key, val in data.items():
                if key in ['segment', 'overlap']:
                    data[key] = val.astype(np.int64)
                else:
                    data[key] = val.astype(np.float64) / 255
        with h5py.File(os.path.join(folder_out, name, '{}.h5'.format(phase)), 'r') as f:
            results_all = {key: f[key][()] / 255 for key in f}
        metrics[name][phase] = compute_ami(config, data, results_all)
with open('metrics.pkl', 'wb') as f:
    pickle.dump(metrics, f)

In [None]:
with open('metrics.pkl', 'rb') as f:
    metrics = pickle.load(f)
for name_data in metrics:
    print(name_data)
    for phase in metrics[name_data]:
        print(phase)
        for key_list in [['ami_obj']]:
            text_list = []
            for key in key_list:
                val = metrics[name_data][phase][key]
                text_list.append('{}: {:6.3f}'.format(key, val.mean()) + u'\u00B1' + '{:.1e}'.format(val.std()))
            text = (' ' * 4).join(text_list)
            print(text)
    print()