In [1]:
import os
import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import dsntnn

In [2]:
import matplotlib
font = {'family' : 'serif',
        'weight' : 'normal',
        'size'   : 22}

matplotlib.rc('font', **font)

In [3]:
ordered_verts = ['T4', 'T5', 'T6', 'T7', 'T8', 'T9', 'T10', 'T11', 'T12', 'L1', 'L2', 'L3', 'L4']
colors = ['r', 'b', 'g', 'c', 'm', 'y', 'orange', 'brown', 'pink', 'purple', 'white', 'gray', 'olive']

# Collect results:
<ol>
    <li>Patient score<br>
    <li>Level score<br>
<ol>

In [4]:
def sigmoid(x):
    return 1/(1+np.exp(-x))

### Get predictions

In [5]:
predictions = np.load('../outputs/mse_model_preds.npz')
ids, coords, dists, labels = predictions.values()
ids.shape, coords.shape, dists.shape, labels.shape

((46,), (46, 13), (46, 13, 512, 1), (46, 13))

Format predictions.<br> **Note**: Only interested in levels we know are present.

In [6]:
def format_preds(ids, coords, dists, labels):
    pred_dict = {}
    for i, name in enumerate(ids):
        tmp = coords[i]
        mask = labels[i]
        pred_dict[name] = {}
        for idx in range(len(mask)):
            if mask[idx] == 1:
                vert = ordered_verts[idx]
                coord = dsntnn.normalized_to_pixel_coordinates(tmp[idx], size=dists.shape[-1])
                pred_dict[name][vert] = coord
            else:
                continue
                
    dist_dict = {}
    for i, name in enumerate(ids):
        dist_dict[name] = dists[i]
    return pred_dict, dist_dict

In [7]:
pred_dict, dist_dict = format_preds(ids, coords, dists, labels)


### Get annotation info for converting to $mm$

In [8]:
from ast import literal_eval

In [9]:
pix_info = pd.read_csv(f'../images_sagittal/annotation_info.csv', index_col='Name')
pix_info.head()

Unnamed: 0_level_0,Padding,Pixel Scaling
Name,Unnamed: 1_level_1,Unnamed: 2_level_1
03_06_2014_389_Sag,"(42.0, 42.0)","(1.1640625, 1.1640625)"
03_06_2014_402_Sag,"(-96.0, -96.0)","(0.625, 0.625)"
03_06_2014_396_Sag,"(-63.0, -63.0)","(0.75200004576, 0.75200004576)"
03_06_2014_395_Sag,"(-90.0, -90.0)","(0.6480000019200001, 0.6480000019200001)"
03_06_2014_399_TS_Sag,"(-96.0, -96.0)","(0.625, 0.625)"


### Get ground-truth

In [10]:
def get_gt(id_):
    path = f'../data/testing/targets/'
    #image
    img_path = f'../data/testing/slices/sagittal/{id_}.npy'
    img = np.load(img_path)
    # Coordinates
    coord_path = path + f'coordinates/{id_}.csv'
    coords = pd.read_csv(coord_path, index_col='Level').sort_values(by='Coordinate')
    # Heatmaps
    dist_path = path + f'heatmaps/{id_}.npy'
    dists = np.load(dist_path)
    return coords, dists, img

In [11]:
coord_dict = {}
for name in ids:
    coords, dists, img = get_gt(name)
    coord_dict[name] = coords

In [12]:
def make_labels(coord_dict):
    # Use halfway point between neighbouring levels as boundary
    # Coord dict=  patient specific coordinates
    gt_dict = {}
    for i, (level, data) in enumerate(coord_dict.iterrows()):
        coord = data.to_list()[0]
        dif = coord_dict['Coordinate']-coord
        if i ==0:
            nn = coord_dict.iloc[(dif).abs().argsort()[1]]
            dist = np.abs(dif[nn.name])
            range_ = [coord-dist/2, coord, coord+dist/2]
            gt_dict[level] = range_
        elif i > 0 and i < len(coord_dict.index)-1:
            nn = coord_dict.iloc[(dif).abs().argsort()[1:3]]
            for vert in nn.index.to_list():
                rel_dist = ordered_verts.index(vert)-ordered_verts.index(level)
                if rel_dist < 0:
                    up_nn = np.abs(dif[vert])
                elif rel_dist > 0:
                    down_nn = np.abs(dif[vert])
                else:
                    print('Indexing has gone wrong')
                    break
            range_ = [coord - up_nn/2, coord, coord + down_nn/2]
            gt_dict[level] = range_
        else:
            nn = coord_dict.iloc[(dif).abs().argsort()[1]]
            dist = np.abs(dif[nn.name])
            range_ = [coord - dist/2, coord, coord+dist/2]
            gt_dict[level] = range_
    return gt_dict

