# Setup

## Data Import

In [None]:
from itertools import combinations, permutations
from pathlib import Path

import numpy as np
import pandas as pd
from sqlite3 import connect

from matplotlib import pyplot as plt
import seaborn as sns

np.random.seed(707260)

skip_plots = True

In [None]:
db_source = 'results.db'
con = connect(db_source)
tables = pd.read_sql(
    "SELECT * FROM sqlite_master", 
    con=con
).loc[:, 'name']
con.close()

In [None]:
def parse_dataset(data_label):
    # Obligatory clinical exception
    if 'clinical' in data_label:
        return 'clin_only'

    # Otherwise the dataset type can be queried by a simple if chain
    components = []
    # Whether the dataset is slice-based or vertebreal-based metrics
    if 'perslice' in data_label:
        components.append('slice')
    elif 'vertebral' in data_label:
        components.append('vert')
    else:
        components.append('unkData')

    # Whether the database includes clinical as well (full) or just image-derived
    if 'full' in data_label:
        components.append('full')
    else:
        components.append('img_only')

    # Return the concatenated results
    return '_'.join(components)

In [None]:
def parse_vert_range(data_label):
    # Then whether its in the C2C6 or C2C7 range
    if 'c2c6' in data_label:
        return 'C2C6'
    elif 'c2c7' in data_label:
        return 'C2C7'
    else:
        return 'None'

In [None]:
def parse_seg_algo(data_label):
    # Obligatory clinical exception
    if 'clinical' in data_label:
        return 'none'

    # Get the segmentation algorithm
    if 'binary' in data_label:
        return 'binary'
    elif 'soft' in data_label:
        return 'soft'
    return 'unkSeg'

In [None]:
def parse_contrast(data_label):
    # Obligatory clinical exception
    if 'clinical' in data_label:
        return 'none'

    # Currently only T1 and T2 contrasts exist
    if 'T1w' in data_label:
        return 'T1'
    elif 'T2w' in data_label:
        return 'T2'
    return 'unkContrast'

In [None]:
def parse_orientation(data_label):
    # Obligatory clinical exception
    if 'clinical' in data_label:
        return 'none'

    # Currently only T1 and T2 contrasts exist
    if 'sag' in data_label:
        return 'sag'
    elif 'axial' in data_label:
        return 'axial'
    return 'unkOri'

In [None]:
def parse_pre_processing(data_label):
    # We have 5 variants this time
    if 'rfe_pca' in data_label:
        return 'rfe_pca'
    elif 'pca_rfe' in data_label:
        return 'pca_rfe'
    elif 'rfe' in data_label:
        return 'rfe'
    elif 'pca' in data_label:
        return 'pca'
    elif 'noprep' in data_label:
        return 'none'
    return 'unkPrep'

In [None]:
df_map = {}

bad_vals = 0

analysis_idx = ['seg_algo', 'dataset', 'vert_range', 'model', 'weight', 'ori', 'prep']

con = connect(db_source)
n_studies = 0
for t in tables:
    # Pull the dataframe from the database
    try:
        df = pd.read_sql(
            f"SELECT * FROM {t}", 
            con=con
        )
        n_studies += 1
    except:
        print(f"Failed to read table {t}, ignoring it")
        bad_vals += 1
        continue

    # If the table represents a study which wasn't run to completion, end early and report it
    if df.shape[0] < 1000:
        print(f"Study {t} was not completed")
        bad_vals += 1
        continue

    # Split the DataFrame's label into its components
    label_comps = t.split('__')

    # The model is always the second element of study tag
    model = label_comps[1]

    # Dataset is always the last element of the study tag
    data_description = label_comps[-1]

    # Interpret the data label bit by bit to build up the dataframe
    df['seg_algo'] = parse_seg_algo(data_description)
    df['dataset'] = parse_dataset(data_description)
    df['vert_range'] = parse_vert_range(data_description)
    df['model'] = model
    df['weight'] = parse_contrast(data_description)
    df['ori'] = parse_orientation(data_description)
    df['prep'] = parse_pre_processing(data_description)

    df_map[model + '_' + data_description] = df

con.close()

print(f"\nTotal No. bad values: {bad_vals}")

## Performance Metric Stacking

All metrics in the below index list are tracked for all analyses, so are safe to query (and stack) from all analytical permutations

In [None]:
shared_performance_metric_idxs = [
    "objective",
    "balanced_accuracy (validate)",
    "roc_auc (validate)",
    "log_loss (validate)",
    "balanced_accuracy (test)",
    "sk_precision_perclass (test)",
    "sk_recall_perclass (test)",
    "sk_f1_perclass (test)",
    "roc_auc (test)",
    "log_loss (test)",
    "importance_by_permutation (test)"
]

In [None]:
study_idxs = [
    "replicate",
    "trial"
]

In [None]:
def stack_performance_metrics():
    sub_dfs = []
    for df in df_map.values():
        sub_df = df.loc[:, [*analysis_idx, *study_idxs, *shared_performance_metric_idxs]]
        sub_dfs.append(sub_df)
    return pd.concat(sub_dfs)

performance_metric_df = stack_performance_metrics()

### Per-Class Feature Extraction

In [None]:
class_metrics = [
    "sk_precision_perclass (test)",
    "sk_recall_perclass (test)",
    "sk_f1_perclass (test)"
]

In [None]:
def unpack_class_metrics(df, metric_col):
    # Conver the string the values are strored in back to dictionary maps
    metric_subset = df.loc[:, metric_col]
    metric_subset = metric_subset.apply(lambda x: x.strip('{').strip('}').split(', '))
    metric_subset = metric_subset.apply(lambda x: {y.split(': ')[0]: float(y.split(': ')[1]) for y in x})

    # Make a copy of the dict to avoid issues later
    ret_df = df.copy()
    
    # Add the new columns for each class in these dictionaries
    class_keys = set([y for x in metric_subset for y in x])
    for k in class_keys:
        new_col_name = f"{metric_col} [{k}]"
        ret_df[new_col_name] = metric_subset.apply(lambda x: x[k]).values

    # Remove the original column
    ret_df.drop(metric_col, axis=1)

    return ret_df

In [None]:
for c in class_metrics:
    performance_metric_df = unpack_class_metrics(performance_metric_df, c)

#### Save Results to File

In [None]:
performance_metric_df.to_csv('full_performance_metrics.tsv', sep='\t')

# Patient Metric Distributions

## Data Importing

In [None]:
clinical_metric_df = pd.read_csv("clinical_only.tsv", sep='\t')

## mJOA

Setup

