In [None]:
import sys
sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/libraries')
sys.path.append('/storage/vbutoi/projects/ESE')
sys.path.append('/storage/vbutoi/projects/UniverSeg')

from ionpy.analysis import ResultsLoader
from ese.experiment.experiment import CalibrationExperiment
import seaborn as sns
import matplotlib.pyplot as plt
sns.set_style("darkgrid")

import os 
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
))

# Results loader object does everything
rs = ResultsLoader()
root = "/storage/vbutoi/scratch/ESE"

%load_ext yamlmagic
%load_ext autoreload
%autoreload 2

In [None]:
path = f"{root}/WMH_aug_runs"

dfc = rs.load_configs(
    path,
    properties=False,
)

df = rs.load_metrics(dfc)

In [None]:

os.environ["CUDA_VISIBLE_DEVICES"] = '1'

best_exp = rs.get_best_experiment(
    df=df,
    exp_class=CalibrationExperiment,
    device="cuda"
)
best_exp.build_dataloader()

In [None]:
best_exp.vis_loss_curves(height=6)

In [None]:
import torch
import einops
import numpy as np
from tqdm.notebook import tqdm
from ionpy.util.torchutils import to_device
from ese.experiment.metrics.grouping.portion_loss import get_perpix_boundary_dist, get_perpix_group_size
from ese.experiment.metrics.grouping.regions import get_label_region_sizes

def get_dataset_perf(
        exp, 
        split="val",
        ):
    dataloader = exp.val_dl if split=="val" else exp.train_dl
    items = []
    with torch.no_grad():
        for subj_idx, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
            # Get your image label pair and define some regions.
            x, y = to_device(batch, exp.device)
            # Reshape to a good size
            x = einops.rearrange(x, "b c h w -> (b c) 1 h w")
            y = einops.rearrange(y, "b c h w -> (b c) 1 h w")
            # Get the prediction
            yhat = exp.model(x)  
            # Extract predictions
            soft_pred = torch.sigmoid(yhat)
            # Get the hard prediction
            hard_pred = (soft_pred > 0.5).float()
            # Squeeze our tensors
            x = x.squeeze().cpu().numpy()
            y = y.squeeze().cpu().numpy()
            hard_pred = hard_pred.squeeze().cpu().numpy()
            # For the soft prediction, make sure to flip the confidences below 0.5
            soft_pred[soft_pred < 0.5] = 1 - soft_pred[soft_pred < 0.5]
            soft_pred = soft_pred.squeeze().cpu().numpy()
            # Get some performance metrics
            accuracy_map = (hard_pred==y).astype(np.float32)
            perf_per_dist = get_perpix_boundary_dist(y_pred=hard_pred)
            perf_per_regsize = get_perpix_group_size(y_pred=hard_pred)
            # Get region sizes
            gt_lab_region_sizes = get_label_region_sizes(label_map=y)
            pred_lab_region_sizes = get_label_region_sizes(label_map=hard_pred)
            # Wrap it in an item
            items.append({
                "subject_id": subj_idx,
                "image": x,
                "label_map": y,
                "conf_map": soft_pred,
                "pred_map": hard_pred,
                "perpix_accuracies": accuracy_map,
                "perf_per_dist": perf_per_dist,
                "perf_per_regsize": perf_per_regsize,
                "gt_lab_region_sizes": gt_lab_region_sizes,
                "pred_lab_region_sizes": pred_lab_region_sizes
            })
    return items

In [None]:
# val_perf is a dict where each item is the subj id
# with the y, ypred, yloss, ydice
predictions_list = get_dataset_perf(
    exp=best_exp, 
    split="val"
    )

In [None]:
from ese.experiment.analysis.inference import get_pixelinfo_df

pixel_preds_df = get_pixelinfo_df(predictions_list)

## Distribution of Pixel Confidences

In [None]:
plt.figure(figsize=(15, 10))  # Adjust the width and height as needed
g = sns.kdeplot(data=pixel_preds_df, x="conf", clip=(0, 1))
g.set_title("Distribution of Pixel Confidences")
plt.show()