#### Make labels (range = dist. to NN/2)

Each entry: Min Range, Center point, Max Range

In [13]:
all_gt = {}
for name in ids:
    gt_dict = make_labels(coord_dict[name])
    all_gt[name] = gt_dict

## Evaluate predictions
(`all_gt` & `pred_dict`)

In [14]:
def plot_preds(name, pred_dict):
    #Inputs are patient-specific
    gt, pred = all_gt[name], pred_dict[name]
    img = np.load(f'../data/testing/slices/sagittal/{name}.npy')
    fig, ax = plt.subplots(1, 1, figsize=(15, 15))
    ax.imshow(img)
        
    plot_gt = np.zeros((512, 512), dtype=np.uint8)
    
    # Plot ground-truth
    for i, (level, range_) in enumerate(gt.items()):
        channel = ordered_verts.index(level)
        min_range, center, max_range = range_
        rec = patches.Rectangle((0, min_range), 512, max_range-min_range, facecolor=None, 
                                edgecolor='w', linewidth=4, alpha=0.8,fill=False, hatch=None)
        ax.add_patch(rec)
        ax.text(490, center, level, color='w')
        pred_coord = pred[level]
        if pred_coord > min_range and pred_coord < max_range:
            ax.axhline(pred_coord, ls='--', c='lawngreen', lw=4)
            ax.text(5, pred_coord-5, level, color='lawngreen')
        else:
            ax.axhline(pred_coord, ls='--', c='orangered', lw=4)
            ax.text(5, pred_coord-5, level, color='orangered')
    fig.savefig(f'../outputs/predictions/{name}.png')
    plt.close()
    

In [15]:
plot_preds('fr_552_TS_Sag', pred_dict)

In [16]:
def get_scores(name, pred_dict, plot=False):
    if plot:
        plot_preds(name, pred_dict)
    #Inputs are patient-specific
    gt, pred = all_gt[name], pred_dict[name]
    img = np.load(f'../data/testing/slices/sagittal/{name}.npy')
    plot_gt = np.zeros((512, 512), dtype=np.uint8)
    level_list = []
    correct_list = []
    dist_list = []
    range_list = []
    for i, (level, range_) in enumerate(gt.items()):
        channel = ordered_verts.index(level)
        min_range, center, max_range = range_
        if level not in pred.keys():
            scores_dict[level] = {'Correct': 'Not Predicted',
                                  'Distance': 'N/a'}
            level_list.append(level)
            correct_list.append('Not Predicted')
            dist_list.append('N/A')
            range_list.append(pix_scale[0]*np.abs(max_range-min_range))
            continue
        pred_coord = pred[level]
        if pred_coord > min_range and pred_coord < max_range:
            level_list.append(level)
            correct_list.append(True)
            dist = np.abs(pred_coord - center)
            pix_scale = literal_eval(pix_info.loc[name, 'Pixel Scaling'])
            dist /= pix_scale[0]
            dist_list.append(dist)
            range_list.append(pix_scale[0]*np.abs(max_range-min_range))
        else:
            level_list.append(level)
            correct_list.append(False)
            dist = np.abs(pred_coord - center)
            pix_scale = literal_eval(pix_info.loc[name, 'Pixel Scaling'])
            dist /= pix_scale[0]
            if dist > 250:
                print(name)
            dist_list.append(dist)
            range_list.append(pix_scale[0]*np.abs(max_range-min_range))
    return level_list, correct_list, dist_list, range_list

Automate making dataframe w/ scores

In [17]:
def make_score_df(pred_dict):
    levels = []
    correct = []
    distances = []
    range_list = []
    for name in ids:
        level, cor, dist, ranges = get_scores(name, pred_dict, plot=False)
        levels.extend(level)
        correct.extend(cor)
        distances.extend(dist)
        range_list.extend(ranges)
    scores_df = pd.DataFrame(columns=['Level', 'Correct', 'Distance', 'Range'])
    scores_df['Level'] = levels
    scores_df['Correct'] = correct
    scores_df['Distance'] = distances
    scores_df['Range'] = range_list
    scores_df.head()
    scores_df['Level'] = pd.Categorical(scores_df['Level'],
                                   categories=ordered_verts,
                                   ordered=True)
    table = pd.DataFrame(scores_df.groupby('Level')['Correct'].value_counts(normalize=True)*100).transpose()
    table.index = ['Accuracy']
    display(table)
    accuracy = {level: table[level, True] for level in ordered_verts}
    accuracy = pd.DataFrame(accuracy).T
    return scores_df, accuracy