In [None]:
def plot_distributions(data, cmap, legend_elements, xlabel, title, mean_offset=0, flip_mean_rot=False):
    # Get the appropriate ranges for the data
    min_range = int(np.min(data))-1
    max_range = int(np.max(data))+1
    
    # Bin the data
    hist, bins = np.histogram(
        data, 
        np.array(range(min_range, max_range))+.1
    )
    
    # Generate the figure
    fig, ax = plt.subplots()
        
    # Iteratively color code the bars
    for t, c in cmap.items():
        mask = bins < t
        to_display = np.array(range(min_range, t))+0.5
        vals = hist[mask[:-1]]
        ax.bar(
            to_display, vals,
            width=1, color=c,
            align='edge',
            edgecolor='black'
        )
        
    # Add a mean line
    data_mean = np.mean(data)
    ax.axvline(data_mean, ls='--', c='black')
    if flip_mean_rot:
        ax.text(data_mean-0.5, ax.get_ylim()[1]-mean_offset, f"Mean ({data_mean:.4})", rotation=90)
    else:
        ax.text(data_mean+0.05, ax.get_ylim()[1]-mean_offset, f"Mean ({data_mean:.4})", rotation=-90)
        
    # Add in the legend
    ax.legend(handles=legend_elements)
    
    # Add in labels
    ax.set_xlabel(xlabel)
    ax.set_ylabel('Count')
    ax.set_title(title)
    
    # Return the figure and axis
    return fig, ax

In [None]:
# Limits so that all plots have consistent range
xlim_min = int(np.min([*clinical_metric_df['mJOA initial'], *clinical_metric_df['mJOA 12 months']]))-1
xlim_max = int(np.max([*clinical_metric_df['mJOA initial'], *clinical_metric_df['mJOA 12 months']]))+1

ylim_min = 0
ylim_max = int(np.max([
    *np.histogram(clinical_metric_df['mJOA initial'], np.array(range(xlim_min, xlim_max))+.1)[0],
    *np.histogram(clinical_metric_df['mJOA 12 months'], np.array(range(xlim_min, xlim_max))+.1)[0]
]))+5

# Color threshold map
severity_cmap = {
    18: 'blue',
    17: 'green',
    14: 'gold',
    11: 'red'
}

# Generate a custom legend
from matplotlib.patches import Patch
legend_elements = [
    Patch(facecolor='red', edgecolor='black', label='Severe'),
    Patch(facecolor='gold', edgecolor='black', label='Moderate'),
    Patch(facecolor='green', edgecolor='black', label='Mild'),
    Patch(facecolor='blue', edgecolor='black', label='Healthy'),
]

# DCM Severity labelling
clinical_metric_df['DCM Severity initial'] = 'Severe'
clinical_metric_df.loc[clinical_metric_df['mJOA initial'] > 11, 'DCM Severity initial'] = 'Moderate'
clinical_metric_df.loc[clinical_metric_df['mJOA initial'] > 14, 'DCM Severity initial'] = 'Mild'
clinical_metric_df.loc[clinical_metric_df['mJOA initial'] > 17, 'DCM Severity initial'] = 'Healthy'

clinical_metric_df['DCM Severity 12 months'] = 'Severe'
clinical_metric_df.loc[clinical_metric_df['mJOA 12 months'] > 11, 'DCM Severity 12 months'] = 'Moderate'
clinical_metric_df.loc[clinical_metric_df['mJOA 12 months'] > 14, 'DCM Severity 12 months'] = 'Mild'
clinical_metric_df.loc[clinical_metric_df['mJOA 12 months'] > 17, 'DCM Severity 12 months'] = 'Healthy'

# Output path for the files
mjoa_dist_out_path = Path('figures/mjoa_dist')
if not mjoa_dist_out_path.exists():
    mjoa_dist_out_path.mkdir(parents=True)

### Initial

In [None]:
if not skip_plots:
    # Plot the data
    fig, ax = plot_distributions(
        clinical_metric_df['mJOA initial'], severity_cmap, legend_elements,
        'mJOA', 'Pre-Surgical mJOA Scores', 20
    )
    
    # Plot the total number of each severity class as text
    severity_counts = clinical_metric_df['DCM Severity initial'].value_counts()
    ax.text(9, 15, f"({severity_counts['Severe']})", c='black', size=12, horizontalalignment='center')
    ax.text(14, 44.5, f"({severity_counts['Moderate']})", c='black', size=12, horizontalalignment='center')
    ax.text(16.5, 33, f"({severity_counts['Mild']})", c='black', size=12, horizontalalignment='center')
    ax.text(18, 2.5, f"({severity_counts['Healthy']})", c='black', size=12, horizontalalignment='center')
    
    # Save and show the result
    if not skip_plots:
        fig.savefig(mjoa_dist_out_path / 'pre_treatment_mjoa.svg')
        plt.show()

### 12 Month

In [None]:
if not skip_plots:
    # Plot the data
    fig, ax = plot_distributions(
        clinical_metric_df['mJOA 12 months'], severity_cmap, legend_elements,
        'mJOA', 'Post-Surgical mJOA Scores', 20, flip_mean_rot=True
    )
    
    # Plot the total number of each severity class as text
    severity_counts = clinical_metric_df['DCM Severity 12 months'].value_counts()
    ax.text(8.5, 3, f"({severity_counts['Severe']})", c='black', size=12, horizontalalignment='center')
    ax.text(13.5, 36, f"({severity_counts['Moderate']})", c='black', size=12, horizontalalignment='center')
    ax.text(16, 45, f"({severity_counts['Mild']})", c='black', size=12, horizontalalignment='center')
    ax.text(18, 37, f"({severity_counts['Healthy']})", c='black', size=12, horizontalalignment='center')
    
    # Save and show the result
    fig.savefig(mjoa_dist_out_path / 'post_treatment_mjoa.svg')
    plt.show()

### mJOA Delta

In [None]:
# Define a new color scheme and legend for this new style of data
delta_cmap = {
    8: 'springgreen',
    0: 'white',
    -1: 'salmon'
}

delta_legend_elements = [
    Patch(facecolor='springgreen', edgecolor='black', label='Improved'),
    Patch(facecolor='white', edgecolor='black', label='No Change'),
    Patch(facecolor='salmon', edgecolor='black', label='Declined'),
]

xticks = (
    list(range(-8, 9, 2)),
    list(range(-8, 9, 2))
)

deltas = clinical_metric_df['mJOA 12 months'] - clinical_metric_df['mJOA initial']

