In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import argparse
import glob
import json
import os
import sys
from collections import defaultdict

import torch
import torch.nn
import h5py

In [5]:
# from sharpf.modules.losses import bce_loss, smooth_l1_loss, smooth_l1_reg_loss

In [30]:
# LOSSES_BY_NAME = {
#     'has_sharp': bce_loss,
#     'segment_sharp': bce_loss,
#     'regress_sharpdf': smooth_l1_loss,
#     'regress_sharpdirf': smooth_l1_reg_loss
# }
LOSSES_BY_NAME = {}

# an ugly hack to extract available losses from torch
TORCH_NN_LOSSES = list(filter(lambda f: f.endswith('Loss'), dir(torch.nn)))
LOSSES = TORCH_NN_LOSSES


def get_loss_function(metric_name, reduction='none'):
    """The metric should be either importable from torch.nn, or in LOSSES_BY_NAME."""

    if metric_name in LOSSES_BY_NAME:
        loss_class = LOSSES_BY_NAME[metric_name]

    elif metric_name in TORCH_NN_LOSSES:
        loss_class = getattr(torch.nn, metric_name)

    else:
        raise ValueError('Metric {} cannot be instantiated, skipping'.format(metric_name))

    loss_function = loss_class(reduction=reduction)
    return loss_function



In [7]:
def display_sharpness(mesh=None, plot_meshvert=True,
                      samples=None, samples_distances=None,
                      sharp_vert=None, sharp_curves=None,
                      directions=None, directions_width=0.0025,
                      samples_color=0x0000ff, samples_psize=0.002, 
                      mesh_color=0xbbbbbb, meshvert_color=0x666666, meshvert_psize=0.0025,
                      sharpvert_color=0xff0000, sharpvert_psize=0.0025,
                      sharpcurve_color=None, sharpcurve_width=0.0025,
                      as_image=False, plot_height=768, 
                      cmap_distances=k3d.colormaps.basic_color_maps.WarmCool,
                      distance_range=None):
    
    plot = k3d.plot(height=plot_height)
    
    if None is not mesh:
        k3d_mesh = k3d.mesh(mesh.vertices, mesh.faces, color=mesh_color)
        plot += k3d_mesh

        if plot_meshvert:
            k3d_points = k3d.points(mesh.vertices, 
                                    point_size=meshvert_psize, color=meshvert_color)
            plot += k3d_points
            k3d_points.shader='flat'

    if None is not samples:
        colors = None
        if None is not samples_distances:
            if None is distance_range:
                distance_range = [0, np.max(samples_distances)]

            colors = k3d.helpers.map_colors(
                samples_distances, cmap_distances, distance_range
            ).astype(np.uint32)
            k3d_points = k3d.points(samples, point_size=samples_psize, colors=colors)
        else:
            k3d_points = k3d.points(samples, point_size=samples_psize, color=samples_color)
        plot += k3d_points
        k3d_points.shader='flat'
        
        if None is not directions:
            directions_to_plot = np.hstack((samples, samples + directions))
            
            for i, dir_to_plot in enumerate(directions_to_plot):
                dir_to_plot = dir_to_plot.reshape((2, 3))
                if np.all(dir_to_plot[0] == dir_to_plot[1]):
                    continue
                color = int(colors[i]) if None is not colors else samples_color
                plt_line = k3d.line(dir_to_plot, 
                                    shader='mesh', width=directions_width, color=color)
                plot += plt_line

    if None is not sharp_vert:
        k3d_points = k3d.points(sharp_vert,
                                point_size=sharpvert_psize, color=sharpvert_color)
        plot += k3d_points
        k3d_points.shader='flat'
        
        if None is not sharp_curves:            
            if None is not sharpcurve_color:
                color = sharpcurve_color
            else:
                import randomcolor
                rand_color = randomcolor.RandomColor()
            for i, vert_ind in enumerate(sharp_curves):
                sharp_points_curve = mesh.vertices[vert_ind]
                
                if None is sharpcurve_color:
                    color = rand_color.generate(hue='red')[0]
                    color = int('0x' + color[1:], 16)
                plt_line = k3d.line(sharp_points_curve, 
                                    shader='mesh', width=sharpcurve_width, color=color)
                plot += plt_line
        
    plot.grid_visible = False
    plot.camera_auto_fit = True
    plot.display()
    
    return plot

In [11]:
!ls /data/points/high_0.02/val/val_1024_0.hdf5

/data/points/high_0.02/val/val_1024_0.hdf5


In [103]:
true_dir = '/data/points/high_0.02/val/'
verbose = True
target_label = 'distances'
split_by = 'num_surfaces'

