In [None]:
import sys
sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/libraries')
sys.path.append('/storage/vbutoi/projects/ESE')
sys.path.append('/storage/vbutoi/projects/UniverSeg')

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
sns.set_style("darkgrid")
sns.set_context("talk")

import os 
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
))

from ese.experiment.analysis.analyze_inf import load_cal_inference_stats
# Results loader object does everything
from ionpy.analysis import ResultsLoader
from pathlib import Path
root = Path("/storage/vbutoi/scratch/ESE")
rs = ResultsLoader()

# For using code without restarting.
%load_ext autoreload
%autoreload 2
# For using yaml configs.
%load_ext yamlmagic

In [None]:
%%yaml results_cfg 

log:
    root: /storage/vbutoi/scratch/ESE/inference
    inference_groups: 
        - '06_12_24_WMH_CorrectedMultiAnno'

options:
    add_dice_loss_rows: True
    drop_nan_metric_rows: True 
    remove_shared_columns: False
    equal_rows_per_cfg_assert: False 

In [None]:
inference_df = load_cal_inference_stats(
    results_cfg=results_cfg,
    load_cached=False,
)

In [None]:
# For the purpose of this experiment, we only care about a few columns in particular:
exp_columns = [
    "annotator",
    "data_id",
    "gt_volume",
    "hard_volume",
    "soft_volume",
    "pretrained_seed", 
    "slice_idx",
    "split",
    "task"
]
# Take these columns of the inference_df, drop other columns.
experiment_df = inference_df[exp_columns]
# Remove the duplicate rows.
experiment_df = experiment_df.drop_duplicates()

In [None]:
# Make the volume df by summing over the volumes.
vol_id_keys = ["data_id", "annotator", "task", "pretrained_seed", "split"]
# Accumulate the volumes.
volume_df = experiment_df.groupby(vol_id_keys).agg(
    gt_volume=("gt_volume", "sum"),
    hard_volume=("hard_volume", "sum"),
    soft_volume=("soft_volume", "sum"),
).reset_index()

In [None]:
# Make two new columns, one for the soft volume error and one for the hard volume error.
volume_df['soft_volume_error'] = volume_df['soft_volume'] - volume_df['gt_volume']
volume_df['hard_volume_error'] = volume_df['hard_volume'] - volume_df['gt_volume']

In [None]:
melted_error_df = pd.melt(
    volume_df,
    id_vars=["annotator", "data_id", "pretrained_seed", "task", "gt_volume", "soft_volume", "hard_volume"],
    value_vars=["soft_volume_error", "hard_volume_error"],
    var_name="volume_type",
    value_name="error",
)

In [None]:
import numpy as np

# Make some columns that are useful for plotting.
melted_error_df['abs_error'] = melted_error_df['error'].abs()
melted_error_df['log_abs_error'] = melted_error_df['error'].abs().apply(lambda x: np.log(x + 1))

# Experiment 1: Looking at one annotator on Amsterdam, let's look at how the volumetric comparison looks like.

In [None]:
exp_1_df = melted_error_df.select(annotator='observer_o12', task='Amsterdam')

In [None]:
g = sns.catplot(
    exp_1_df,
    x="data_id",
    y="error",
    hue="volume_type",
    aspect=3,
    height=8,
    sharey=False,
)
# For each subplot make a line at y = 0 to show the error.
for ax in g.axes.flat:
    ax.axhline(0, ls='--', color='red')

# Adjust the layout
plt.subplots_adjust(top=0.85)
g.fig.suptitle('Soft/Hard Volumetric Error', fontsize=30)

# Show the plot
plt.show()

In [None]:
g = sns.catplot(
    exp_1_df,
    x="data_id",
    y="log_abs_error",
    hue="volume_type",
    aspect=3,
    height=8,
    sharey=False,
)
# For each subplot make a line at y = 0 to show the error.
for ax in g.axes.flat:
    ax.axhline(0, ls='--', color='red')

# Adjust the layout
plt.subplots_adjust(top=0.85)
g.fig.suptitle('Absolute Soft/Hard Volumetric Log Error', fontsize=30)

# Show the plot
plt.show()

# Experiment 2: Let's look at the same thing but this time also for Singapore.

In [None]:
exp_2_df = melted_error_df.select(annotator='observer_o12', task='Singapore')

In [None]:
g = sns.catplot(
    exp_2_df,
    x="data_id",
    y="error",
    hue="volume_type",
    aspect=3,
    height=8,
    sharey=False,
)
# For each subplot make a line at y = 0 to show the error.
for ax in g.axes.flat:
    ax.axhline(0, ls='--', color='red')

# Adjust the layout
plt.subplots_adjust(top=0.85)
g.fig.suptitle('Singapore Soft/Hard Volumetric Error', fontsize=30)

# Show the plot
plt.show()

In [None]:
g = sns.catplot(
    exp_2_df,
    x="data_id",
    y="log_abs_error",
    hue="volume_type",
    aspect=3,
    height=8,
    sharey=False,
)
# For each subplot make a line at y = 0 to show the error.
for ax in g.axes.flat:
    ax.axhline(0, ls='--', color='red')

# Adjust the layout
plt.subplots_adjust(top=0.85)
g.fig.suptitle('Singapore Absolute Soft/Hard Volumetric Log Error', fontsize=30)

# Show the plot
plt.show()

# Experiment 3: The interesting thing about WMH is that we have multiple annotations per data_id (for some of the data_ids). Let's gather all of the data_ids that have all three annotators.

In [None]:
# Get the subset of the dataframe for which the number of unique annotators for each data_id is 3

# Step 1: Group by 'data_id' and count unique 'annotator' values
unique_counts = melted_error_df.groupby('data_id')['annotator'].nunique()
# Step 2: Filter 'data_id's that have exactly three unique 'annotator' values
filtered_data_ids = unique_counts[unique_counts == 3].index
# Step 3: Get the subset of rows with the filtered 'data_id's
multianno_melted_error_df= melted_error_df[melted_error_df['data_id'].isin(filtered_data_ids)]

In [None]:
multianno_melted_error_df['data_id'].unique().shape[0]

In [None]:
multianno_melted_error_df.select(data_id='101', pretrained_seed=40)