In [None]:
# Plot the deltas
if not skip_plots:
    fig, ax = plot_distributions(
        deltas, delta_cmap, delta_legend_elements, 
        "mJOA Change", 'Change in mJOA 1 Year Post-Surgery', 20, flip_mean_rot=True
    )
    
    # Plot the total number of each severity class as text
    change_counts = pd.cut(
        deltas, 
        [-20, -1, 0, 20], 
        labels=['Declined', 'No Change', 'Improved']
    ).value_counts()
    ax.text(-4.5, 9, f"({change_counts['Declined']})", c='black', size=12, verticalalignment='center')
    ax.text(-0.6, 40, f"({change_counts['No Change']})", c='black', size=12, verticalalignment='center')
    ax.text(4, 32, f"({change_counts['Improved']})", c='black', size=12, verticalalignment='center')
    
    # Save and show the result
    fig.savefig(mjoa_dist_out_path / 'treatment_mjoa_delta.svg')
    plt.show()

## Hirayabashi Recovery Ratio Distribution

Setup

In [None]:
from scipy.stats import gaussian_kde

# Plot the KDE distribution onto an existing plot
def plot_kde(ax, values, c='black', ls='-', label=None):
    kde = gaussian_kde(values)
    kde.covariance_factor = lambda: 0.15
    kde._compute_covariance()
    xs = np.linspace(np.min(values), np.max(values), 200)
    ys = kde(xs)
    ys /= np.linalg.norm(ys)
    if label == None:
        ax.plot(xs, ys, ls=ls, c=c)
    else:
        ax.plot(xs, ys, ls=ls, c=c, label=label)

# Clean out invalid values from the set
def clean_vals(df):
    df2 = df[df != -np.inf]
    df2 = df2.dropna()
    return df2

# Adds important reference lines to the plot
def draw_line_references(ax):
    # Significant improvement
    ax.axvline(0.5, ls='-.', c='grey')
    
    # Baselines
    ax.axhline(0, ls=":",  c='lightgrey') 
    ax.axvline(0, ls=":",  c='lightgrey')

# The HRR Equation, for immediate reference within the plot
hirabayashi_equation = r"HRR = $\frac{\mathrm{mJOA (1 Year)} - \mathrm{mJOA (Initial)}}{18 - \mathrm{mJOA (Initial)}}$"

In [None]:
if not skip_plots:
    # Get the HRR for our patients, skipping over initially healthy patients who could not improve whatsoever
    hrr_df = clinical_metric_df.loc[clinical_metric_df['DCM Severity initial'] != "Healthy", 'HRR']
    
    # Generate the initial plot
    fig, ax = plt.subplots()
    
    # Plot our reference lines
    draw_line_references(ax)
    
    # Plot the distributions by their initial severity class
    plot_kde(
        ax, clean_vals(hrr_df[clinical_metric_df['DCM Severity initial'] == 'Severe']), ls='--', c='red', label='Severe'
    )
    plot_kde(
        ax, clean_vals(hrr_df[clinical_metric_df['DCM Severity initial'] == 'Moderate']), ls='--', c='gold', label='Moderate'
    )
    plot_kde(
        ax, clean_vals(hrr_df[clinical_metric_df['DCM Severity initial'] == 'Mild']), ls='--', c='green', label='Mild'
    )
    
    # Plot the overall distribution
    plot_kde(ax, hrr_df, c='blue', label='All')
    
    # Calculate the ratio above and below the HRR significance threshold, and add it
    good_ratio = np.sum(hrr_df >= 0.5)/hrr_df.shape[0]
    fair_ratio = np.sum(hrr_df < 0.5)/hrr_df.shape[0]
    
    ax.text(0.7, 0.238, f"{good_ratio: .2f}", c='purple')
    ax.text(-0.5, 0.238, f"{fair_ratio: .2f}", c='purple')
    
    # Add axis labels
    ax.set_xlabel('Hirabayashi Recovery Ratio (HRR)')
    ax.set_ylabel('Normalized Kernel Density Estimate')
    
    # Add a legend
    ax.legend(title='Pre-Surgical DCM Severity')
    
    # Add hirabayashi equation directly to plot
    ax.text(-8, 0.15, hirabayashi_equation)
    
    # Add a title
    ax.set_title("Distribution of Hirabayashi Recovery Ratio")
    
    plt.tight_layout()
    
    fig.savefig(mjoa_dist_out_path / 'hirabayashi_ratios.svg')
    
    plt.show()

## Demographics

### Continuous

In [None]:
def plot_continuous_demographics(col, **kwargs):
    sns.displot(clinical_metric_df, x=col, **kwargs)
    plt.title(f"Patient Distribution ({col})")
    plt.xlabel(col)
    plt.ylabel("Count")
    plt.tight_layout()
    plt.savefig(f"figures/demo_dist/{'_'.join(col.lower().split(' '))}_dist.svg")
    plt.show()

In [None]:
out_path = Path("figures/demo_dist/")
if not out_path.exists():
    out_path.mkdir(parents=True)

#### Age

In [None]:
if not skip_plots:
    plot_continuous_demographics("Age", bins=range(20, 90, 5))

#### BMI

In [None]:
if not skip_plots:
    plot_continuous_demographics("BMI", bins=range(15, 51, 3))

### Categorical Columns

In [None]:
def plot_categorical_demographics(col):
    col_counts = clinical_metric_df[col].value_counts()
    plt.pie(col_counts, labels=None, autopct=lambda x: f'{x: .2f}%')
    plt.legend(labels=col_counts.index)
    plt.title(f"Patient Distribution ({col})")
    plt.tight_layout()
    plt.savefig(f"figures/demo_dist/{'_'.join(col.lower().split(' '))}_dist.svg")
    plt.show()

#### Sex

In [None]:
if not skip_plots:
    plot_categorical_demographics("Sex")

#### Work Status

In [None]:
if not skip_plots:
    plot_categorical_demographics("Work Status (Category)")

#### Symptom Duration

In [None]:
if not skip_plots:
    plot_categorical_demographics("Symptom Duration")

# Best across Trial

## Utility Functions

In [None]:
best_across_trials_idx = [*analysis_idx, 'Mean', 'STD']
best_across_trials_idx

In [None]:
# Gets the values of one column when the value of another is among the n-highest (default to n=1)
def get_peak_at_max_other(target_col, other_col, df=performance_metric_df, n=1, ascending=[True]) -> pd.DataFrame:
    # Get the best value per analytical grouping and replicate across all trials
    peak_value_df = df.sort_values(by=other_col, ascending=ascending).groupby([*analysis_idx, 'replicate']).tail(n)

    # Set up the return dataframe
    analysis_groups = peak_value_df.reset_index().groupby(analysis_idx)
    value_means = analysis_groups[target_col].mean()
    value_stds = analysis_groups[target_col].std()
    return_df = pd.DataFrame(index=list(value_means.index))
    return_df['Mean'] = value_means
    return_df['STD'] = value_stds

    # Return the result
    return return_df