pred_dir = '/data/points/low_0.125/val/'

metrics = ['MSELoss']

n_histogram_bins = 100
n_best_instances = 10
n_worst_instances = 10
n_avg_instances = 10


save_filename = '/logs/metrics.json'

In [78]:
filenames = sorted(glob.glob(os.path.join(true_dir, '*.hdf5')))

In [79]:
split_by_values = None
metrics_by_name = defaultdict(dict)  # metric

In [80]:
true_pathname = filenames[0]

In [81]:
true_filename = os.path.basename(true_pathname)
if verbose:
    print('=== Reading GT file %s ===' % true_filename)
with h5py.File(true_pathname, 'r') as f:
    true_label = torch.from_numpy(f[target_label][:])
    if None is not split_by:
        split_by_values = torch.from_numpy(f[split_by][:])

=== Reading GT file val_1024_0.hdf5 ===


In [82]:
split_by_values

tensor([2, 2, 2,  ..., 1, 1, 1], dtype=torch.int8)

In [83]:
pred_pathname = os.path.join(pred_dir, true_filename)
if verbose:
    print('=== Reading pred file %s ===' % os.path.basename(pred_pathname))
with h5py.File(pred_pathname, 'r') as f:
    pred_label = torch.from_numpy(f[target_label][:])

=== Reading pred file val_1024_0.hdf5 ===


In [84]:
metric_name = metrics[0]

In [85]:
if verbose:
    print('=== Computing %s ===' % metric_name)
try:
    loss_function = get_loss_function(metric_name, reduction='none')
except Exception as e:
    print(str(e), file = sys.stderr)
    

=== Computing MSELoss ===


In [86]:
loss_values = loss_function(true_label, pred_label)

In [87]:
metrics_by_name[metric_name][true_filename] = {
    'values_by_instance': loss_values.mean(1),
    'split_by': split_by_values
}

In [98]:
def compute_statistics(loss_values, indexes=None):
    n_instances = len(loss_values)
    if None is indexes:
        indexes = torch.arange(n_instances)
    ascending_idx = torch.argsort(loss_values)
    mean_value = loss_values.mean()
    median_value = loss_values.median()
    values_hist = torch.histc(loss_values, bins=n_histogram_bins)
    bins_hist = torch.linspace(torch.min(loss_values),
                               torch.max(loss_values),
                               n_histogram_bins)

    statistics = {
        'values_by_instance': loss_values.tolist(),
        'mean_value': mean_value.item(),
        'median_value': median_value.item(),
        'values_hist': values_hist.tolist(),
        'bins_hist': bins_hist.tolist(),
    }
    if None is not n_best_instances:
        statistics['best_ids'] = indexes[ascending_idx[:n_best_instances]].tolist()

    if None is not n_worst_instances:
        statistics['worst_ids'] = indexes[ascending_idx[-n_worst_instances:]].tolist()

    if None is not n_avg_instances:
        # average here denotes "close to median"
        avg_min_idx = n_instances // 2 - n_avg_instances // 2
        avg_max_idx = n_instances // 2 + n_avg_instances // 2
        statistics['avg_ids'] = indexes[ascending_idx[avg_min_idx:avg_max_idx]].tolist()

    return statistics


In [99]:
statistics_by_metric = defaultdict(dict)


In [100]:
loss_values = torch.cat([value['values_by_instance']
                                 for value in metrics_by_name[metric_name].values()])

In [101]:
statistics_by_metric[metric_name] = {
    'default': compute_statistics(loss_values)
}

In [105]:
if None is not split_by:
    split_by_values = torch.cat([value['split_by']
                                 for value in metrics_by_name[metric_name].values()])
    unique_split_by_values = sorted(torch.unique(split_by_values).tolist())


In [106]:
unique_split_by_values

[1, 2, 3, 4, 5, 6, 7, 8]

In [122]:
for value in unique_split_by_values:
    key = '{}={}'.format(split_by, value)
    selector = (split_by_values == value)
    statistics_by_metric[metric_name].update({
        key : compute_statistics(loss_values[selector], indexes=selector.nonzero().reshape((-1)))
    })

In [125]:
statistics_by_metric['MSELoss'].keys()

dict_keys(['default', 'num_surfaces=1', 'num_surfaces=2', 'num_surfaces=3', 'num_surfaces=4', 'num_surfaces=5', 'num_surfaces=6', 'num_surfaces=7', 'num_surfaces=8'])

In [104]:
with open(save_filename, 'w') as write_file:
    json.dump(statistics_by_metric, write_file)