In [89]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.spatial as spatial

### Helper functions

In [90]:
def pred_eps_acc(gt_endpoints, pred_endpoints, threshold):
    # Calculate distances
    dist_matrix = np.array(spatial.distance.cdist(gt_endpoints, pred_endpoints, metric = 'euclidean'))

    # Apply threshold
    dist_matrix[dist_matrix > threshold] = 0

    # Calculating accuracy
    valid_eps = np.count_nonzero(dist_matrix, axis = 1)
    accuracy = np.count_nonzero(valid_eps) / len(gt_endpoints)

    # If more than one valid endpoint found for a single ground truth endpoint, add the other valid endpoints to extra_valid_pairs
    extra_valid_pairs = []
    [[extra_valid_pairs.append([gt_endpoints[i], pred_endpoints[index]]) \
        for index, j in enumerate(dist_matrix[i]) if j != np.min(dist_matrix[i][dist_matrix[i] != 0]) if j != 0] \
            for i in valid_eps if i > 1]

    return accuracy, extra_valid_pairs


### Setting up - can delete once put into the workflow

In [91]:
data = pd.DataFrame(columns=["seg_id", "num_gt_eps", "gt_eps", "num_pred_eps", "pred_eps"])
data.loc[len(data.index)] = [864691136577830164, 6, [
    [402188, 228684, 24029], [401258, 224832, 24029], [401314, 228366, 24424], \
    [400199, 220721, 24029], [397870, 232292, 24004], [403272, 227529, 24394]], 6, \
    [[400292, 220879, 24027], [401253, 228342, 24408], [401253, 228342, 24409],
        [402923, 227458, 24372], [402099, 228716, 24037], [402627, 231471, 24026]]]
data

Unnamed: 0,seg_id,num_gt_eps,gt_eps,num_pred_eps,pred_eps
0,864691136577830164,6,"[[402188, 228684, 24029], [401258, 224832, 240...",6,"[[400292, 220879, 24027], [401253, 228342, 244..."


In [92]:
acc, extra_valid_pairs = pred_eps_acc(
    data.loc[0, "gt_eps"], data.loc[0, "pred_eps"], 100)
print(acc)
print(extra_valid_pairs)

0.3333333333333333
[[[401314, 228366, 24424], [401253, 228342, 24408]]]