## Prep

In [None]:
# Force the target column to act as a float, as occasionally "null" values slip through and make it act like strings
performance_metric_df.loc[:, ['balanced_accuracy (test)', 'balanced_accuracy (validate)']] = \
    performance_metric_df.loc[:, ['balanced_accuracy (test)', 'balanced_accuracy (validate)']].astype('float32')

## Balanced Accuracy

### Test @ Peak Validation **[MAIN RESULT]**

In [None]:
other_index = ['balanced_accuracy (validate)', 'objective']
other_ascend = ascending=[True, False]

In [None]:
peak_testing_at_validation_df = get_peak_at_max_other('balanced_accuracy (test)', other_index, ascending=other_ascend).sort_values(by='Mean').tail(10)
peak_testing_at_validation_df

### Test @ Peak Test [Theoretical Potential]

In [None]:
other_index = ['balanced_accuracy (test)', 'objective']
other_ascend = ascending=[True, False]

In [None]:
get_peak_at_max_other('balanced_accuracy (test)', other_index, ascending=ascending).sort_values(by='Mean').tail(10)

### Other Performance Metrics

In [None]:
other_index = ['balanced_accuracy (validate)', 'objective']
other_ascend = ascending=[True, False]

#### Precision (Good)

In [None]:
get_peak_at_max_other('sk_precision_perclass (test) [good]', other_index, ascending=ascending).sort_values(by='Mean').tail(10)

#### Recall (Good)

In [None]:
get_peak_at_max_other('sk_recall_perclass (test) [good]', other_index, ascending=ascending).sort_values(by='Mean').tail(10)

#### F1-Score (Good)

In [None]:
get_peak_at_max_other('sk_f1_perclass (test) [good]', other_index, ascending=ascending).sort_values(by='Mean').tail(10)

#### Precision (Fair)

In [None]:
get_peak_at_max_other('sk_precision_perclass (test) [fair]', other_index, ascending=ascending).sort_values(by='Mean').tail(10)

#### Recall (Good)

In [None]:
get_peak_at_max_other('sk_recall_perclass (test) [fair]', other_index, ascending=ascending).sort_values(by='Mean').tail(10)

#### F1-Score (Good)

In [None]:
get_peak_at_max_other('sk_f1_perclass (test) [fair]', other_index, ascending=ascending).sort_values(by='Mean').tail(10)

#### ROC AUC

In [None]:
performance_metric_df['roc_auc (test)'] = performance_metric_df['roc_auc (test)'].astype('float32')

In [None]:
get_peak_at_max_other('roc_auc (test)', other_index, ascending=ascending).sort_values(by='Mean').tail(10)

# Performance Across Trials

## Utility Functions

In [None]:
def plot_average_performance_across_trials(df, metric, grouping, fpath):
    # Plot the average and standard deviation
    sns.lineplot(data=df, x='trial', y=metric, hue=grouping)

    # Add details
    plt.title(f'By {grouping.capitalize()} (Average)')
    plt.tight_layout()

    # Save and show the plot
    plt.savefig(fpath)
    plt.show()

## Balanced Accuracy (Test)

In [None]:
if not skip_plots:
    output_dir = Path("figures/bacc_performance/")
    if not output_dir.exists():
        output_dir.mkdir(parents=True)
    
    for i in analysis_idx:
        plot_average_performance_across_trials(performance_metric_df, 'balanced_accuracy (test)', i, output_dir/f'bacc_avg_by_{i}.png')

## Balanced Accuracy (Test) at Peak Balanced Accuracy (Validate)

In [None]:
def plot_metric_at_peak_other_across_trials(df, metric, other, grouping, fpath):
    # Reformat the data to be max by trial/replicate grouping
    tmp_df = df.sort_values(other).groupby(['replicate', 'trial', grouping]).tail(1).reset_index()
    
    # Plot the average and standard deviation
    sns.lineplot(data=tmp_df, x='trial', y=metric, hue=grouping)

    # Add details
    plt.title(f'By {grouping.capitalize()} (B.Acc Test @ Peak Validation)')
    plt.tight_layout()

    # Save and show the plot
    plt.savefig(fpath)
    plt.show()

In [None]:
if not skip_plots:
    for i in analysis_idx:
        plot_metric_at_peak_other_across_trials(performance_metric_df, 'balanced_accuracy (test)', 'balanced_accuracy (validate)', i, output_dir/f'bacc_test_at_peak_validate_by_{i}.png')

## Balanced Accuracy (Test) Weighted by Balanced Accuracy (Validated)

In [None]:
def weighted_std(vals, weights):
    mean_val = np.average(vals, weights=weights)
    n = vals.shape[0]
    var_val= np.average((vals-mean_val)**2, weights=weights) * (n/(n-1))
    std_val = np.sqrt(var_val)
    return std_val

In [None]:
def metric_weighted_by_other(df, metric, weight, grouping, fpath):
    # Calculate the weighted metrics from the original dataset
    df_groupedby = df.loc[:, [grouping, *study_idxs, metric, weight]].groupby([grouping, 'trial'])
    mean_vals = df_groupedby.apply(lambda x: np.average(x[metric], weights=x[weight]), include_groups=False)
    std_vals = df_groupedby.apply(lambda x: weighted_std(x[metric], x[weight]), include_groups=False)
    
    sub_df = pd.DataFrame()
    sub_df['Mean'] = mean_vals
    sub_df['STD'] = std_vals

    # Plot each of them iteratively, w/ weighted mean and std
    fig, ax = plt.subplots(1)
    group_options = set(df[grouping])
    for i, g in enumerate(group_options):
        # Plot the main line
        y = sub_df.reset_index().query(f"{grouping} == '{g}'")
        y_mean = y.groupby('trial')['Mean'].mean()
        ax.plot(y_mean, label=g)

        # Plot the (weighted) standard deviation fills
        y_std = y.groupby('trial')['STD'].mean()
        ax.fill_between(np.arange(y_std.shape[0]), y_mean+y_std, y_mean-y_std, facecolor=f'C{i}', alpha=0.2)

    # Add other plotted elements
    plt.xlabel('Trial')
    plt.ylabel('Weighted Average')
    plt.legend(title=grouping)
    plt.show()

In [None]:
if not skip_plots:
    for i in analysis_idx:
        metric_weighted_by_other(performance_metric_df, 'balanced_accuracy (test)', 'balanced_accuracy (validate)', i, output_dir/f'bacc_weighted_avg_by_{i}.png')

