In [None]:
import h5py
import numpy as np
import os
import pickle
import scipy
import yaml
from scipy.optimize import linear_sum_assignment
from sklearn.metrics import adjusted_mutual_info_score

In [None]:
def compute_ll(config, data, results_all, eps=1e-10):
    normal_scale = config['normal_scale']
    normal_invvar = 1 / pow(normal_scale, 2)
    normal_const = np.log(2 * np.pi / normal_invvar)
    images = data['image']
    mask_all = results_all['mask']
    apc_all = results_all['apc']
    ll_all = []
    for mask, apc in zip(mask_all, apc_all):
        log_mask = np.log(mask + eps)
        raw_pixel_ll = -0.5 * (normal_const + normal_invvar * np.square(apc - images[:, None])).sum(-1, keepdims=True)
        log_prob = log_mask + raw_pixel_ll
        pixel_ll = scipy.special.logsumexp(log_prob, axis=1)
        ll = pixel_ll.reshape(pixel_ll.shape[0], -1).sum(-1)
        ll_all.append(ll)
    return np.array(ll_all).mean(-1)

def compute_oca(data, results_all):
    segments = data['segment']
    seg_back = segments.max()
    counts = np.array([(np.unique(val) != seg_back).sum() for val in segments])
    pres_all = results_all['pres']
    oca_all = []
    for pres in pres_all:
        oca = ((pres[:, :-1] >= 0.5).sum(-1) == counts).astype(np.float)
        oca_all.append(oca)
    return np.array(oca_all).mean(-1)

def compute_ami(config, data, results_all):
    seg_overlap = config['seg_overlap']
    segments = data['segment']
    overlaps = data['overlap']
    segments_valid = overlaps >= 1 if seg_overlap else overlaps == 1
    mask_all = results_all['mask']
    outputs = {key: [] for key in ['ami_obj']}
    for mask in mask_all:
        segment_o = np.argmax(mask[:, :-1], axis=1).squeeze(-1)
        sub_outputs = {key: [] for key in outputs}
        for seg_true, seg_valid, seg_o in zip(segments, segments_valid, segment_o):
            seg_true_sel = seg_true[seg_valid]
            seg_o_sel = seg_o[seg_valid]
            sub_outputs['ami_obj'].append(adjusted_mutual_info_score(seg_true_sel, seg_o_sel, average_method='max'))
        for key, val in sub_outputs.items():
            outputs[key].append(val)
    outputs = {key: np.array(val).mean(-1) for key, val in outputs.items()}
    return outputs

def select_by_order(val_all, order_all):
    val_all_sel = []
    for val_list, order_list in zip(val_all, order_all):
        val_sel = np.array([val[order] for val, order in zip(val_list, order_list)])
        val_sel = np.concatenate([val_sel, val_list[:, -1:]], axis=1)
        val_all_sel.append(val_sel)
    return np.array(val_all_sel)

def compute_order(config, data, results_all):
    seg_overlap = config['seg_overlap']
    overlaps = data['overlap'][:, None, ..., None]
    mask_valid = np.ones(overlaps.shape) if seg_overlap else (overlaps <= 1).astype(np.float)
    mask_all = results_all['mask']
    mask_all *= mask_valid[None]
    shp_true = data['layers'][..., -1:]
    part_cumprod = np.concatenate([
        np.ones((shp_true.shape[0], 1, *shp_true.shape[2:])),
        np.cumprod(1 - shp_true[:, :-1], 1),
    ], axis=1)
    mask_true = shp_true * part_cumprod
    mask_true *= mask_valid
    order_cost = -(mask_true[None, :, :-1, None] * mask_all[:, :, None, :-1])
    order_cost = order_cost.reshape(*order_cost.shape[:-3], -1).sum(-1)
    order_all = []
    for cost_list in order_cost:
        order_list = []
        for cost in cost_list:
            _, cols = linear_sum_assignment(cost)
            order_list.append(cols)
        order_all.append(order_list)
    return np.array(order_all)

def compute_ooa(data, order_all):
    layers = data['layers']
    objects_apc, objects_shp = layers[:, :-1, ..., :-1], layers[:, :-1, ..., -1:]
    weights = np.zeros((objects_shp.shape[0], objects_shp.shape[1], objects_shp.shape[1]))
    for i in range(objects_shp.shape[1] - 1):
        for j in range(i + 1, objects_shp.shape[1]):
            sq_diffs = np.square(objects_apc[:, i] - objects_apc[:, j]).sum(-1, keepdims=True)
            sq_diffs *= objects_shp[:, i] * objects_shp[:, j]
            weights[:, i, j] = sq_diffs.reshape(sq_diffs.shape[0], -1).sum(-1)
    ooa_all, weights_all = [], []
    for order in order_all:
        binary_mat = np.zeros(weights.shape)
        for i in range(order.shape[1] - 1):
            for j in range(i + 1, order.shape[1]):
                binary_mat[:, i, j] = order[:, i] < order[:, j]
        ooa_all.append((binary_mat * weights).reshape(weights.shape[0], -1).sum(-1))
        weights_all.append(weights.reshape(weights.shape[0], -1).sum(-1))
    return np.array(ooa_all).mean(-1) / np.array(weights_all).mean(-1)

