In [1]:
%load_ext autoreload
%autoreload 2

from functools import partial

import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.io as pio
from plotly.subplots import make_subplots
from sklearn.decomposition import PCA
from sklearn.neighbors import KernelDensity

from htc.evaluation.utils import aggregator_bootstrapping
from htc.utils.visualization import add_range_fill
from htc_projects.rat.settings_rat import settings_rat
from htc_projects.species.settings_species import settings_species
from htc_projects.species.tables import icg_table, ischemic_clear_table

pio.kaleido.scope.mathjax = None

In [2]:
labels = settings_species.malperfused_labels
df = ischemic_clear_table().query("label_name in @labels")
df = pd.concat([df, icg_table().query("label_name in @labels and icg_clear")])

pca = PCA(n_components=2, whiten=True)
X_pca = pca.fit(np.stack(df["median_normalized_spectrum"]))
X_pca = pca.transform(np.stack(df["median_normalized_spectrum"]))
df["pca_1"] = X_pca[:, 0]
df["pca_2"] = X_pca[:, 1]


def _kde_estimate(x: pd.Series) -> list[float]:
    kde = KernelDensity(bandwidth=0.03).fit(x.values.reshape(-1, 1))
    scores = kde.score_samples(x_param.reshape(-1, 1))
    scores = np.exp(scores)
    return list(scores)


x_param = np.linspace(0, 1, 100)
dfg = df.groupby(["subject_name", "species_name", "label_name", "perfusion_state"], as_index=False).agg(
    param_hist=pd.NamedAgg(column="median_sto2", aggfunc=lambda x: list(np.histogram(x, bins=10, range=(0, 1))[0])),
    param_kde=pd.NamedAgg(column="median_sto2", aggfunc=_kde_estimate),
    median_normalized_spectrum=pd.NamedAgg(
        column="median_normalized_spectrum", aggfunc=lambda c: np.mean(np.stack(c), axis=0)
    ),
)

np.random.seed(42)
dfg_bootstraps = dfg.groupby(["species_name", "label_name", "perfusion_state"], as_index=False).apply(
    partial(aggregator_bootstrapping, columns=["param_hist", "param_kde", "median_normalized_spectrum"]),
    include_groups=False,
)
dfg_bootstraps

