In [1]:
%load_ext autoreload
%autoreload 2

import plotly.graph_objects as go
import plotly.io as pio
from plotly.subplots import make_subplots

from htc.settings import settings
from htc_projects.rat.settings_rat import settings_rat
from htc_projects.species.settings_species import settings_species
from htc_projects.species.species_evaluation import icg_performance

pio.kaleido.scope.mathjax = None

In [2]:
df = icg_performance([
    settings.training_dir / "image" / f"{settings_species.model_timestamp}_baseline_pig_nested-0-2",
    settings.training_dir / "image" / f"{settings_species.model_timestamp}_baseline_rat_nested-0-2",
    settings.training_dir / "image" / f"{settings_species.model_timestamp}_projected-ICG_rat2pig_nested-0-2",
    settings.training_dir / "image" / f"{settings_species.model_timestamp}_projected-ICG_pig2rat_nested-0-2",
])
df

Unnamed: 0,subject_name,label_index,dice_metric,label_name,network,species
0,P005,1,0.000000,stomach,baseline_pig,pig
1,P113,1,0.203253,stomach,baseline_pig,pig
2,P105,1,0.608035,stomach,baseline_pig,pig
3,P104,1,0.996596,stomach,baseline_pig,pig
4,P102,1,0.999346,stomach,baseline_pig,pig
...,...,...,...,...,...,...
221,R046,11,0.999982,peritoneum,projected-ICG_pig2rat,rat
222,R044,11,0.906012,peritoneum,projected-ICG_pig2rat,rat
223,R043,11,0.999922,peritoneum,projected-ICG_pig2rat,rat
224,R047,11,0.989899,peritoneum,projected-ICG_pig2rat,rat


In [3]:
fig = make_subplots(
    rows=2,
    cols=1,
    shared_xaxes=True,
    vertical_spacing=0.15,
)

metric_name = "dice_metric"
network_renaming = {
    "baseline_pig": "baseline",
    "baseline_rat": "baseline",
    "projected-ICG_rat2pig": "xeno-learning",
    "projected-ICG_pig2rat": "xeno-learning",
}


def add_species(species: str, row: int, col: int) -> go.Figure:
    for network in df["network"].unique():
        if "baseline" in network:
            color = settings_species.species_colors[species]
        else:
            color = settings_species.xeno_learning_color

        network_short = network_renaming.get(network, network)

        for l, label_name in enumerate(df["label_name"].unique()):
            df_label = df[(df["species"] == species) & (df["label_name"] == label_name) & (df["network"] == network)]
            fig.add_trace(
                go.Box(
                    y=df_label[metric_name],
                    x=[settings_rat.labels_paper_renaming.get(label_name, label_name)] * len(df_label[metric_name]),
                    offsetgroup=network_short,
                    text=df_label["label_name"],
                    boxpoints="all",
                    boxmean=True,
                    name=network_short,
                    marker_color=color,
                    hovertext=df_label["subject_name"],
                    legendgroup=network_short,
                    showlegend=False,
                ),
                row=row,
                col=col,
            )

    return fig


add_species("pig", row=1, col=1)
add_species("rat", row=2, col=1)

fig.update_annotations(xanchor="left", x=0)
fig.update_yaxes(title_standoff=12)
fig.update_yaxes(tickfont=dict(size=16))
fig.update_xaxes(tickfont=dict(size=16))
xmin = -0.52
xmax = len(settings_species.icg_labels) - 0.6
fig.update_layout(
    yaxis_range=[-0.05, 1.05],
    xaxis_range=[xmin, xmax],
    xaxis_autorangeoptions_minallowed=xmin,
    xaxis_autorangeoptions_maxallowed=xmax,
)
fig.update_layout(boxmode="group", boxgap=0.2, boxgroupgap=0.45)
fig.update_layout(legend=dict(orientation="h", yanchor="bottom", y=1.01, xanchor="center", x=0.5))
fig.update_layout(
    template="plotly_white",
    height=700,
    width=1000,
    font_family="Libertinus Sans",
    font_size=16,
    margin=dict(l=0, r=0, b=0, t=0),
)
fig.update_xaxes(showticklabels=True)
fig.update_yaxes(title="<b>DSC</b>", title_standoff=10)
fig.write_image(settings_species.paper_dir / "icg_performance.pdf")
fig