def compute_layer_mse(layers, apc_all, shp_all):
    target_apc, target_shp = layers[..., :-1], layers[..., -1:]
    layer_mse_all, weights_all = [], []
    for apc, shp in zip(apc_all, shp_all):
        noise = np.random.uniform(0, 1, size=apc.shape)
        target_recon = target_apc * target_shp + noise * (1 - target_shp)
        recon = apc * shp + noise * (1 - shp)
        sq_diffs = np.square(recon - target_recon).mean(-1, keepdims=True)
        mask_valid = np.maximum(target_shp, shp)
        layer_mse_all.append((sq_diffs * mask_valid).reshape(mask_valid.shape[0], -1).mean(-1))
        weights_all.append(mask_valid.reshape(mask_valid.shape[0], -1).mean(-1))
    return np.array(layer_mse_all).mean(-1) / np.array(weights_all).mean(-1)

folder_out = 'outs'
folder_data = '../data'
mode_list = ['sep', 'occ']
color_list = ['gray', 'rgb_1', 'rgb_2', 'rgb_3', 'rgb_4']
name_list = ['shapes', 'mnist']
phase_list = ['test', 'general_4', 'general_10']
with open('config.yaml') as f:
    config = yaml.safe_load(f)
metrics = {}
for mode in mode_list:
    metrics[mode] = {}
    for color in color_list:
        metrics[mode][color] = {}
        mode_color = '{}_{}'.format(mode, color)
        for name in name_list:
            metrics[mode][color][name] = {}
            for phase in phase_list:
                data_phase = phase.split('_')[0]
                with h5py.File(os.path.join(folder_data, mode_color, '{}.h5'.format(name)), 'r') as f:
                    data = {key: f[data_phase][key][()] for key in f[data_phase]}
                    for key, val in data.items():
                        if key in ['segment', 'overlap']:
                            data[key] = val.astype(np.int64)
                        else:
                            data[key] = val.astype(np.float64) / 255
                with h5py.File(os.path.join(folder_out, mode_color, name, '{}.h5'.format(phase)), 'r') as f:
                    results_all = {key: f[key][()] / 255 for key in f}
                sub_metrics = {
                    'll': compute_ll(config, data, results_all),
                    'oca': compute_oca(data, results_all),
                }
                sub_metrics.update(compute_ami(config, data, results_all))
                order_all = compute_order(config, data, results_all)
                sub_metrics['ooa'] = compute_ooa(data, order_all) if mode == 'occ' and color in ['rgb_1', 'rgb_3'] else None
                if 'layers' in data:
                    layers = data['layers']
                    shp_true = layers[..., -1:]
                    apc_all_sel = select_by_order(results_all['apc'], order_all)
                    shp_all_sel = select_by_order(results_all['shp'], order_all)
                    sub_metrics['layer_mse'] = compute_layer_mse(layers, apc_all_sel, shp_all_sel)
                metrics[mode][color][name][phase] = sub_metrics
with open('metrics.pkl', 'wb') as f:
    pickle.dump(metrics, f)       

In [None]:
with open('metrics.pkl', 'rb') as f:
    metrics = pickle.load(f)
fmt = {
    'll': '{:7.1f}',
    'ami_obj': '{:6.3f}',
    'layer_mse': '{:9.2e}',
    'oca': '{:6.3f}',
    'ooa': '{:6.3f}',
}
for phase in phase_list:
    print(phase)
    for mode in mode_list:
        print(mode)
        for color in color_list:
            for name in name_list:
                print('{}_{}'.format(color, name))
                for key_list in [['ll', 'ami_obj', 'layer_mse', 'oca', 'ooa']]:
                    text_list = []
                    for key in key_list:
                        val = metrics[mode][color][name][phase][key]
                        if val is None:
                            continue
                        text_list.append('{}:'.format(key) + fmt[key].format(val.mean()) + u'\u00B1' + '{:.1e}'.format(val.std()))
                    text = (' ' * 4).join(text_list)
                    print(text)
        print()
    print()