In [18]:
scores_df, accuracy = make_score_df(pred_dict)

04_06_2014_431_Sag
04_06_2014_431_Sag
04_06_2014_431_Sag
fr_553_LS_Sag
fr_553_LS_Sag
fr_553_LS_Sag
fr_553_LS_Sag
16_05_2014_249_Sag
16_05_2014_249_Sag
16_05_2014_249_Sag
16_05_2014_249_Sag
16_05_2014_249_Sag
16_05_2014_249_Sag
16_05_2014_249_Sag
16_05_2014_249_Sag
fr_583_TS_Sag
fr_583_TS_Sag
fr_583_TS_Sag
fr_583_TS_Sag
fr_583_TS_Sag
fr_583_TS_Sag
fr_583_TS_Sag
fr_583_TS_Sag
04_06_2014_414_LS_Sag_3mm
04_06_2014_414_LS_Sag_3mm
04_06_2014_414_LS_Sag_3mm
fr_540_LS_Sag
fr_540_LS_Sag
fr_540_LS_Sag
fr_540_LS_Sag
fr_540_LS_Sag
fr_540_LS_Sag
19_05_2014_227_Sag
19_05_2014_227_Sag
19_05_2014_227_Sag
19_05_2014_227_Sag
19_05_2014_227_Sag
19_05_2014_227_Sag
19_05_2014_227_Sag
28_05_2014_80_Sag_LS
28_05_2014_80_Sag_LS
28_05_2014_80_Sag_LS
fr_555_LS_Sag
fr_555_LS_Sag
fr_555_LS_Sag
fr_555_LS_Sag
fr_523_TS_Sag
fr_523_TS_Sag
fr_523_TS_Sag
fr_523_TS_Sag
fr_523_TS_Sag
fr_523_TS_Sag
fr_523_TS_Sag
fr_523_TS_Sag
fr_523_TS_Sag
fr_574_TS_Sag
fr_574_TS_Sag
fr_574_TS_Sag
fr_574_TS_Sag
fr_574_TS_Sag
fr_574_TS_Sag

Level,T4,T5,T6,T7,T8,T9,T10,T11,T12,L1,L2,L3,L4
Correct,False,False,False,False,False,False,False,False,False,False,False,False,False
Accuracy,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0


KeyError: ('T4', True)

In [None]:
def plot_distances_w_acc(scores, accuracy):
    plt.style.use('seaborn-bright')
    fig, ax = plt.subplots(1, 1,figsize=(20, 10))

    sns.color_palette("hls", 2)
    sns.boxplot(data=scores, x='Level', y='Distance', hue='Correct', linewidth=1.5,  order=ordered_verts, palette='muted', ax=ax)
    g = sns.lineplot(data=scores, x= 'Level', y='Range', ci='sd', err_style="band", estimator='min', ax=ax, color='gray')
    ax.set_ylabel('Absolute distance to center ($mm$)')
    ax2 = ax.twinx()
    ax2.set_ylabel('Accuracy (%)')
    ax2.set_ylim([-0.2, 100.2])
    sns.lineplot(data=accuracy, x=accuracy.index, y='Accuracy', ax=ax2, legend=False, ls='--', lw=3, color='limegreen' )
    #sns.catplot(dara=scores_df, x='Level', y=, hue='Correct', kind='hist')
    g.grid(True, which='both')
    #g.axhline(y=0, color='k')


    # Get the handles and labels. For this example it'll be 2 tuples
    # of length 4 each.
    handles, labels = g.get_legend_handles_labels()

    # When creating the legend, only use the first two elements
    # to effectively remove the last two.
    l = ax.legend(handles[0:2], labels[0:2], loc='upper left')

    ax2.spines['right'].set_color('limegreen')
    #ax2.yaxis.label.set_color('limegreen')
    ax2.tick_params(axis='y', colors='limegreen')
    l.get_texts()[0].set_text('Incorrect')
    l.get_texts()[1].set_text('Correct')

In [None]:
plot_distances_w_acc(scores_df, accuracy)

## Update predictions

Apply post-processing to predictions.

In [None]:
from scipy.special import softmax

