In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import plotly.graph_objects as go
import plotly.io as pio
from plotly.subplots import make_subplots

from htc.settings_seg import settings_seg
from htc.utils.helper_functions import sort_labels
from htc.utils.visualization import boxplot_symbols
from htc_projects.species.settings_species import settings_species
from htc_projects.species.species_evaluation import baseline_performance

pio.kaleido.scope.mathjax = None

In [2]:
df = baseline_performance()
df

Unnamed: 0,label_index,dice_metric,dice_metric_bootstraps,surface_dice_metric_mean,surface_dice_metric_mean_bootstraps,label_name,network,source_species,target_species
0,1,0.289921,"[0.3071158116976992, 0.32515001276416633, 0.28...",0.229507,"[0.2386613830710378, 0.2482425972753955, 0.223...",stomach,baseline_human,human,pig
1,2,0.605336,"[0.6644174919511754, 0.5524962330706954, 0.663...",0.368218,"[0.3724041706106172, 0.34863130418807103, 0.38...",small_bowel,baseline_human,human,pig
2,3,0.323029,"[0.2747438077832579, 0.3178767478404278, 0.272...",0.244206,"[0.22228894259760032, 0.23346208199123283, 0.2...",colon,baseline_human,human,pig
3,4,0.614768,"[0.5747257457272265, 0.6713602712507882, 0.605...",0.293276,"[0.2784028802597699, 0.3183857076161617, 0.304...",liver,baseline_human,human,pig
4,5,0.000000,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.000000,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",pancreas,baseline_human,human,pig
...,...,...,...,...,...,...,...,...,...
139,8,0.885607,"[0.8856873640977315, 0.8977617739893754, 0.874...",0.786933,"[0.7932642738471432, 0.8103319464540271, 0.776...",omentum,joint_pig-p+rat-p2human,pig-p+rat-p2human,human
140,9,0.868888,"[0.8768972320026821, 0.8646927743487888, 0.877...",0.755025,"[0.7511523003048369, 0.7593832235866123, 0.815...",lung,joint_pig-p+rat-p2human,pig-p+rat-p2human,human
141,10,0.816514,"[0.7865177496658932, 0.8398494560677418, 0.829...",0.854466,"[0.819408978756749, 0.8812234833558145, 0.8749...",skin,joint_pig-p+rat-p2human,pig-p+rat-p2human,human
142,11,0.625651,"[0.6273911029406511, 0.5925937585247442, 0.669...",0.653304,"[0.6474005020521824, 0.615219682115578, 0.6916...",peritoneum,joint_pig-p+rat-p2human,pig-p+rat-p2human,human


In [3]:
def species_comparison(metric_name: str, metric_name_short: str) -> go.Figure:
    fig = make_subplots(
        rows=1,
        cols=3,
        shared_yaxes=True,
        vertical_spacing=0.1,
        horizontal_spacing=0.03,
        column_widths=[3, 3, 2],
    )
    label_names = None

    targets = {
        (1, "pig"): [(0, "pig"), (1, "rat"), (2, "human")],
        (2, "rat"): [(0, "rat"), (1, "pig"), (2, "human")],
        (3, "human"): [(0, "human")],
        (3, "pig-p+rat-p2human"): [(1, "human")],
    }
    for (i, source_species), source_targets in targets.items():
        for j, target_species in source_targets:
            df_box = df[(df["source_species"] == source_species) & (df["target_species"] == target_species)]
            sort_labels(df_box)

            label_names_comparison = df_box["label_name"].tolist()
            if "background" not in label_names_comparison:
                label_names_comparison.append("background")
            if label_names is None:
                label_names = label_names_comparison
            else:
                assert label_names == label_names_comparison, (
                    f"Label names do not match: {label_names} vs {label_names_comparison}"
                )

            # Average across organs per bootstrap sample
            bootstraps_mean = np.mean(np.stack(df_box[f"{metric_name}_bootstraps"]), axis=0)
            bootstraps_median = np.median(np.stack(df_box[f"{metric_name}_bootstraps"]), axis=0)

            boxplot_symbols(
                fig,
                df_box[metric_name],
                df_box["label_name"],
                ci_mean=(np.quantile(bootstraps_mean, q=0.025), np.quantile(bootstraps_mean, q=0.975)),
                ci_median=(np.quantile(bootstraps_median, q=0.025), np.quantile(bootstraps_median, q=0.975)),
                trace_name=f"{source_species}2{target_species}",
                box_index=j,
                color=settings_species.species_colors[target_species],
                showlegend=i == 1,
                row=1,
                col=i,
            )

    fig.update_layout(height=400, width=1200, template="plotly_white")
    fig.update_layout(font_family="Libertinus Sans", font_size=16)
    fig.update_annotations(font=dict(size=20))
    fig.update_layout(showlegend=False)
    fig.update_yaxes(zeroline=False)
    fig.update_yaxes(title=f"<b>{metric_name_short}</b>", title_standoff=12, col=1)
    fig.update_xaxes(showticklabels=False)
    fig.update_layout(legend=dict(title=None, orientation="h", yanchor="bottom", y=1.05, xanchor="center", x=0.2))
    fig.update_layout(legend2=dict(title=None, orientation="h", yanchor="bottom", y=1.05, xanchor="center", x=0.8))
    fig.update_layout(margin=dict(l=0, r=0, t=0, b=0))
    fig.write_image(settings_species.paper_dir / f"domain_shift_performance_{metric_name_short}.pdf")

    return fig


species_comparison("dice_metric", "DSC")

In [4]:
species_comparison(settings_seg.nsd_aggregation_short, "NSD")