In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import plotly.graph_objects as go
import plotly.io as pio
from plotly.subplots import make_subplots

from htc.settings_seg import settings_seg
from htc.utils.helper_functions import sort_labels
from htc.utils.visualization import boxplot_symbols
from htc_projects.species.settings_species import settings_species
from htc_projects.species.species_evaluation import baseline_performance

pio.kaleido.scope.mathjax = None

In [2]:
df = baseline_performance()
df

Unnamed: 0,label_index,dice_metric,dice_metric_bootstraps,surface_dice_metric_mean,surface_dice_metric_mean_bootstraps,label_name,network,source_species,target_species
0,1,0.313875,"[0.33624719538943526, 0.3495526881545618, 0.31...",0.244484,"[0.25158693587530234, 0.2507578224890751, 0.25...",stomach,baseline_human,human,pig
1,2,0.618449,"[0.6668461833569955, 0.5755083175330398, 0.676...",0.383367,"[0.38216757144117564, 0.36171737972862295, 0.4...",small_bowel,baseline_human,human,pig
2,3,0.399226,"[0.3637829237616209, 0.3761668796248114, 0.356...",0.277434,"[0.25444409700727133, 0.26577364487947175, 0.2...",colon,baseline_human,human,pig
3,4,0.537611,"[0.5177921226446983, 0.6639741052093328, 0.543...",0.290353,"[0.27865982791702704, 0.3263526230365636, 0.28...",liver,baseline_human,human,pig
4,5,0.006627,"[0.006627236945288521, 0.0, 0.0066272369452885...",0.005879,"[0.005879223346710205, 0.0, 0.0058792233467102...",pancreas,baseline_human,human,pig
...,...,...,...,...,...,...,...,...,...
139,8,0.884005,"[0.8754188400654349, 0.8928557882362288, 0.886...",0.786107,"[0.7868702215869043, 0.7932600932576384, 0.796...",omentum,joint_pig-p+rat-p2human,pig-p+rat-p2human,human
140,9,0.852342,"[0.8735243572129143, 0.8496070424715679, 0.920...",0.745890,"[0.7891469269328647, 0.6919352936744688, 0.782...",lung,joint_pig-p+rat-p2human,pig-p+rat-p2human,human
141,10,0.810152,"[0.7936052020890538, 0.8183816260234429, 0.808...",0.843698,"[0.8274412400151724, 0.8362806841445268, 0.850...",skin,joint_pig-p+rat-p2human,pig-p+rat-p2human,human
142,11,0.613972,"[0.5338582265026429, 0.5842160611565597, 0.581...",0.644649,"[0.5738497746313542, 0.6150082722375818, 0.600...",peritoneum,joint_pig-p+rat-p2human,pig-p+rat-p2human,human


In [3]:
def species_comparison(metric_name: str, metric_name_short: str) -> go.Figure:
    fig = make_subplots(
        rows=1,
        cols=3,
        shared_yaxes=True,
        vertical_spacing=0.1,
        horizontal_spacing=0.03,
        column_widths=[3, 3, 2],
    )
    label_names = None

    targets = {
        (1, "pig"): [(0, "pig"), (1, "rat"), (2, "human")],
        (2, "rat"): [(0, "rat"), (1, "pig"), (2, "human")],
        (3, "human"): [(0, "human")],
        (3, "pig-p+rat-p2human"): [(1, "human")],
    }
    for (i, source_species), source_targets in targets.items():
        for j, target_species in source_targets:
            df_box = df[(df["source_species"] == source_species) & (df["target_species"] == target_species)]
            sort_labels(df_box)

            label_names_comparison = df_box["label_name"].tolist()
            if "background" not in label_names_comparison:
                label_names_comparison.append("background")
            if label_names is None:
                label_names = label_names_comparison
            else:
                assert label_names == label_names_comparison, (
                    f"Label names do not match: {label_names} vs {label_names_comparison}"
                )

            # Average across organs per bootstrap sample
            bootstraps_mean = np.mean(np.stack(df_box[f"{metric_name}_bootstraps"]), axis=0)
            bootstraps_median = np.median(np.stack(df_box[f"{metric_name}_bootstraps"]), axis=0)

            boxplot_symbols(
                fig,
                df_box[metric_name],
                df_box["label_name"],
                ci_mean=(np.quantile(bootstraps_mean, q=0.025), np.quantile(bootstraps_mean, q=0.975)),
                ci_median=(np.quantile(bootstraps_median, q=0.025), np.quantile(bootstraps_median, q=0.975)),
                trace_name=f"{source_species}2{target_species}",
                box_index=j,
                color=settings_species.species_colors[target_species],
                showlegend=i == 1,
                row=1,
                col=i,
            )

    fig.update_layout(height=400, width=1200, template="plotly_white")
    fig.update_layout(font_family="Libertinus Sans", font_size=16)
    fig.update_annotations(font=dict(size=20))
    fig.update_layout(showlegend=False)
    fig.update_yaxes(zeroline=False)
    fig.update_yaxes(title=f"<b>{metric_name_short}</b>", title_standoff=12, col=1)
    fig.update_xaxes(showticklabels=False)
    fig.update_layout(legend=dict(title=None, orientation="h", yanchor="bottom", y=1.05, xanchor="center", x=0.2))
    fig.update_layout(legend2=dict(title=None, orientation="h", yanchor="bottom", y=1.05, xanchor="center", x=0.8))
    fig.update_layout(margin=dict(l=0, r=0, t=0, b=0))
    fig.write_image(settings_species.paper_dir / f"domain_shift_performance_{metric_name_short}.pdf")

    return fig


species_comparison("dice_metric", "DSC")

In [4]:
species_comparison(settings_seg.nsd_aggregation_short, "NSD")