In [None]:
def linear_expectation(probs, values):
    # Assumes normalised probs + single channel
    expectation = []
    for i in range(len(values)):
        expectation.append(probs[i]*values[i])
    return np.sum(expectation)

In [None]:
def update_preds(name, dist_dict, pred_dict, num_iters=0):
    eps=1e-24
    # k = number of iterations
    data = dist_dict[name]
    x = np.arange(0, data.shape[-1])
    # Get predictions afs dict
    updated_pred = {}
    if num_iters==0:
        print('Not updating...')
        return pred_dict[name]
    else:
        # Normalise across channels & update
        for k in range(num_iters):
            channel_norm = data/(data.sum(axis=0, keepdims=True)+eps)
            data = data*channel_norm**3
        for key in pred_dict[name].keys():
            channel = ordered_verts.index(key)
            norm_data = data[channel]/np.sum(data[channel])
            coords = linear_expectation(norm_data, x)
            #print(coords)
            updated_pred[key] = coords 
        return updated_pred

In [None]:
updated_preds = {}
for name in ids:
    updated_preds[name] = update_preds(name, dist_dict, pred_dict, num_iters=1)

In [None]:
upd_scores, upd_accuracy = make_score_df(updated_preds)

In [None]:
plot_distances_w_acc(upd_scores, upd_accuracy)

In [None]:
fig, ax = plt.subplots(1, 1,figsize=(20, 10))
sns.lineplot(data=accuracy, x=accuracy.index, y='Accuracy', ax=ax, legend=False, ls='--', lw=3, color='Red' )
sns.lineplot(data=upd_accuracy, x=upd_accuracy.index, y='Accuracy', ax=ax, legend=False, ls='--', lw=3, color='limegreen')
ax.set_xlabel('Level')

In [None]:
fig, ax = plt.subplots(1, 1,figsize=(20, 10))
concatenated = pd.concat([scores_df.assign(dataset='original'), upd_scores.assign(dataset='updated')])
sns.violinplot(data=concatenated, x='Level', y='Distance', hue='dataset', linewidth=3, order=ordered_verts, palette='muted', ax=ax, split=True)
plt.legend().remove()

In [None]:
plot_preds(ids[43], updated_preds)

In [None]:
plot_preds('fr_552_TS_Sag', pred_dict)

### Smaller bandwith (sigma=10)

In [None]:
small_band_preds = np.load('../outputs/small_band_model_preds.npz')
sb_ids, sb_coords, sb_dists, sb_labels = small_band_preds.values()
sb_ids.shape, sb_coords.shape, sb_dists.shape, sb_labels.shape

In [None]:
sb_pred, sb_dist = format_preds(sb_ids, sb_coords,sb_dists, sb_labels)

In [None]:
sb_scores, sb_acc = make_score_df(sb_pred)

In [None]:
plot_distances_w_acc(sb_scores, sb_acc)

In [None]:
small_band_preds = np.load('../outputs/small_band_model_preds.npz')
sb_ids, sb_coords, sb_dists, sb_labels = small_band_preds.values()
sb_ids.shape, sb_coords.shape, sb_dists.shape, sb_labels.shape

In [None]:
sb_pred, sb_dist = format_preds(sb_ids, sb_coords,sb_dists, sb_labels)

In [None]:
sb_scores, sb_acc = make_score_df(sb_pred)

In [None]:
plot_distances_w_acc(sb_scores, sb_acc)

In [None]:
fig, ax = plt.subplots(1, 1,figsize=(20, 10))
sns.lineplot(data=accuracy, x=accuracy.index, y='Accuracy', ax=ax, legend=False, lw=3, color='Red', label='Original')
sns.lineplot(data=upd_accuracy, x=upd_accuracy.index, y='Accuracy', ax=ax, legend=False, lw=3, color='limegreen', label='Updated Original')
sns.lineplot(data=sb_acc, x=sb_acc.index, y='Accuracy', ax=ax, legend=False, lw=3, color='Blue', label='Small Width')
ax.set_xlabel('Level')
ax.legend()

In [None]:
sb_upd_preds = {}
for name in ids:
    sb_upd_preds[name] = update_preds(name, sb_dist, sb_pred, num_iters=2)