# Statistical Tests

## Setup

In [None]:
from itertools import permutations

from scipy.stats import ranksums, kruskal, false_discovery_control

Target metric gathering function

In [None]:
# Absolute peak values by replicate, mean and std
def get_best_per_replicate(target_value):
    component_dfs = []
    for k, df in df_map.items():
        peak_df = df.sort_values(by=target_value).groupby('replicate').last()
        peak_df = peak_df.loc[:, [*analysis_idx, 'trial', target_value]]
        component_dfs.append(peak_df)
    result_df = pd.concat(component_dfs).reset_index()
    return result_df

In [None]:
# Values of one metric, sampled at the peak value of another, per-replicate mean and STD sampled
def get_val_at_best_other_per_replicate(target, other, fallback=None, ascending=True, fallback_ascending=True):
    component_dfs = []

    
    if fallback is not None:
        other_idx = [other, fallback]
        ascending_ls = [ascending, fallback_ascending]
    else:
        other_idx = other
        ascending_ls = ascending
    
    for k, df in df_map.items():
        # Get the 'best' value in the set, as sorted by other
        peak_other = df.sort_values(by=other_idx, ascending=ascending_ls).groupby('replicate').last()[other]

        # For each entry in the peak values, query the sub_df which matches 
        comp_dfs = []
        for rep_idx in peak_other.index:
            peak_val = peak_other[rep_idx]
            sub_df = df.query(f"replicate == {rep_idx}")
            sub_df = sub_df.loc[sub_df[other] == peak_val, :]
            comp_dfs.append(sub_df)
            
        # Extend the DF with the new value(s)
        peak_df = pd.concat(comp_dfs).loc[:, [*analysis_idx, 'replicate', target, other]]

        # Append the result to our list
        component_dfs.append(peak_df)
    result_df = pd.concat(component_dfs).reset_index()
    return result_df

In [None]:
alt_keys = {
    'two-sided': '!=',
    'greater':   '>',
    'less':      '<'
}

def paired_rankedsum(df, query, target, alternative='two-sided'):
    pvals = {}
    query_set = set(df[query])

    # Caclulate the native rankedsum p-value for each pair of datasets, testing whether the former's value is greater than the latters
    for v1, v2 in permutations(query_set, 2):
        x1 = df.query(f"{query} == '{v1}'")[target]
        x2 = df.query(f"{query} == '{v2}'")[target]
        p = ranksums(x1, x2, alternative=alternative).pvalue
        pvals[f"{v1} {alt_keys[alternative]} {v2}"] = [p]

    # Save the results as a dataframe
    return_df = pd.DataFrame.from_dict(pvals).T
    return_df.index.name = 'Comparison'
    return_df.columns = ['p']
    return return_df

In [None]:
def evaluate_kw(df, grouping, target):
    query_set = set(df[grouping])
    samples = [df.query(f"{grouping} == '{q}'")[target] for q in query_set]
    return kruskal(*samples).pvalue

## Testing Balanced Accuracy

### Testing @ Peak Validation

#### Raw Performance

In [None]:
target = 'balanced_accuracy (test)'
other = 'balanced_accuracy (validate)'
fallback = 'objective'
replicate_test_at_peak_bacc_df = get_val_at_best_other_per_replicate(target, other, fallback, fallback_ascending=False)

#### Ranked-Sum Grouping Comparisons

In [None]:
# Calculate the p-values for whether one experimental permutation has greater average balanced accuracy performance than another
sub_dfs = []
for k in analysis_idx:
    tmp_df = paired_rankedsum(replicate_test_at_peak_bacc_df, k, target, alternative='greater')
    sub_dfs.append(tmp_df)

sig_test_at_peak_valid_df = pd.concat(sub_dfs).sort_values('p')

Bonferonni False Detection Correction **[Extremely Conservative]**

In [None]:
# Calculate the corrected p-value significance as well
sig_test_at_peak_valid_df['significance (bonferonni)'] = ''
bf_ps = sig_test_at_peak_valid_df['p'] * sig_test_at_peak_valid_df.shape[0]
for i, t in enumerate([0.05, 0.01, 0.001]):
    sig_test_at_peak_valid_df.loc[bf_ps < t, 'significance (bonferonni)'] = '*'*(i+1)

Benjaminini-Yekutieli False-Detection Correction **[Less Conservative, chosen over Benjamini-Hochberg due to our tests not being completely independent]**

In [None]:
# Calculate the corrected p-value significance as well
sig_test_at_peak_valid_df['significance (benjaminini-yekutieli)'] = ''
by_ps = false_discovery_control(sig_test_at_peak_valid_df['p'], method='by')
for i, t in enumerate([0.05, 0.01, 0.001]):
    sig_test_at_peak_valid_df.loc[by_ps < t, 'significance (benjaminini-yekutieli)'] = '*'*(i+1)

### Result

In [None]:
sig_test_at_peak_valid_df.reset_index().head(40)

#### Kruskal-Wallace

In [None]:
# Using Kruskal-Wallace, confirm that there is a significant difference in the best-case performance for each analytical variation
kw_pvals = {}
for i in analysis_idx:
    kw_pvals[i] = [evaluate_kw(replicate_test_at_peak_bacc_df, i, 'balanced_accuracy (test)')]
kw_df = pd.DataFrame.from_dict(kw_pvals).T
kw_df.columns = ['p']

Bonferonni

In [None]:
# Calculate the corrected p-value significance as well w/ Bonferroni correction
kw_df['significance (bonferonni)'] = ''
kw_bf_ps = kw_df['p'] * kw_df.shape[0]
for i, t in enumerate([0.05, 0.01, 0.001]):
    kw_df.loc[kw_bf_ps<t, 'significance (bonferonni)'] = '*'*(i+1)

Benjaminini-Yekutieli

In [None]:
# Calculate the corrected p-value significance as well w/ Bonferroni correction
kw_df['significance (benjaminini-yekutieli)'] = ''
kw_bf_ps = kw_df['p'] * kw_df.shape[0]
for i, t in enumerate([0.05, 0.01, 0.001]):
    kw_df.loc[kw_bf_ps<t, 'significance (benjaminini-yekutieli)'] = '*'*(i+1)

### Result

In [None]:
kw_df

### Testing @ Peak Objective

#### Raw Performance

In [None]:
target = 'balanced_accuracy (test)'
other = 'objective'
replicate_test_at_peak_obj_df = get_val_at_best_other_per_replicate(target, other, ascending=False)

#### Ranked-Sum Grouping Comparisons