## Distribution of Pixel Accuracies

In [None]:
plt.figure(figsize=(5, 10))  # Adjust the width and height as needed
g = sns.barplot(data=pixel_preds_df, y="accuracy")
g.set_title("Average Pixel-wise Accuracy")
plt.show()

In [None]:
pixel_preds_df['conf'].unique()

## Distribution of Pixel Confidences vs Label 

In [None]:
plt.figure(figsize=(15, 10))  # Adjust the width and height as needed
g = sns.kdeplot(data=pixel_preds_df, x="conf", hue="label", common_norm=False, clip=(0.5, 1))
g.set_title("Distribution of Pixel Confidences")
plt.show()

## Distribution of Pixel Accuracies vs Label 

In [None]:
plt.figure(figsize=(15, 10))  # Adjust the width and height as needed
g = sns.barplot(data=pixel_preds_df, x="label", y="accuracy")
g.set_title("Average Pixel-wise Accuracy per Label")
plt.show() 


## Distribution Predicted Region Sizes Over Images. 

In [None]:
from ese.experiment.analysis.err_diagrams import viz_region_size_distribution

viz_region_size_distribution(
    data_points=predictions_list,
    hue="label_map_type"
)

In [None]:
from ese.experiment.analysis.err_diagrams import viz_region_size_distribution

viz_region_size_distribution(
    data_points=predictions_list,
    hue="label_map_type",
    row_split="label"
)

## Distribution of Pixel Accuracies vs Distance to a Boundary

In [None]:
sns.lineplot(data=pixel_preds_df, x="dist_to_boundary", y="accuracy")

In [None]:
unique_labels = pixel_preds_df['label'].unique()
color_dict = {
    0.0: "blue",
    1.0: "orange"
}
for l_idx, lab in enumerate(unique_labels):
    sns.lineplot(
        data=pixel_preds_df[pixel_preds_df['label']==lab], 
        x="dist_to_boundary", 
        y="accuracy",
        color=sns.color_palette("tab10")[l_idx]
        )
    plt.title(f"Label {lab}")
    plt.show()

## Distribution Pixel Accuracies vs Predicted Size of Region (that pixel is in)

In [None]:
sns.lineplot(data=pixel_preds_df, x="group_size", y="accuracy")

In [None]:
unique_labels = pixel_preds_df['label'].unique()
for l_idx, lab in enumerate(unique_labels):
    sns.lineplot(
        data=pixel_preds_df[pixel_preds_df['label']==lab],
        x="group_size", 
        y="accuracy",
        color=sns.color_palette("tab10")[l_idx]
        )
    plt.title(f"Label {lab}")
    plt.show()

## Super Weird Plot Time: Distance to Boundary vs Parent Size, hue accuracy

In [None]:
unique_labels = pixel_preds_df['label'].unique()
for l_idx, lab in enumerate(unique_labels):
    sns.scatterplot(
        data=pixel_preds_df[pixel_preds_df['label']==lab],
        x="group_size", 
        y="dist_to_boundary",
        hue="accuracy"
        )
    plt.title(f"Label {lab}")
    plt.show()

## From ToySquare, we want to look at dist to boundary vs (confidence) and (accuracy) on the same plot.j

In [None]:
unique_labels = pixel_preds_df['label'].unique()
color_dict = {
    0.0: "blue",
    1.0: "orange"
}
for l_idx, lab in enumerate(unique_labels):
    plt.figure(figsize=(15, 10))  # Adjust the width and height as needed
    label_df = pixel_preds_df[pixel_preds_df['label']==lab]
    label_df_melted = label_df.melt(id_vars='dist_to_boundary', value_vars=['conf', 'accuracy'], var_name='y', value_name='value')
    sns.lineplot(
        data=label_df_melted,
        x="dist_to_boundary", 
        y="value",
        hue='y'
        )
    plt.title(f"Label {lab}")
    plt.show()