sb_upd_scores, sb_upd_acc = make_score_df(sb_upd_preds)
fig, ax = plt.subplots(1, 1,figsize=(20, 10))
sns.lineplot(data=accuracy, x=accuracy.index, y='Accuracy', ax=ax, legend=False, lw=3, color='Red', label='Original')
sns.lineplot(data=upd_accuracy, x=upd_accuracy.index, y='Accuracy', ax=ax, legend=False, lw=3, color='limegreen', label='Updated Original')
sns.lineplot(data=sb_acc, x=sb_acc.index, y='Accuracy', ax=ax, legend=False, lw=3, color='Blue', label='Small Width')
sns.lineplot(data=sb_upd_acc, x=sb_upd_acc.index, y='Accuracy', ax=ax, legend=False, lw=3, color='K', label='Updated Small Width')
ax.set_xlabel('Level')
ax.legend()

### Number of counts per level

In [None]:
scores_df['Level'].value_counts().reindex(ordered_verts)

## KL Divergence

In [None]:
kl_preds = np.load('../outputs/small_band_model_preds.npz')
kl_ids, kl_coords, kl_dists, kl_labels = kl_preds.values()
kl_ids.shape, kl_coords.shape, kl_dists.shape, kl_labels.shape

In [None]:
kl_pred, kl_dist = format_preds(kl_ids, kl_coords, kl_dists, kl_labels)

In [None]:
kl_scores, kl_acc = make_score_df(kl_pred)

In [None]:
plot_distances_w_acc(kl_scores, kl_acc)

In [None]:
sb_upd_preds = {}
for name in ids:
    sb_upd_preds[name] = update_preds(name, sb_dist, sb_pred, num_iters=2)
sb_upd_scores, sb_upd_acc = make_score_df(sb_upd_preds)
fig, ax = plt.subplots(1, 1,figsize=(20, 10))
sns.lineplot(data=accuracy, x=accuracy.index, y='Accuracy', ax=ax, legend=False, lw=3, color='Red', label='Original')
sns.lineplot(data=upd_accuracy, x=upd_accuracy.index, y='Accuracy', ax=ax, legend=False, lw=3, color='limegreen', label='Updated Original')
sns.lineplot(data=sb_acc, x=sb_acc.index, y='Accuracy', ax=ax, legend=False, lw=3, color='Blue', label='Small Width')
sns.lineplot(data=sb_upd_acc, x=sb_upd_acc.index, y='Accuracy', ax=ax, legend=False, lw=3, color='K', label='Updated Small Width')
sns.lineplot(data=kl_acc, x=kl_acc.index, y='Accuracy', ax=ax, legend=False, lw=3, color='m', label='KL')
ax.set_xlabel('Level')
ax.legend()

### Add augmentations

In [None]:
aug_preds = np.load('../outputs/aug_preds.npz')
aug_ids, aug_coords, aug_dists, aug_labels = aug_preds.values()
aug_ids.shape, aug_coords.shape, aug_dists.shape, aug_labels.shape

In [None]:
aug_pred, aug_dist = format_preds(aug_ids, aug_coords, aug_dists, aug_labels)

In [None]:
aug_scores, aug_acc = make_score_df(aug_pred)

In [None]:
plot_distances_w_acc(aug_scores, aug_acc)

In [None]:
fig, ax = plt.subplots(1, 1,figsize=(20, 10))
sns.lineplot(data=accuracy, x=accuracy.index, y='Accuracy', ax=ax, legend=False, lw=3, color='Red', label='Original')
sns.lineplot(data=upd_accuracy, x=upd_accuracy.index, y='Accuracy', ax=ax, legend=False, lw=3, color='limegreen', label='Updated Original')
sns.lineplot(data=sb_acc, x=sb_acc.index, y='Accuracy', ax=ax, legend=False, lw=3, color='Blue', label='Small Width')
sns.lineplot(data=sb_upd_acc, x=sb_upd_acc.index, y='Accuracy', ax=ax, legend=False, lw=3, color='K', label='Updated Small Width')
sns.lineplot(data=kl_acc, x=kl_acc.index, y='Accuracy', ax=ax, legend=False, lw=3, color='m', label='KL')
sns.lineplot(data=aug_acc, x=aug_acc.index, y='Accuracy', ax=ax, legend=False, lw=3, color='orange', label='KL+Augmentations')
ax.set_xlabel('Level')
ax.legend()

## Sharpen Heatmaps

In [19]:
sh_preds = np.load('../outputs/w_classifier_preds.npz')
sh_ids, sh_coords, sh_dists, sh_labels = sh_preds.values()
sh_ids.shape, sh_coords.shape, sh_dists.shape, sh_labels.shape

((46,), (46, 13), (46, 13, 512), (46, 13))

In [20]:
sh_pred, sh_dist = format_preds(sh_ids, sh_coords, sh_dists, labels)

