In [1]:
%load_ext autoreload
%autoreload 2

from functools import partial

import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.io as pio
from plotly.subplots import make_subplots
from sklearn.decomposition import PCA
from sklearn.neighbors import KernelDensity

from htc.evaluation.utils import aggregator_bootstrapping
from htc.utils.visualization import add_range_fill
from htc_projects.rat.settings_rat import settings_rat
from htc_projects.species.settings_species import settings_species
from htc_projects.species.tables import ischemic_clear_table, ischemic_table

pio.kaleido.scope.mathjax = None

In [2]:
labels = settings_species.malperfused_labels
df = ischemic_clear_table().query("label_name in @labels")
df_all = ischemic_table().query("label_name in @labels")

pca = PCA(n_components=2, whiten=True)
X_pca = pca.fit(np.stack(df_all["median_normalized_spectrum"]))  # Same PCA as in the perfusion performance figure
X_pca = pca.transform(np.stack(df["median_normalized_spectrum"]))
df["pca_1"] = X_pca[:, 0]
df["pca_2"] = X_pca[:, 1]


def _kde_estimate(x: pd.Series) -> list[float]:
    kde = KernelDensity(bandwidth=0.03).fit(x.values.reshape(-1, 1))
    scores = kde.score_samples(x_param.reshape(-1, 1))
    scores = np.exp(scores)
    return list(scores)


x_param = np.linspace(0, 1, 100)
dfg = df.groupby(["subject_name", "species_name", "label_name", "perfusion_state"], as_index=False).agg(
    param_hist=pd.NamedAgg(column="median_sto2", aggfunc=lambda x: list(np.histogram(x, bins=10, range=(0, 1))[0])),
    param_kde=pd.NamedAgg(column="median_sto2", aggfunc=_kde_estimate),
    median_normalized_spectrum=pd.NamedAgg(
        column="median_normalized_spectrum", aggfunc=lambda c: np.mean(np.stack(c), axis=0)
    ),
)

np.random.seed(42)
dfg_bootstraps = dfg.groupby(["species_name", "label_name", "perfusion_state"], as_index=False).apply(
    partial(aggregator_bootstrapping, columns=["param_hist", "param_kde", "median_normalized_spectrum"]),
    include_groups=False,
)
dfg_bootstraps