In [None]:
# Calculate the p-values for whether one experimental permutation has greater average balanced accuracy performance than another
sub_dfs = []
for k in analysis_idx:
    tmp_df = paired_rankedsum(replicate_test_at_peak_obj_df, k, target, alternative='greater')
    sub_dfs.append(tmp_df)

sig_test_at_peak_obj_df = pd.concat(sub_dfs).sort_values('p')

Bonferonni False Detection Correction **[Extremely Conservative]**

In [None]:
# Calculate the corrected p-value significance as well
sig_test_at_peak_obj_df['significance (bonferonni)'] = ''
bf_ps = sig_test_at_peak_obj_df['p'] * sig_test_at_peak_obj_df.shape[0]
for i, t in enumerate([0.05, 0.01, 0.001]):
    sig_test_at_peak_obj_df.loc[bf_ps < t, 'significance (bonferonni)'] = '*'*(i+1)

Benjaminini-Yekutieli False-Detection Correction **[Less Conservative, chosen over Benjamini-Hochberg due to our tests not being completely independent]**

In [None]:
# Calculate the corrected p-value significance as well
sig_test_at_peak_obj_df['significance (benjaminini-yekutieli)'] = ''
by_ps = false_discovery_control(sig_test_at_peak_obj_df['p'], method='by')
for i, t in enumerate([0.05, 0.01, 0.001]):
    sig_test_at_peak_obj_df.loc[by_ps < t, 'significance (benjaminini-yekutieli)'] = '*'*(i+1)

### Result

In [None]:
sig_test_at_peak_obj_df.reset_index().head(25)

#### Kruskal-Wallace

In [None]:
# Using Kruskal-Wallace, confirm that there is a significant difference in the best-case performance for each analytical variation
kw_pvals = {}
for i in analysis_idx:
    kw_pvals[i] = [evaluate_kw(replicate_test_at_peak_obj_df, i, 'balanced_accuracy (test)')]
kw_df = pd.DataFrame.from_dict(kw_pvals).T
kw_df.columns = ['p']

#### Bonferonni

In [None]:
# Calculate the corrected p-value significance as well w/ Bonferroni correction
kw_df['significance (bonferonni)'] = ''
kw_bf_ps = kw_df['p'] * kw_df.shape[0]
for i, t in enumerate([0.05, 0.01, 0.001]):
    kw_df.loc[kw_bf_ps<t, 'significance (bonferonni)'] = '*'*(i+1)

#### Benjaminini-Yekutieli

In [None]:
# Calculate the corrected p-value significance as well w/ Bonferroni correction
kw_df['significance (benjaminini-yekutieli)'] = ''
kw_bf_ps = kw_df['p'] * kw_df.shape[0]
for i, t in enumerate([0.05, 0.01, 0.001]):
    kw_df.loc[kw_bf_ps<t, 'significance (benjaminini-yekutieli)'] = '*'*(i+1)

### Result

In [None]:
kw_df

# Feature Importance

## Utility Functions

In [None]:
def format_feature_imp(val):
    # Strip leading and trailing brackets
    val = val[1:-2]

    # Create a dictionary from the remaining components
    imp_dict = dict()
    for v in val.split(', '):
        vcomps = v.split(': ')
        k = ': '.join(vcomps[:-1])
        v = float(vcomps[-1])
        imp_dict[k] = v
        
    return imp_dict

In [None]:
def feature_imp_report(df: pd.DataFrame, feature_col, weight_col) -> pd.DataFrame:
    # Convert the dictionaries contained with the feature_col dicts into dataframes which can be stacked
    raw_dfs = []
    weighted_dfs = []
    for r in df.iterrows():
        rvals = r[1]
        tmp_df = pd.DataFrame.from_dict({k: [v] for k, v in rvals[feature_col].items()})
        raw_dfs.append(tmp_df)

    # Stack the dataframes
    raw_feature_imps = pd.concat(raw_dfs).fillna(0)

    # Query the weights list a single time to avoid repeated querying expense
    weights = df[weight_col].astype('float64')
    
    # For each feature, calculate our desired statistics
    return_cols = ['Mean', 'STD', 'Weighted Mean', 'Weighted STD']
    return_df_dict = {}
    for c in raw_feature_imps.columns:
        # Single query of the dataframe, as pandas can be slow w/ repeated queries
        samples = raw_feature_imps[c]
        # Raw Mean
        c_mean = np.mean(samples)
        # Raw STD
        c_std = np.std(samples)
        # Weighted mean
        c_mean_weighted = np.average(samples, weights=weights)
        # Weighted STD
        c_std_weighted = weighted_std(samples, weights)
        # Stack them into a list and store it in the dictionary
        return_df_dict[c] = [c_mean, c_std, c_mean_weighted, c_std_weighted]

    # Return the result as a dataframe
    return pd.DataFrame.from_dict(return_df_dict, columns=return_cols, orient='index')

## Setup

In [None]:
# Isolate and stack the information relative to the value
sub_dfs = []

for df in df_map.values():
    tmp_df = df.loc[:, [*study_idxs, *analysis_idx, 'balanced_accuracy (test)', 'importance_by_permutation (test)']]
    sub_dfs.append(tmp_df)

feature_imp_df = pd.concat(sub_dfs)

# Isolate only the best trial from each replicate
feature_imp_df = feature_imp_df.sort_values('balanced_accuracy (test)').groupby([*analysis_idx, 'replicate']).tail(1).set_index(analysis_idx)

# Parse the feature importance list into a cleaner dictionary
feature_imp_df['importance_by_permutation (test)'] = feature_imp_df['importance_by_permutation (test)'].apply(format_feature_imp)
feature_imp_df

In [None]:
# Isolate PCA-derived features from the rest
pca_feature_imp_df = feature_imp_df.reset_index().loc[feature_imp_df.reset_index()['prep'].apply(lambda x: 'pca' in x), :].set_index([*analysis_idx])
pca_feature_imp_df

In [None]:
nonpca_feature_imp_df = feature_imp_df.drop(pca_feature_imp_df.index)
nonpca_feature_imp_df

## Un-transformed Features

### Full dataset (C2C6, Vertebral + Clinical)

In [None]:
query_df = nonpca_feature_imp_df.query("vert_range == 'C2C6'")
query_df = query_df.query("dataset == 'vert_full'")
query_report = feature_imp_report(query_df, 'importance_by_permutation (test)', 'balanced_accuracy (test)')
query_report.sort_values("Weighted Mean", ascending=False).head(10)

### Image-derived features only (C2C6, Vertebral Only)