In [21]:
print(sh_dist["04_06_2014_431_Sag"].shape)

(13, 512)


In [22]:
sh_scores, sh_acc = make_score_df(sh_pred)

04_06_2014_416_CS_Sag
04_06_2014_416_CS_Sag


Level,T4,T4,T5,T5,T6,T6,T7,T7,T8,T8,...,T12,T12,L1,L1,L2,L2,L3,L3,L4,L4
Correct,False,True,False,True,False,True,False,True,False,True,...,True,False,True,False,True,False,True,False,True,False
Accuracy,85.714286,14.285714,80.952381,19.047619,76.190476,23.809524,70.0,30.0,76.190476,23.809524,...,71.875,28.125,63.333333,36.666667,65.517241,34.482759,82.758621,17.241379,92.592593,7.407407


In [23]:
sh_scores.head()

Unnamed: 0,Level,Correct,Distance,Range
0,T10,True,3.138128,16.191523
1,T11,True,8.073654,17.017223
2,T12,True,0.622825,18.565876
3,L1,True,4.690262,20.152865
4,L2,True,4.333403,21.391757


In [24]:
for name in sh_pred.keys():
    plot_preds(name, sh_pred)

In [None]:
plot_distances_w_acc(sh_scores, sh_acc)

In [None]:
fig, ax = plt.subplots(1, 1,figsize=(20, 10))
#sns.lineplot(data=accuracy, x=accuracy.index, y='Accuracy', ax=ax, legend=False, lw=3, color='Red', label='Original')
# sns.lineplot(data=upd_accuracy, x=upd_accuracy.index, y='Accuracy', ax=ax, legend=False, lw=3, color='limegreen', label='Updated Original')
# sns.lineplot(data=sb_acc, x=sb_acc.index, y='Accuracy', ax=ax, legend=False, lw=3, color='Blue', label='Small Width')
# sns.lineplot(data=sb_upd_acc, x=sb_upd_acc.index, y='Accuracy', ax=ax, legend=False, lw=3, color='K', label='Updated Small Width')
# sns.lineplot(data=kl_acc, x=kl_acc.index, y='Accuracy', ax=ax, legend=False, lw=3, color='m', label='KL')
sns.lineplot(data=aug_acc, x=aug_acc.index, y='Accuracy', ax=ax, legend=False, lw=3, color='orange', label='KL+Augmentations')
sns.lineplot(data=sh_acc, x=sh_acc.index, y='Accuracy', ax=ax, legend=False, lw=3, color='grey', label='W/ Classifier')
ax.set_xlabel('Level')
ax.legend()

### Update

In [None]:
sh_upd = {}
for name in ids:
    sh_upd[name] = update_preds(name, sh_dist, sh_pred, num_iters=1)

In [None]:
sh_upd_scores, sh_upd_acc = make_score_df(sh_upd)

In [None]:
plot_distances_w_acc(sh_upd_scores, sh_upd_acc)

### Compare CDFs

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(15, 10))
pal = sns.cubehelix_palette(n_colors=13, start=3, rot=-0.5, gamma=0.7)
sns.ecdfplot(data=aug_scores, x='Distance', hue='Level', ax=ax, palette=pal, lw=3)
print(scores_df.groupby('Level')['Distance'].median())
scores_df.groupby('Level')['Distance'].median().mean()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(15, 10))
pal = sns.cubehelix_palette(n_colors=13, start=3, rot=-0.5, gamma=0.7)
sns.ecdfplot(data=sh_scores, x='Distance', hue='Level', ax=ax, palette=pal, lw=3)
print(aug_scores.groupby('Level')['Distance'].median())
aug_scores.groupby('Level')['Distance'].median().mean()

## Distance vs. Target Range

In [None]:
%matplotlib inline

In [None]:
import ipywidgets as widgets
from ipywidgets import interact

In [None]:
scores=sh_scores
acc_df = sh_acc
@interact
def plot_scatter(level=ordered_verts):
    subset_df = scores[scores['Level'] == level]
    fig, ax = plt.subplots(1, 1, figsize=(10, 10))
    sns.scatterplot(data=subset_df, y='Distance', x='Range', ax=ax, hue='Level', style='Correct', s=100, alpha=0.8, linewidth=1.5, edgecolor='k')
    ax.set_title(acc_df.loc[level])
    ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
    ax.axhline(5, ls='--', c='k')
    ax.set_ylabel('Distance ($mm$)')
    ax.set_xlabel('Target Range ($mm$)')
    ax.set_ylim([-0.3, 100])