Unnamed: 0,species_name,label_name,perfusion_state,param_hist_mean,param_hist_q025,param_hist_q975,param_kde_mean,param_kde_q025,param_kde_q975,median_normalized_spectrum_mean,median_normalized_spectrum_q025,median_normalized_spectrum_q975
0,human,kidney,malperfused,"[0.0, 1.5251666666666663, 4.168166666666671, 0...","[0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 3.8333333333333335, 5.166666666666667, 0...","[0.0005798965152263106, 0.002123578047762019, ...","[1.1828145523805617e-10, 1.097130760680854e-09...","[0.0012739302457214598, 0.00461711432794694, 0...","[0.0032022055, 0.0031176112, 0.0030127605, 0.0...","[0.0023930239316541702, 0.0022902311698999255,...","[0.003892703913152218, 0.0038251106336247174, ..."
1,human,kidney,physiological,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4736875, 1.02...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.125, 0.5, 0.7...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9375, 1.5625,...","[2.6865632215636145e-96, 2.85880372561055e-93,...","[5.16999456523108e-104, 7.449237139496253e-101...","[7.617904592006329e-96, 8.106311669881179e-93,...","[0.0033815834, 0.0032942235, 0.0031783853, 0.0...","[0.003171373630175367, 0.003086777200223878, 0...","[0.003583240235457197, 0.003485814010491595, 0..."
2,pig,kidney,malperfused,"[0.4786666666666665, 10.319000000000003, 10.93...","[0.0, 8.0, 7.58125, 0.5833333333333334, 0.0, 0...","[1.0833333333333333, 12.833333333333334, 14.25...","[0.004984213537119698, 0.013629639537219242, 0...","[0.0013503707654039273, 0.004427869115141204, ...","[0.011591542652095523, 0.03008458328639306, 0....","[0.004277626, 0.0041130614, 0.003935362, 0.003...","[0.004053064179606736, 0.003917291900143028, 0...","[0.004497976053971797, 0.004313652007840574, 0..."
3,pig,kidney,physiological,"[0.0, 0.0, 0.0, 0.45035294117647473, 0.7447058...","[0.0, 0.0, 0.0, 0.0, 0.0, 1.3529411764705883, ...","[0.0, 0.0, 0.0, 1.411764705882353, 2.117647058...","[9.24604350922778e-34, 5.045640557429611e-32, ...","[2.7590599745769656e-62, 7.353505782935886e-60...","[2.8984462410561567e-33, 1.581705503816008e-31...","[0.003857146, 0.0036512616, 0.0034511588, 0.00...","[0.003602763294475153, 0.0033809958607889713, ...","[0.004106546868570149, 0.003912668081466108, 0..."
4,rat,kidney,malperfused,"[0.0, 4.147826086956521, 22.703434782608692, 2...","[0.0, 2.0, 19.347826086956523, 0.3913043478260...","[0.0, 6.436956521739129, 26.696739130434782, 5...","[0.00019260174189891799, 0.0006414484186636815...","[2.711103243240741e-07, 1.447589046530106e-06,...","[0.0005299139035696286, 0.0017501615627768335,...","[0.0033999572, 0.0029458532, 0.0026013795, 0.0...","[0.003298487770371139, 0.002869922516401857, 0...","[0.0035068094439338893, 0.003029434074414894, ..."
5,rat,kidney,physiological,"[0.0, 0.0, 0.0, 0.041625000000000016, 0.344624...","[0.0, 0.0, 0.0, 0.0, 0.041666666666666664, 0.3...","[0.0, 0.0, 0.0, 0.16666666666666666, 0.7083333...","[6.273780389642713e-24, 1.8211505233728557e-22...","[6.997214476285816e-53, 1.17836045646572e-50, ...","[2.5120241800371057e-23, 7.291893987478966e-22...","[0.0036077828, 0.0031189765, 0.0027206673, 0.0...","[0.003466355614364147, 0.0030054003931581975, ...","[0.003743799263611436, 0.0032317895907908677, ..."


In [3]:
n_rows = 3
n_cols = 2
fig = make_subplots(
    n_rows,
    n_cols,
    subplot_titles=["StO<sub>2</sub> distribution", "median spectra"],
    shared_xaxes=True,
    vertical_spacing=0.03,
    horizontal_spacing=0.11,
)
wavelengths = np.arange(500, 1000, 5) + 2.5
perfusion_colors = {
    "physiological": "gray",
    "malperfused": "indianred",
}

arrows = []
for row_idx, species_name in enumerate(settings_species.species_colors.keys()):
    arrow_positions = {
        28: [],
        51: [],
        75: [],
    }
    for perfusion_state in ["physiological", "malperfused"]:
        df_state = dfg_bootstraps[
            (dfg_bootstraps.species_name == species_name) & (dfg_bootstraps.perfusion_state == perfusion_state)
        ]
        df_state_spectrum = dfg[(dfg.species_name == species_name) & (dfg.perfusion_state == perfusion_state)]
        color = settings_species.species_colors[species_name]
        dash = "solid" if perfusion_state == "physiological" else "dot"

        # StO2 distribution
        add_range_fill(
            fig=fig,
            x=x_param,
            top_line=df_state["param_kde_q975"].item(),
            mid_line=df_state["param_kde_mean"].item(),
            bottom_line=df_state["param_kde_q025"].item(),
            color=color,
            line_dash=dash,
            name=perfusion_state,
            row=row_idx + 1,
            col=1,
        )

        fig.update_yaxes(title="<b>density [a.u.]</b>", title_standoff=0, row=row_idx + 1, col=1)
        if row_idx == n_rows - 1:
            fig.update_xaxes(title="<b>StO<sub>2</sub> [a.u.]</b>", row=row_idx + 1, col=1)

        spec = np.stack(df_state_spectrum["median_normalized_spectrum"])  # Spectra from all subjects
        spec_mean = np.mean(spec, axis=0)
        spec_std = np.std(spec, axis=0)

        for c in arrow_positions.keys():
            arrow_positions[c].append(spec_mean[c])

        # Median spectra
        add_range_fill(
            fig=fig,
            x=wavelengths,
            top_line=spec_mean + spec_std,
            mid_line=spec_mean,
            bottom_line=spec_mean - spec_std,
            color=color,
            line_dash=dash,
            name=perfusion_state,
            row=row_idx + 1,
            col=2,
        )

        fig.update_yaxes(title="<b>L1 normalized<br>reflectance [a.u.]</b>", row=row_idx + 1, col=2)
        if row_idx == n_rows - 1:
            fig.update_xaxes(title="<b>wavelength [nm]</b>", row=row_idx + 1, col=2)

    arrows.append(arrow_positions)

for row_idx, arrow in enumerate(arrows):
    for c, positions in arrow.items():
        y_overfill = 0.0002 if positions[0] > positions[1] else -0.0002
        fig.add_annotation(
            ax=wavelengths[c],  # xstart
            ay=positions[0] + y_overfill,  # ystart
            x=wavelengths[c],  # xend
            y=positions[1] - y_overfill,  # yend
            xref=f"x{row_idx * n_cols + 2}",
            yref=f"y{row_idx * n_cols + 2}",
            axref=f"x{row_idx * n_cols + 2}",
            ayref=f"y{row_idx * n_cols + 2}",
            showarrow=True,
            arrowhead=3,
            arrowwidth=2,
            arrowcolor="black",
            text="",
            row=row_idx + 1,
            col=2,
        )

fig.update_layout(height=800, width=1200, template="plotly_white")
fig.update_annotations(font=dict(size=20))
for i, annotation in enumerate(fig.layout.annotations):
    fig.layout.annotations[i].text = settings_rat.labels_paper_renaming.get(
        fig.layout.annotations[i].text, fig.layout.annotations[i].text
    )
fig.update_yaxes(title_standoff=5)
fig.update_layout(font_family="Libertinus Sans", font_size=16)
fig.update_layout(showlegend=False)
fig.write_image(settings_species.paper_dir / "malperfused_spectra_base.pdf")
fig

In [4]:
n_rows = 1
n_cols = len(labels)
fig = make_subplots(
    n_rows,
    n_cols,
    shared_xaxes=True,
    shared_yaxes=True,
    vertical_spacing=0.03,
    horizontal_spacing=0.03,
)

point_traces = []
centroid_traces = []
for col_idx, label_name in enumerate(labels):
    for species_name in df.species_name.unique():
        prev_xc = None
        prev_yc = None
        for perfusion_state in ["physiological", "malperfused"]:
            df_state = df[
                (df.species_name == species_name)
                & (df.label_name == label_name)
                & (df.perfusion_state == perfusion_state)
            ]
            marker = dict(
                color=settings_species.species_colors[species_name],
                symbol="circle" if perfusion_state == "physiological" else "star",
            )

            point_traces.append(
                dict(
                    trace=go.Scatter(
                        x=df_state["pca_1"],
                        y=df_state["pca_2"],
                        mode="markers",
                        marker=marker,
                        name=f"{species_name}#{perfusion_state}",
                        legendgroup=species_name,
                        opacity=0.5,
                    ),
                    row=1,
                    col=col_idx + 1,
                )
            )

            # Hierarchical aggregation of the centroids
            df_state_pca = df_state.groupby("subject_name")[["pca_1", "pca_2"]].mean()
            xc = df_state_pca["pca_1"].mean()
            yc = df_state_pca["pca_2"].mean()

            centroid_traces.append(
                dict(
                    trace=go.Scatter(
                        x=[xc],
                        y=[yc],
                        mode="markers",
                        marker=marker,
                        name=f"{species_name}#{perfusion_state}",
                        legendgroup=species_name,
                        marker_size=15,
                        marker_line_width=1.5,
                    ),
                    row=1,
                    col=col_idx + 1,
                )
            )

            if prev_xc is None:
                prev_xc = xc
                prev_yc = yc
            else:
                fig.add_annotation(
                    ax=prev_xc,  # xstart
                    ay=prev_yc,  # ystart
                    x=xc,  # xend
                    y=yc,  # yend
                    xref="x1",
                    yref="y1",
                    axref="x1",
                    ayref="y1",
                    showarrow=True,
                    arrowhead=2,
                    arrowwidth=1.5,
                    arrowcolor="black",
                    text="",
                )

# Make sure the mean traces are below the point traces
for t in point_traces:
    fig.add_trace(**t)
for t in centroid_traces:
    fig.add_trace(**t)

fig.update_xaxes(title=f"<b>PC 1 ({pca.explained_variance_ratio_[0] * 100:.0f} %) [a.u.]</b>", row=1, col=1)
fig.update_xaxes(title=f"<b>PC 1 ({pca.explained_variance_ratio_[0] * 100:.0f} %) [a.u.]</b>", row=1, col=2)
fig.update_yaxes(title=f"<b>PC 2 ({pca.explained_variance_ratio_[1] * 100:.0f} %) [a.u.]</b>", row=1, col=1)

fig.update_annotations(font=dict(size=20))
for i, annotation in enumerate(fig.layout.annotations):
    fig.layout.annotations[i].text = settings_rat.labels_paper_renaming.get(
        fig.layout.annotations[i].text, fig.layout.annotations[i].text
    )
fig.update_layout(template="plotly_white", width=1200, height=600)
fig.update_layout(font_family="Libertinus Sans", font_size=16)
fig.update_layout(showlegend=False)

fig.write_image(settings_species.paper_dir / "malperfused_pca_base.pdf")
fig