In [None]:
query_df = nonpca_feature_imp_df.query("vert_range == 'C2C6'")
query_df = query_df.query("dataset == 'vert_img_only'")
query_report = feature_imp_report(query_df, 'importance_by_permutation (test)', 'balanced_accuracy (test)')
query_report.sort_values("Mean", ascending=False).head(10)

### Full dataset (C2C6, Slice-Based + Clinical)

In [None]:
query_df = nonpca_feature_imp_df.query("vert_range == 'C2C6'")
query_df = query_df.query("dataset == 'slice_full'")
query_report = feature_imp_report(query_df, 'importance_by_permutation (test)', 'balanced_accuracy (test)')
query_report.sort_values("Mean", ascending=False).head(10)

### Image-derived features only (C2C6, Slice-Based Only)

In [None]:
query_df = nonpca_feature_imp_df.query("vert_range == 'C2C6'")
query_df = query_df.query("dataset == 'slice_img_only'")
query_report = feature_imp_report(query_df, 'importance_by_permutation (test)', 'balanced_accuracy (test)')
query_report.sort_values("Mean", ascending=False).head(10)

### Full dataset (C2C7, Vertebral + Clinical)

In [None]:
query_df = nonpca_feature_imp_df.query("vert_range == 'C2C7'")
query_df = query_df.query("dataset == 'vert_full'")
query_report = feature_imp_report(query_df, 'importance_by_permutation (test)', 'balanced_accuracy (test)')
query_report.sort_values("Mean", ascending=False).head(10)

### Image-derived features only (C2C7, Vertebral Only)

In [None]:
query_df = nonpca_feature_imp_df.query("vert_range == 'C2C7'")
query_df = query_df.query("dataset == 'vert_img_only'")
query_report = feature_imp_report(query_df, 'importance_by_permutation (test)', 'balanced_accuracy (test)')
query_report.sort_values("Mean", ascending=False).head(10)

### Image-derived features only (C2C7, Slice-Based+Clinical)

In [None]:
query_df = nonpca_feature_imp_df.query("vert_range == 'C2C7'")
query_df = query_df.query("dataset == 'slice_full'")
query_report = feature_imp_report(query_df, 'importance_by_permutation (test)', 'balanced_accuracy (test)')
query_report.sort_values("Mean", ascending=False).head(10)

### Image-derived features only (C2C7, Slice-Based Only)

In [None]:
query_df = nonpca_feature_imp_df.query("vert_range == 'C2C7'")
query_df = query_df.query("dataset == 'slice_img_only'")
query_report = feature_imp_report(query_df, 'importance_by_permutation (test)', 'balanced_accuracy (test)')
query_report.sort_values("Mean", ascending=False).head(10)

### Clinical Features

In [None]:
query_df = nonpca_feature_imp_df.query("dataset == 'clin_only'")
query_report = feature_imp_report(query_df, 'importance_by_permutation (test)', 'balanced_accuracy (test)')
query_report.sort_values("Mean", ascending=False).head(10)

## PCA

### Full dataset (C2C6, Vertebral + Clinical)

In [None]:
query_df = pca_feature_imp_df.query("vert_range == 'C2C6'")
query_df = query_df.query("dataset == 'vert_full'")
query_report = feature_imp_report(query_df, 'importance_by_permutation (test)', 'balanced_accuracy (test)')
query_report.sort_values("Mean", ascending=False).head(10)

### Full dataset (C2C6, Vertebral Only)

In [None]:
query_df = pca_feature_imp_df.query("vert_range == 'C2C6'")
query_df = query_df.query("dataset == 'vert_img_only'")
query_report = feature_imp_report(query_df, 'importance_by_permutation (test)', 'balanced_accuracy (test)')
query_report.sort_values("Mean", ascending=False).head(10)

### Full dataset (C2C6, Slice-Based + Clinical)

In [None]:
query_df = pca_feature_imp_df.query("vert_range == 'C2C6'")
query_df = query_df.query("dataset == 'slice_full'")
query_report = feature_imp_report(query_df, 'importance_by_permutation (test)', 'balanced_accuracy (test)')
query_report.sort_values("Mean", ascending=False).head(10)

### Full dataset (C2C6, Slice-Based Only)

In [None]:
query_df = pca_feature_imp_df.query("vert_range == 'C2C6'")
query_df = query_df.query("dataset == 'slice_img_only'")
query_report = feature_imp_report(query_df, 'importance_by_permutation (test)', 'balanced_accuracy (test)')
query_report.sort_values("Mean", ascending=False).head(10)

### Full dataset (C2C7, Vertebral + Clinical)

In [None]:
query_df = pca_feature_imp_df.query("vert_range == 'C2C7'")
query_df = query_df.query("dataset == 'vert_full'")
query_report = feature_imp_report(query_df, 'importance_by_permutation (test)', 'balanced_accuracy (test)')
query_report.sort_values("Mean", ascending=False).head(10)

### Image-derived features only (C2C7, Vertebral Only)

In [None]:
query_df = pca_feature_imp_df.query("vert_range == 'C2C7'")
query_df = query_df.query("dataset == 'vert_img_only'")
query_report = feature_imp_report(query_df, 'importance_by_permutation (test)', 'balanced_accuracy (test)')
query_report.sort_values("Mean", ascending=False).head(10)

### Full dataset (C2C7, Slice-Based + Clinical)

In [None]:
query_df = pca_feature_imp_df.query("vert_range == 'C2C7'")
query_df = query_df.query("dataset == 'slice_full'")
query_report = feature_imp_report(query_df, 'importance_by_permutation (test)', 'balanced_accuracy (test)')
query_report.sort_values("Mean", ascending=False).head(10)

### Full dataset (C2C7, Slice-Based Only)

In [None]:
query_df = pca_feature_imp_df.query("vert_range == 'C2C7'")
query_df = query_df.query("dataset == 'slice_img_only'")
query_report = feature_imp_report(query_df, 'importance_by_permutation (test)', 'balanced_accuracy (test)')
query_report.sort_values("Mean", ascending=False).head(10)

### Clinical Features

In [None]:
query_df = pca_feature_imp_df.query("dataset == 'clin_only'")
query_report = feature_imp_report(query_df, 'importance_by_permutation (test)', 'balanced_accuracy (test)')
query_report.sort_values("Mean", ascending=False).head(10)

#### Feature Importance of Best Performing Pipeline

In [None]:
query_df = nonpca_feature_imp_df.query("vert_range == 'C2C6'")
query_df = query_df.query("dataset == 'vert_full'")
query_df = query_df.query("model == 'RandomForestClassifier'")
query_df = query_df.query("seg_algo == 'binary'")
query_df = query_df.query("weight == 'T1'")
query_df = query_df.query("ori == 'sag'")
query_df = query_df.query("prep == 'rfe'")
query_report = feature_imp_report(query_df, 'importance_by_permutation (test)', 'balanced_accuracy (test)')
query_report.sort_values("Mean", ascending=False).head(10)

