In [1]:
%load_ext autoreload
%autoreload 2

import pandas as pd
import plotly.io as pio
from plotly import graph_objects as go
from plotly.subplots import make_subplots

from htc.utils.helper_functions import sort_labels
from htc_projects.rat.settings_rat import settings_rat
from htc_projects.species.settings_species import settings_species
from htc_projects.species.tables import paper_table

pio.kaleido.scope.mathjax = None

In [2]:
df = paper_table()
assert pd.isna(df["species_name"]).sum() == 0

In [3]:
fig = make_subplots(rows=3, cols=1, shared_xaxes=True)

# Bar plot
for i, species_name in enumerate(df["species_name"].unique()):
    for state in ["physiological", "unclear", "malperfused", "icg"]:
        df_state = df[(df["perfusion_state"] == state) & (df["species_name"] == species_name)]
        for subject_name in df_state["subject_name"].unique():
            df_subject = df_state[df_state["subject_name"] == subject_name]
            df_subject = df_subject.groupby(["label_name"], as_index=False).agg(
                n_images=pd.NamedAgg(column="image_name", aggfunc="nunique")
            )

            kwargs = {
                "marker_line_width": 0.1,
                "marker_line_color": "black" if species_name == "pig" else "gray",
            }
            if state == "icg":
                kwargs["marker_color"] = settings_species.species_colors[species_name]
                kwargs["marker_line_width"] = 1
                kwargs["marker_line_color"] = settings_species.icg_color
            elif state == "malperfused":
                kwargs["marker_color"] = settings_species.species_colors_malperfused[species_name]
            elif state == "physiological":
                kwargs["marker_color"] = settings_species.species_colors[species_name]
            elif state == "unclear":
                kwargs["marker_color"] = settings_species.species_colors[species_name]
            else:
                raise ValueError(f"Unknown state: {state}")

            fig.add_trace(
                go.Bar(
                    x=df_subject["label_name"],
                    y=df_subject["n_images"],
                    name=state,
                    **kwargs,
                ),
                row=i + 1,
                col=1,
            )

# Subject numbers
for i, species_name in enumerate(df["species_name"].unique()):
    df_species = df[df.species_name == species_name]
    df_species = df_species.groupby("label_name", as_index=False).agg(
        n_subjects=pd.NamedAgg(column="subject_name", aggfunc="nunique")
    )
    for _, row_data in df_species.iterrows():
        n_images = df[(df.species_name == species_name) & (df.label_name == row_data["label_name"])][
            "image_name"
        ].nunique()
        fig.add_annotation(
            x=row_data["label_name"],
            y=n_images,
            text=f"({n_images:,}; {row_data['n_subjects']:,})",
            xanchor="center",
            yanchor="bottom",
            xref="x",
            yref="y",
            showarrow=False,
            row=i + 1,
            col=1,
        )

labels = sort_labels(df.label_name.unique())
labels_plotly = [settings_rat.labels_paper_renaming.get(l, l) for l in labels]
fig.update_xaxes(categoryorder="array", categoryarray=labels)

fig.update_layout(barmode="stack")
fig.update_xaxes(title="<b>label name</b>", row=3)
fig.update_yaxes(title="<b># images</b>", title_standoff=5)
fig.update_layout(height=1200, width=1200, template="plotly_white")
fig.update_layout(font_family="Libertinus Sans", font_size=16)
fig.update_layout(showlegend=False)
fig.update_annotations(font=dict(size=20))
fig.update_yaxes(tickformat=",")
fig.write_image(settings_species.paper_dir / "data_statistics.pdf")
fig.show()