In [1]:
%load_ext autoreload
%autoreload 2

import pandas as pd
import plotly.graph_objects as go
import plotly.io as pio
from plotly.subplots import make_subplots

from htc.context.models.context_evaluation import best_run_data
from htc.context.settings_context import settings_context
from htc.settings_seg import settings_seg

pio.kaleido.scope.mathjax = None

In [2]:
df = best_run_data(test=True)

# Repeat reference for removal scenario
semantic2 = df[df.dataset == "semantic"].copy()
semantic2.replace({"dataset": {"semantic": "semantic2"}}, inplace=True)
df = pd.concat([df, semantic2], ignore_index=True)

df

Unnamed: 0,network,dataset,label_index,dice_metric,surface_distance_metric,surface_dice_metric_mean,label_name,modality
0,baseline,semantic,6,0.861474,9.262870,0.790171,stomach,HSI
1,baseline,semantic,5,0.957207,4.624576,0.791536,small_bowel,HSI
2,baseline,semantic,4,0.948589,3.187332,0.924161,colon,HSI
3,baseline,semantic,3,0.955587,3.820706,0.645928,liver,HSI
4,baseline,semantic,8,0.812361,2.863776,0.494350,gallbladder,HSI
...,...,...,...,...,...,...,...,...
663,organ_transplantation,semantic2,13,0.704187,38.340746,0.723531,fat_subcutaneous,RGB
664,organ_transplantation,semantic2,11,0.766200,22.827456,0.825938,peritoneum,RGB
665,organ_transplantation,semantic2,17,0.904380,2.448071,0.964088,major_vein,RGB
666,organ_transplantation,semantic2,18,0.888198,5.064728,0.485892,kidney_with_Gerotas_fascia,RGB


In [3]:
def dataset_comparison(metric_name: str) -> go.Figure:
    fig = make_subplots(
        rows=2,
        cols=3,
        subplot_titles=[
            "<b>image#HSI</b>",
            None,
            None,
            "<b>image#RGB</b>",
            None,
            None,
        ],
        shared_xaxes=True,
        shared_yaxes=True,
        vertical_spacing=0.1,
        horizontal_spacing=0.03,
        column_widths=[4, 3, 2],
    )

    network_renaming = {
        "baseline": "Baseline",
        "organ_transplantation": "Organ Transplantation",
    }

    datasets = list(settings_context.task_name_mapping.keys())
    datasets.insert(datasets.index("removal_0"), "semantic2")

    def add_modality(modality: str, row: int) -> go.Figure:
        df_m = df.query("modality == @modality")
        for dataset in datasets:
            dataset_name = "original2" if dataset == "semantic2" else settings_context.task_name_mapping[dataset]
            scenario = settings_context.scenario_mapping[dataset]
            if scenario == "isolation":
                col = 1
            elif scenario == "removal":
                col = 2
            else:
                col = 3

            for network in df_m["network"].unique():
                df_box = df_m.query("dataset == @dataset and network == @network")
                fig.add_trace(
                    go.Box(
                        y=df_box[metric_name],
                        x=[dataset_name] * len(df_box[metric_name]),
                        offsetgroup=network,
                        text=df_box["label_name"],
                        boxpoints="all",
                        boxmean=True,
                        name=network_renaming.get(network, network),
                        marker_color=settings_context.network_colors[network],
                        legendgroup=network,
                        showlegend=row == 1 and col == 1 and dataset == "isolation_0",
                    ),
                    row=row,
                    col=col,
                )

        return fig

    add_modality("HSI", row=1)
    add_modality("RGB", row=2)

    # Font size for subplot titles
    for i in fig["layout"]["annotations"]:
        i["font"] = dict(size=20)

    fig.update_annotations(xanchor="left", x=0)
    fig.update_yaxes(title_standoff=12)
    fig.update_yaxes(tickfont=dict(size=16))
    fig.update_xaxes(tickfont=dict(size=16))

    xmin = -0.6
    xmax_margin = 0.6
    fig.update_xaxes(
        range=[xmin, 4 - xmax_margin],
        autorangeoptions_minallowed=xmin,
        autorangeoptions_maxallowed=4 - xmax_margin,
        col=1,
    )
    fig.update_xaxes(
        range=[xmin, 3 - xmax_margin],
        autorangeoptions_minallowed=xmin,
        autorangeoptions_maxallowed=3 - xmax_margin,
        col=2,
    )
    fig.update_xaxes(
        range=[xmin, 2 - xmax_margin],
        autorangeoptions_minallowed=xmin,
        autorangeoptions_maxallowed=2 - xmax_margin,
        col=3,
    )
    fig.update_yaxes(range=[-0.05, 1.05])
    fig.update_layout(boxmode="group", boxgap=0.2, boxgroupgap=0.4)
    fig.update_layout(legend=dict(orientation="h", yanchor="bottom", y=0.46, xanchor="center", x=0.5))
    fig.update_layout(
        template="plotly_white",
        height=510,
        width=1100,
        font_family="Libertinus Serif",
        font_size=20,
        margin=dict(l=0, r=0, b=0, t=20),
    )

    return fig


fig = dataset_comparison("dice_metric")
fig.update_yaxes(title="<b>DSC</b>", col=1)
fig.write_image(settings_context.paper_extended_dir / "task_performance_DSC.pdf")
fig

In [4]:
fig = dataset_comparison(settings_seg.nsd_aggregation_short)
fig.update_yaxes(title="<b>NSD</b>", col=1)
fig.write_image(settings_context.paper_extended_dir / "task_performance_NSD.pdf")
fig