In [None]:
import h5py
import numpy as np
import os
import pickle
import yaml

In [None]:
def compute_ll(config, data, results_all):
    normal_scale = config['normal_scale']
    normal_invvar = 1 / pow(normal_scale, 2)
    normal_const = np.log(2 * np.pi / normal_invvar)
    images = data['image']
    recon_all = results_all['recon']
    ll_all = []
    for recon in recon_all:
        pixel_ll = -0.5 * (normal_const + normal_invvar * np.square(recon - images))
        ll = pixel_ll.reshape(pixel_ll.shape[0], -1).sum(1)
        ll_all.append(ll)
    return np.array(ll_all).mean(-1)

def compute_oca(data, results_all):
    segments = data['segment']
    seg_back = segments.max()
    counts = np.array([(np.unique(val) != seg_back).sum() for val in segments])
    pres_all = results_all['pres']
    oca_all = []
    for pres in pres_all:
        oca = ((pres[:, :-1] >= 0.5).sum(-1) == counts).astype(np.float)
        oca_all.append(oca)
    return np.array(oca_all).mean(-1)

folder_out = 'outs'
folder_data = '../data'
name_list = ['mnist_extrapol', 'mnist_interpol']
phase_list = ['test', 'general']
with open('config.yaml') as f:
    config = yaml.safe_load(f)
metrics = {key: {} for key in name_list}
for name in name_list:
    for phase in phase_list:
        with h5py.File(os.path.join(folder_data, '{}.h5'.format(name)), 'r') as f:
            data = {key: f[phase][key][()] for key in f[phase]}
            for key, val in data.items():
                if key in ['segment', 'overlap']:
                    data[key] = val.astype(np.int64)
                else:
                    data[key] = val.astype(np.float64) / 255
        with h5py.File(os.path.join(folder_out, name, '{}.h5'.format(phase)), 'r') as f:
            results_all = {key: f[key][()] / 255 for key in f}
        metrics[name][phase] = {
            'll': compute_ll(config, data, results_all),
            'oca': compute_oca(data, results_all),
        }
with open('metrics.pkl', 'wb') as f:
    pickle.dump(metrics, f)     

In [None]:
with open('metrics.pkl', 'rb') as f:
    metrics = pickle.load(f)
fmt = {
    'll': '{:7.1f}',
    'oca': '{:6.3f}',
}
for name in name_list:
    print(name)
    for phase in phase_list:
        print(phase)
        for key_list in [['ll', 'oca']]:
            text_list = []
            for key in key_list:
                val = metrics[name][phase][key]
                if val is None:
                    continue
                text_list.append('{}:'.format(key) + fmt[key].format(val.mean()) + u'\u00B1' + '{:.1e}'.format(val.std()))
            text = (' ' * 4).join(text_list)
            print(text)
    print()