Unnamed: 0,species_name,label_name,perfusion_state,param_hist_mean,param_hist_q025,param_hist_q975,param_kde_mean,param_kde_q025,param_kde_q975,median_normalized_spectrum_mean,median_normalized_spectrum_q025,median_normalized_spectrum_q975
0,human,kidney,malperfused,"[0.0, 1.5251666666666663, 4.168166666666671, 0...","[0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 3.8333333333333335, 5.166666666666667, 0...","[0.0005798965152263106, 0.002123578047762019, ...","[1.1828145523805617e-10, 1.097130760680854e-09...","[0.0012739302457214598, 0.00461711432794694, 0...","[0.0032022055, 0.0031176112, 0.0030127605, 0.0...","[0.002393024, 0.0022902312, 0.0021753488, 0.00...","[0.003892704, 0.0038251106, 0.003739298, 0.003..."
1,human,kidney,physiological,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4736875, 1.02...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.125, 0.5, 0.7...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9375, 1.5625,...","[2.6865632215636145e-96, 2.85880372561055e-93,...","[5.16999456523108e-104, 7.449237139496253e-101...","[7.617904592006329e-96, 8.106311669881179e-93,...","[0.0033815834, 0.0032942235, 0.0031783853, 0.0...","[0.0031713736, 0.0030867772, 0.0029705372, 0.0...","[0.0035832403, 0.003485814, 0.0033698967, 0.00..."
2,pig,kidney,icg,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.2039999999999...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.85714285...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 6.2857142857142...","[1.9595288856168166e-101, 2.5905908467350745e-...","[3.386359982432477e-123, 9.533672909908047e-12...","[4.439009790570694e-101, 5.774727520715077e-98...","[0.0050422843, 0.0046110135, 0.0041757366, 0.0...","[0.0045675286, 0.0041021863, 0.003793956, 0.00...","[0.0055288477, 0.0050482987, 0.004492626, 0.00..."
3,pig,kidney,malperfused,"[0.4866666666666671, 10.344583333333329, 10.94...","[0.0, 8.166666666666666, 7.5, 0.58333333333333...","[1.0833333333333333, 12.918749999999998, 14.58...","[0.005026809698805409, 0.013741381271797007, 0...","[0.001414590426453701, 0.004578282320052979, 0...","[0.011456980089019117, 0.029743583414107103, 0...","[0.004275267, 0.0041099847, 0.0039318395, 0.00...","[0.0040294416, 0.0038954532, 0.0037229406, 0.0...","[0.004521439, 0.0043267733, 0.004149998, 0.003..."
4,pig,kidney,physiological,"[0.0, 0.0, 0.0, 0.4484705882352982, 0.73705882...","[0.0, 0.0, 0.0, 0.0, 0.0, 1.411764705882353, 3...","[0.0, 0.0, 0.0, 1.411764705882353, 2.117647058...","[9.20739755932901e-34, 5.024551150694175e-32, ...","[2.49309867227565e-62, 6.658204491033526e-60, ...","[2.8984462410561567e-33, 1.5817055038160076e-3...","[0.0038558412, 0.003650433, 0.0034497273, 0.00...","[0.003608887, 0.0034027814, 0.0032037941, 0.00...","[0.004132814, 0.0039318814, 0.0037367179, 0.00..."
5,rat,kidney,icg,"[0.0, 0.0, 9.806999999999992, 8.82800000000000...","[0.0, 0.0, 5.329166666666667, 6.33333333333333...","[0.0, 0.0, 15.004166666666663, 11.666666666666...","[1.6045033089444734e-13, 1.862976445809403e-12...","[6.571090344496185e-16, 1.0065665503801044e-14...","[4.7108101970915e-13, 5.454639838854299e-12, 5...","[0.0037707423, 0.003342745, 0.0029389416, 0.00...","[0.0034905705, 0.00311628, 0.002752502, 0.0023...","[0.00395021, 0.0035056574, 0.003076112, 0.0026..."
6,rat,kidney,malperfused,"[0.0, 4.135173913043476, 22.68552173913042, 2....","[0.0, 1.8695652173913044, 19.389130434782608, ...","[0.0, 6.479347826086956, 26.347826086956523, 5...","[0.0002005870303335467, 0.0006675206262315268,...","[1.938102031567194e-07, 1.090336967215549e-06,...","[0.0005489003226360436, 0.0018100977093636712,...","[0.0034022, 0.0029476148, 0.0026028526, 0.0022...","[0.00329854, 0.0028684214, 0.002536526, 0.0022...","[0.0035085343, 0.0030323197, 0.002674208, 0.00..."
7,rat,kidney,physiological,"[0.0, 0.0, 0.0, 0.041624999999999995, 0.345166...","[0.0, 0.0, 0.0, 0.0, 0.041666666666666664, 0.3...","[0.0, 0.0, 0.0, 0.125, 0.75, 1.958333333333333...","[6.2737803896427124e-24, 1.8211505233728526e-2...","[7.056422199231646e-53, 1.1891404102957045e-50...","[1.8840181350278293e-23, 5.468920490609225e-22...","[0.0036082522, 0.003119371, 0.0027209397, 0.00...","[0.0034693044, 0.0030092457, 0.0026292389, 0.0...","[0.0037483706, 0.0032304367, 0.0028147385, 0.0..."


In [3]:
n_rows = 3
n_cols = 2
fig = make_subplots(
    n_rows,
    n_cols,
    subplot_titles=["StO<sub>2</sub> distribution", "median spectra"],
    shared_xaxes=True,
    vertical_spacing=0.03,
    horizontal_spacing=0.11,
)
wavelengths = np.arange(500, 1000, 5) + 2.5
dash_styles = {
    "physiological": "solid",
    "malperfused": "dot",
    "icg": "dash",
}
icg_position = 61

arrows = []
for row_idx, species_name in enumerate(settings_species.species_colors.keys()):
    arrow_positions = {
        28: [],
        51: [],
        75: [],
    }

    states = ["physiological", "malperfused"]
    if species_name != "human":
        states.append("icg")
        arrow_positions[icg_position] = []

    for perfusion_state in states:
        df_state = dfg_bootstraps[
            (dfg_bootstraps.species_name == species_name) & (dfg_bootstraps.perfusion_state == perfusion_state)
        ]
        df_state_spectrum = dfg[(dfg.species_name == species_name) & (dfg.perfusion_state == perfusion_state)]
        color = settings_species.species_colors[species_name]
        dash = dash_styles[perfusion_state]

        if perfusion_state != "icg":
            # StO2 distribution
            add_range_fill(
                fig=fig,
                x=x_param,
                top_line=df_state["param_kde_q975"].item(),
                mid_line=df_state["param_kde_mean"].item(),
                bottom_line=df_state["param_kde_q025"].item(),
                color=color,
                line_dash=dash,
                name=perfusion_state,
                row=row_idx + 1,
                col=1,
            )

            fig.update_yaxes(title="<b>density [a.u.]</b>", title_standoff=0, row=row_idx + 1, col=1)
            if row_idx == n_rows - 1:
                fig.update_xaxes(title="<b>StO<sub>2</sub> [a.u.]</b>", row=row_idx + 1, col=1)

        spec = np.stack(df_state_spectrum["median_normalized_spectrum"])  # Spectra from all subjects
        spec_mean = np.mean(spec, axis=0)
        spec_std = np.std(spec, axis=0)

        for c in arrow_positions.keys():
            if c == icg_position and perfusion_state != "malperfused":
                arrow_positions[c].append(spec_mean[c])
            elif c != icg_position and perfusion_state != "icg":
                arrow_positions[c].append(spec_mean[c])

        # Median spectra
        add_range_fill(
            fig=fig,
            x=wavelengths,
            top_line=spec_mean + spec_std,
            mid_line=spec_mean,
            bottom_line=spec_mean - spec_std,
            color=color,
            line_dash=dash,
            name=perfusion_state,
            row=row_idx + 1,
            col=2,
        )

        fig.update_yaxes(title="<b>L1 normalized<br>reflectance [a.u.]</b>", row=row_idx + 1, col=2)
        if row_idx == n_rows - 1:
            fig.update_xaxes(title="<b>wavelength [nm]</b>", row=row_idx + 1, col=2)

    arrows.append(arrow_positions)

for row_idx, arrow in enumerate(arrows):
    for c, positions in arrow.items():
        y_overfill = 0.0002 if positions[0] > positions[1] else -0.0002
        fig.add_annotation(
            ax=wavelengths[c],  # xstart
            ay=positions[0] + y_overfill,  # ystart
            x=wavelengths[c],  # xend
            y=positions[1] - y_overfill,  # yend
            xref=f"x{row_idx * n_cols + 2}",
            yref=f"y{row_idx * n_cols + 2}",
            axref=f"x{row_idx * n_cols + 2}",
            ayref=f"y{row_idx * n_cols + 2}",
            showarrow=True,
            arrowhead=3,
            arrowwidth=2,
            arrowcolor="gray" if c == icg_position else "black",
            text="",
            row=row_idx + 1,
            col=2,
        )

fig.update_layout(height=800, width=1200, template="plotly_white")
fig.update_annotations(font=dict(size=20))
for i, annotation in enumerate(fig.layout.annotations):
    fig.layout.annotations[i].text = settings_rat.labels_paper_renaming.get(
        fig.layout.annotations[i].text, fig.layout.annotations[i].text
    )
fig.update_yaxes(title_standoff=5)
fig.update_layout(font_family="Libertinus Sans", font_size=16)
fig.update_layout(showlegend=False)
fig.write_image(settings_species.paper_dir / "malperfused_spectra_base.pdf")
fig

In [4]:
def pca_figure(show_malperfused: bool = True) -> go.Figure:
    n_rows = 1
    n_cols = len(labels)
    fig = make_subplots(
        n_rows,
        n_cols,
        shared_xaxes=True,
        shared_yaxes=True,
        vertical_spacing=0.03,
        horizontal_spacing=0.03,
    )
    symbol_mapping = {
        "physiological": "circle",
        "malperfused": "star",
        "icg": "triangle-down",
    }

    point_traces = []
    centroid_traces = []
    for col_idx, label_name in enumerate(labels):
        for species_name in df.species_name.unique():
            prev_xc = None
            prev_yc = None

            states = ["physiological", "malperfused"]
            if species_name != "human":
                states.append("icg")

            for perfusion_state in states:
                df_state = df[
                    (df.species_name == species_name)
                    & (df.label_name == label_name)
                    & (df.perfusion_state == perfusion_state)
                ]
                marker = dict(
                    color=settings_species.species_colors[species_name],
                    symbol=symbol_mapping[perfusion_state],
                )

                point_traces.append(
                    dict(
                        trace=go.Scatter(
                            x=df_state["pca_1"],
                            y=df_state["pca_2"],
                            mode="markers",
                            marker=marker,
                            name=f"{species_name}#{perfusion_state}",
                            legendgroup=species_name,
                            opacity=0.5 if show_malperfused or perfusion_state == "physiological" else 0,
                        ),
                        row=1,
                        col=col_idx + 1,
                    )
                )

                # Hierarchical aggregation of the centroids
                df_state_pca = df_state.groupby("subject_name")[["pca_1", "pca_2"]].mean()
                xc = df_state_pca["pca_1"].mean()
                yc = df_state_pca["pca_2"].mean()

                centroid_traces.append(
                    dict(
                        trace=go.Scatter(
                            x=[xc],
                            y=[yc],
                            mode="markers",
                            marker=marker,
                            name=f"{species_name}#{perfusion_state}",
                            legendgroup=species_name,
                            marker_size=15,
                            marker_line_width=1.5,
                            opacity=None if show_malperfused or perfusion_state == "physiological" else 0,
                        ),
                        row=1,
                        col=col_idx + 1,
                    )
                )

                if prev_xc is None:
                    prev_xc = xc
                    prev_yc = yc
                else:
                    fig.add_annotation(
                        ax=prev_xc,  # xstart
                        ay=prev_yc,  # ystart
                        x=xc,  # xend
                        y=yc,  # yend
                        xref="x1",
                        yref="y1",
                        axref="x1",
                        ayref="y1",
                        showarrow=True,
                        arrowhead=2,
                        arrowwidth=1.5,
                        arrowcolor="black",
                        text="",
                        opacity=None if show_malperfused else 0,
                    )

    # Make sure the mean traces are below the point traces
    for t in point_traces:
        fig.add_trace(**t)
    for t in centroid_traces:
        fig.add_trace(**t)

    fig.update_xaxes(title=f"<b>PC 1 ({pca.explained_variance_ratio_[0] * 100:.0f} %) [a.u.]</b>", row=1, col=1)
    fig.update_xaxes(title=f"<b>PC 1 ({pca.explained_variance_ratio_[0] * 100:.0f} %) [a.u.]</b>", row=1, col=2)
    fig.update_yaxes(title=f"<b>PC 2 ({pca.explained_variance_ratio_[1] * 100:.0f} %) [a.u.]</b>", row=1, col=1)

    fig.update_annotations(font=dict(size=20))
    for annotation in fig.layout.annotations:
        annotation.text = settings_rat.labels_paper_renaming.get(annotation.text, annotation.text)
    fig.update_layout(template="plotly_white", width=1200, height=600)
    fig.update_layout(font_family="Libertinus Sans", font_size=16)
    fig.update_layout(showlegend=False)

    return fig


fig = pca_figure()
fig.write_image(settings_species.paper_dir / "malperfused_pca_base.pdf")
fig

In [5]:
fig = pca_figure(show_malperfused=False)
fig.write_image(settings_species.paper_dir / "malperfused_pca_physiological_base.pdf")
fig