In [1]:
%load_ext autoreload
%autoreload 2

import pandas as pd
import plotly.express as px
import plotly.io as pio

from htc.utils.helper_functions import sort_labels
from htc_projects.rat.settings_rat import settings_rat
from htc_projects.species.settings_species import settings_species
from htc_projects.species.tables import paper_table

pio.kaleido.scope.mathjax = None

In [2]:
df = paper_table()
assert pd.isna(df["species_name"]).sum() == 0

df_labels = df.groupby(["species_name", "label_name", "subject_name"], as_index=False).agg(
    n_images=pd.NamedAgg(column="image_name", aggfunc="nunique")
)
df_labels["species_name"] = df_labels["species_name"].astype("category")
df_labels["species_name"] = df_labels["species_name"].cat.set_categories(settings_species.species_colors.keys())

df_labels = sort_labels(df_labels, sorting_cols=["species_name", "subject_name", "label_name"])
labels = sort_labels(df_labels.label_name.unique())
labels_plotly = [settings_rat.labels_paper_renaming.get(l, l) for l in labels]
df_labels = df_labels.replace(labels, labels_plotly)

colors = {}
for i, row in df_labels.iterrows():
    if row["subject_name"] not in colors:
        colors[row["subject_name"]] = settings_species.species_colors[row["species_name"]]

fig = px.bar(
    df_labels,
    x="label_name",
    y="n_images",
    color="subject_name",
    facet_col="species_name",
    facet_col_wrap=1,
    color_discrete_map=colors,
    category_orders={"label_name": labels_plotly},
)

for row_idx, species in enumerate(reversed(settings_species.species_colors.keys())):
    df_labels_species = df_labels[df_labels.species_name == species]
    df_labels_species = df_labels_species.groupby("label_name", as_index=False).agg(
        n_subjects=pd.NamedAgg(column="subject_name", aggfunc="nunique")
    )
    for i, row_data in df_labels_species.iterrows():
        n_images = df_labels[
            (df_labels.species_name == species) & (df_labels.label_name == row_data["label_name"])
        ].n_images.sum()
        fig.add_annotation(
            x=row_data["label_name"],
            y=n_images,
            text=row_data["n_subjects"],
            xanchor="center",
            yanchor="bottom",
            xref="x",
            yref="y",
            showarrow=False,
            row=row_idx + 1,
            col=1,
        )

for i, annotation in enumerate(fig.layout.annotations):
    if "species_name" in fig.layout.annotations[i].text:
        fig.layout.annotations[i].text = ""

fig.update_yaxes(matches=None)
fig.update_xaxes(title="<b>label name</b>", row=1)
fig.update_yaxes(title="<b># images</b>", title_standoff=5)
fig.update_layout(height=1200, width=1200, template="plotly_white")
fig.update_layout(font_family="Libertinus Sans", font_size=16)
fig.update_layout(showlegend=False)
fig.update_annotations(font=dict(size=20))
fig.update_yaxes(tickformat=",")
fig.write_image(settings_species.paper_dir / "data_statistics.pdf")
fig.show()