In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import plotly.graph_objects as go
import plotly.io as pio
from plotly.subplots import make_subplots

from htc.settings_seg import settings_seg
from htc.utils.helper_functions import sort_labels
from htc.utils.visualization import boxplot_symbols
from htc_projects.species.settings_species import settings_species
from htc_projects.species.species_evaluation import baseline_performance

pio.kaleido.scope.mathjax = None

In [2]:
df = baseline_performance()
df

Unnamed: 0,label_index,dice_metric,dice_metric_bootstraps,surface_dice_metric_mean,surface_dice_metric_mean_bootstraps,label_name,network,source_species,target_species
0,1,0.317947,"[0.34074875820826483, 0.362413721895955, 0.315...",0.249723,"[0.2633802482715009, 0.2757965110866666, 0.246...",stomach,baseline_human,human,pig
1,2,0.621578,"[0.6773898810167376, 0.5772742601424683, 0.685...",0.371644,"[0.3776121832154154, 0.3553180572116855, 0.393...",small_bowel,baseline_human,human,pig
2,3,0.227294,"[0.19941976478418066, 0.223918239511965, 0.198...",0.171938,"[0.160747846172235, 0.16414551307462483, 0.143...",colon,baseline_human,human,pig
3,4,0.554063,"[0.49920446694225895, 0.6304400772439798, 0.56...",0.258349,"[0.24268291278031065, 0.2950893455816702, 0.27...",liver,baseline_human,human,pig
4,5,0.000093,"[9.349933825433254e-05, 0.0, 9.349933825433254...",0.001226,"[0.0012258368943418776, 0.0, 0.001225836894341...",pancreas,baseline_human,human,pig
...,...,...,...,...,...,...,...,...,...
139,8,0.888016,"[0.8856765283399728, 0.8988884758969858, 0.875...",0.786295,"[0.7914921903963246, 0.8094676639163111, 0.774...",omentum,joint_pig-p+rat-p2human,pig-p+rat-p2human,human
140,9,0.862759,"[0.8484361171722412, 0.8544848998387655, 0.888...",0.730757,"[0.7052413533793555, 0.7371561147107019, 0.816...",lung,joint_pig-p+rat-p2human,pig-p+rat-p2human,human
141,10,0.815231,"[0.7925396424124574, 0.8349951520241876, 0.824...",0.844757,"[0.8140680448519234, 0.8611372900407308, 0.864...",skin,joint_pig-p+rat-p2human,pig-p+rat-p2human,human
142,11,0.612838,"[0.6067706619435445, 0.5486092639125251, 0.650...",0.654853,"[0.6474966068116444, 0.6145212012178757, 0.671...",peritoneum,joint_pig-p+rat-p2human,pig-p+rat-p2human,human


In [3]:
def species_comparison(metric_name: str, metric_name_short: str) -> go.Figure:
    fig = make_subplots(
        rows=1,
        cols=3,
        shared_yaxes=True,
        vertical_spacing=0.1,
        horizontal_spacing=0.03,
        column_widths=[3, 3, 2],
    )
    label_names = None

    targets = {
        (1, "pig"): [(0, "pig"), (1, "rat"), (2, "human")],
        (2, "rat"): [(0, "rat"), (1, "pig"), (2, "human")],
        (3, "human"): [(0, "human")],
        (3, "pig-p+rat-p2human"): [(1, "human")],
    }
    for (i, source_species), source_targets in targets.items():
        for j, target_species in source_targets:
            df_box = df[(df["source_species"] == source_species) & (df["target_species"] == target_species)]
            sort_labels(df_box)

            label_names_comparison = df_box["label_name"].tolist()
            if "background" not in label_names_comparison:
                label_names_comparison.append("background")
            if label_names is None:
                label_names = label_names_comparison
            else:
                assert (
                    label_names == label_names_comparison
                ), f"Label names do not match: {label_names} vs {label_names_comparison}"

            # Average across organs per bootstrap sample
            bootstraps_mean = np.mean(np.stack(df_box[f"{metric_name}_bootstraps"]), axis=0)
            bootstraps_median = np.median(np.stack(df_box[f"{metric_name}_bootstraps"]), axis=0)

            boxplot_symbols(
                fig,
                df_box[metric_name],
                df_box["label_name"],
                ci_mean=(np.quantile(bootstraps_mean, q=0.025), np.quantile(bootstraps_mean, q=0.975)),
                ci_median=(np.quantile(bootstraps_median, q=0.025), np.quantile(bootstraps_median, q=0.975)),
                trace_name=f"{source_species}2{target_species}",
                box_index=j,
                color=settings_species.species_colors[target_species],
                showlegend=i == 1,
                row=1,
                col=i,
            )

    fig.update_layout(height=400, width=1200, template="plotly_white")
    fig.update_layout(font_family="Libertinus Sans", font_size=16)
    fig.update_annotations(font=dict(size=20))
    fig.update_layout(showlegend=False)
    fig.update_yaxes(title=f"<b>{metric_name_short}</b>", title_standoff=12, col=1)
    fig.update_xaxes(showticklabels=False)
    fig.update_layout(legend=dict(title=None, orientation="h", yanchor="bottom", y=1.05, xanchor="center", x=0.2))
    fig.update_layout(legend2=dict(title=None, orientation="h", yanchor="bottom", y=1.05, xanchor="center", x=0.8))
    fig.update_layout(margin=dict(l=0, r=0, t=0, b=0))
    fig.write_image(settings_species.paper_dir / f"domain_shift_performance_{metric_name_short}.pdf")

    return fig


species_comparison("dice_metric", "DSC")

In [4]:
species_comparison(settings_seg.nsd_aggregation_short, "NSD")