#### Feature Importance of 10 Best Performing Pipelines

In [None]:
query_df = nonpca_feature_imp_df.sort_values(['balanced_accuracy (test)']).tail(10)
query_report = feature_imp_report(query_df, 'importance_by_permutation (test)', 'balanced_accuracy (test)')
query_report.sort_values("Mean", ascending=False).head(10)

## Best Model Performance-over-Time Plots

In [None]:
model_vals = {}
for k, v in df_map.items():
    best_model = v.sort_values(['balanced_accuracy (validate)', 'objective'], ascending=[True, False]).groupby('replicate').tail(1)
    model_mean = best_model['balanced_accuracy (test)'].astype('float64').mean()
    model_vals[k] = model_mean

best_models = sorted(model_vals.items(), reverse=True, key=lambda item: item[1])[:10]

In [None]:
best_models

#### Performance over Time

In [None]:
if not skip_plots:
    for v in best_models:
        # Organize the data
        k = v[0]
        df = df_map[v[0]]
        df[['balanced_accuracy (validate)', 'balanced_accuracy (test)']] = df[['balanced_accuracy (validate)', 'balanced_accuracy (test)']].astype('float64')
        trial_grouped = df.sort_values('trial', ascending=True).groupby('trial')
        mean_by_trial = trial_grouped['balanced_accuracy (test)'].mean()
        max_by_trial = trial_grouped['balanced_accuracy (test)'].max()
        min_by_trial = trial_grouped['balanced_accuracy (test)'].min()
        std_by_trial = trial_grouped['balanced_accuracy (test)'].std()
    
        from matplotlib import pyplot as plt
    
        # Initiate the plot
        fig, ax = plt.subplots(1)
        
        # Plot the data
        y = mean_by_trial
        y_std = std_by_trial
        c = "C0"
        ax.plot(y, color=c)
        upper_std = y + y_std
        lower_std = y - y_std
        ax.fill_between(np.arange(y.shape[0]), upper_std, lower_std, alpha=0.2, color=c)
    
        # Plot the max an min lines
        ax.plot(max_by_trial, color=c, linestyle=':')
        ax.plot(min_by_trial, color=c, linestyle=':')
        
        # Add axis labels
        plt.title(k)
        plt.xlabel('Trial')
        plt.ylabel('Balanced Accuracy (Testing)')
        
        # Display the plot
        plt.show()

#### Best/Worst Patients

In [None]:
from collections import Counter

correct_weights = dict()
correct_count = dict()
incorrect_weights = dict()
incorrect_count = dict()

for v in best_models:
    df = df_map[v[0]]
    w = v[1] - 0.5  # Delta from the "naive" performance
    best_trials_df = df.sort_values(['balanced_accuracy (test)', 'log_loss (test)'], ascending=[True, False]).groupby('replicate').tail(1)
    # Correctly predicted patients
    for cst in best_trials_df['correct_samples (test)']:
        patients = cst.replace('[', '').replace(']', '').replace('\n', '').split(' ')
        for p in patients:
            pw = correct_weights.get(p, 0)
            pc = correct_count.get(p, 0)
            correct_weights[p] = pw + w
            correct_count[p] = pc + 1
    # Incorrectly predicted patients
    for cst in best_trials_df['incorrect_samples (test)']:
        patients = cst.replace('[', '').replace(']', '').replace('\n', '').split(' ')
        for p in patients:
            pw = incorrect_weights.get(p, 0)
            pc = incorrect_count.get(p, 0)
            incorrect_weights[p] = pw + w
            incorrect_count[p] = pc + 1

In [None]:
correct_weights = dict(sorted(correct_weights.items(), key=lambda x: x[1]))
correct_weights = {k: float(np.round(v, 3)) / correct_count[k] for k, v in correct_weights.items()}
list(correct_weights.items())[-10:]

In [None]:
incorrect_weights = dict(sorted(incorrect_weights.items(), key=lambda x: x[1]))
incorrect_weights = {k: float(np.round(v, 3)) / incorrect_count[k] for k, v in incorrect_weights.items()}
list(correct_weights.items())[-10:]

In [None]:
mscc_df = pd.read_csv('manual_mscc.tsv', sep='\t')
tmp_df = pd.read_csv('mscc_c2c7_binary.tsv', sep='\t')
mscc_df = pd.merge(
    mscc_df, tmp_df,
    left_on='GRP',
    right_on='GRP',
    suffixes=[' [manual]', ' [automated]']
)
mscc_df = mscc_df.set_index('GRP')
mscc_df

In [None]:
correct_df = pd.merge(
    mscc_df,
    pd.DataFrame.from_dict({k: [v] for k, v in correct_weights.items()}, orient='index', columns=['avg_bacc_of_presence']),
    left_on='GRP',
    right_index=True
)

In [None]:
correct_df.groupby(['acq', 'weight', 'VertLevel [manual]'])['avg_bacc_of_presence'].mean()

In [None]:
correct_df.groupby(['acq', 'weight', 'VertLevel [automated]'])['avg_bacc_of_presence'].mean()

In [None]:
incorrect_df = pd.merge(
    mscc_df,
    pd.DataFrame.from_dict({k: [v] for k, v in incorrect_weights.items()}, orient='index', columns=['avg_bacc_of_presence']),
    left_on='GRP',
    right_index=True
)

In [None]:
incorrect_df.groupby(['acq', 'weight', 'VertLevel [manual]'])['avg_bacc_of_presence'].mean()

In [None]:
incorrect_df.groupby(['acq', 'weight', 'VertLevel [automated]'])['avg_bacc_of_presence'].mean()

In [None]:
mscc_df['Protocol'] = mscc_df['acq'] + "-" + mscc_df['weight']
sub_mscc_df = mscc_df.query("Protocol != 'unk-T1w'")
sub_mscc_df = sub_mscc_df.query("Protocol != 'axial-T1w'")

if not skip_plots:
    sns.catplot(data=sub_mscc_df, x="VertLevel [manual]", kind="count", hue="Protocol", order=sorted(list(set(sub_mscc_df["VertLevel [manual]"]))))
    plt.savefig("mscc_dist_manual.svg")

In [None]:
if not skip_plots:
    sns.catplot(data=sub_mscc_df, x="VertLevel [automated]", kind="count", hue="Protocol", order=sorted(list(set(sub_mscc_df["VertLevel [automated]"]))))
    plt.savefig("mscc_dist_automated.svg")