In [1]:
%load_ext autoreload
%autoreload 2

from collections import ChainMap

import fastkde
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.io as pio
from matplotlib.colors import to_rgb
from plotly.subplots import make_subplots
from sklearn.decomposition import PCA

from htc.settings import settings
from htc.utils.colors import lighten_color
from htc.utils.parallel import p_map
from htc.utils.visualization import add_range_fill, compress_html
from htc_projects.species.settings_species import settings_species
from htc_projects.species.species_evaluation import parameter_comparison
from htc_projects.species.tables import ischemic_clear_table, ischemic_table

pio.kaleido.scope.mathjax = None

In [2]:
def load_species_data(target_species: str, source_species: str) -> dict[str, pd.DataFrame]:
    table_name = f"test_table_{target_species}_perfusion"
    run_baseline = (
        settings.training_dir / "image" / f"{settings_species.model_timestamp}_baseline_{target_species}_nested-*-2"
    )
    run_projected = (
        settings.training_dir
        / "image"
        / f"{settings_species.model_timestamp}_projected_{source_species}2{target_species}_nested-*-2"
    )

    df, df_raw = parameter_comparison(
        {
            "in-species": run_baseline,
            "xeno-learning": run_projected,
        },
        table_name,
        param_name,
        target_labels=settings_species.malperfused_labels,
    )

    return {
        (target_species, source_species): {
            "df": df,
            "df_raw": df_raw,
        }
    }


param_name = "median_sto2"
df_projections = pd.read_feather(settings_species.results_dir / "projections" / "projections_clear.feather")
labels = settings_species.malperfused_labels
df_ischemic = ischemic_clear_table().query("label_name in @labels")
df_ischemic = df_ischemic.groupby(
    ["subject_name", "species_name", "label_name", "perfusion_state"], as_index=False
).agg(
    median_normalized_spectrum=pd.NamedAgg(
        column="median_normalized_spectrum", aggfunc=lambda c: np.mean(np.stack(c), axis=0)
    ),
)

targets = ["pig", "rat", "human", "human"]
sources = ["rat", "pig", "rat", "pig"]
species_data = p_map(load_species_data, targets, sources, use_threads=True)
species_data = dict(ChainMap(*species_data))

df_all = ischemic_table().query("label_name in @labels")
df_projections = pd.read_feather(settings_species.results_dir / "projections" / "projections_clear.feather")

pca = PCA(n_components=2, whiten=True)
pca_projections = pca.fit_transform(np.stack(df_all["median_normalized_spectrum"]))
df_all["pca_1"] = pca_projections[:, 0]
df_all["pca_2"] = pca_projections[:, 1]

pca_projections = pca.transform(np.stack(df_projections["median_normalized_spectrum"]))
df_projections["pca_1"] = pca_projections[:, 0]
df_projections["pca_2"] = pca_projections[:, 1]

Output()

In [3]:
network_colors = {
    "xeno-learning": settings_species.xeno_learning_color,
}

fig = make_subplots(
    4,
    3,
    vertical_spacing=0.07,
    horizontal_spacing=0.06,
    column_widths=[2, 2, 1],
    specs=[
        [{}, {}, {"l": -0.01}],
        [{}, {}, {"l": -0.01}],
        [{}, {}, {"l": -0.01}],
        [{}, {}, {"l": -0.01}],
    ],
)
wavelengths = np.arange(500, 1000, 5) + 2.5


def add_species_learning(target_species: str, source_species: str, row: int):
    df = species_data[(target_species, source_species)]["df"]
    df_raw = species_data[(target_species, source_species)]["df_raw"]

    species_color = settings_species.species_colors[target_species]
    network_colors["in-species"] = species_color

    for network in df["network"].unique():
        df_network = df[df["network"] == network]
        df_network_raw = df_raw[df_raw["network"] == network]
        if network not in network_colors:
            continue
        color = network_colors[network]

        for i, subject_name in enumerate(sorted(df_network_raw["subject_name"].unique())):
            df_subject = df_network_raw[df_network_raw["subject_name"] == subject_name]
            fig.add_trace(
                go.Scatter(
                    x=df_subject[param_name],
                    y=df_subject["dice_metric"],
                    hovertext=df_subject["image_name"],
                    marker_color=color,
                    # The first 13 markers are okish to distinguish but after that it gets difficult
                    # We start again after the first 13 and use open markers (+100)
                    marker_symbol=i if i < 13 else i % 13 + 100,
                    marker_size=8,
                    mode="markers",
                    name=network,
                    opacity=0.15,
                ),
                row=row,
                col=1,
            )

        mid_line = df_network[f"dice_metric_{param_name}_mean"].values.astype(float)
        valid = np.isfinite(mid_line)
        mid_line = mid_line[valid]
        add_range_fill(
            fig=fig,
            x=df_network[param_name][valid],
            top_line=df_network[f"dice_metric_{param_name}_q975"][valid],
            mid_line=mid_line,
            bottom_line=df_network[f"dice_metric_{param_name}_q025"][valid],
            color=color,
            row=row,
            col=1,
        )

    df_states = df_projections[
        (df_projections.source_species == source_species) & (df_projections.target_species == target_species)
    ].copy()
    df_states["perfusion_state"] = "projected"
    df_states["species_name"] = df_states["target_species"]
    df_states = pd.concat([df_states, df_ischemic], ignore_index=True)

    # Draw randomly selected transformed spectra lines
    df_projected = df_states[df_states.perfusion_state == "projected"]
    n_random_spectra = 100
    np.random.seed(0)
    for i in np.random.randint(0, len(df_projected), n_random_spectra):
        random_spectra = df_projected.iloc[i].median_normalized_spectrum
        fig.add_trace(
            go.Scatter(
                x=wavelengths,
                y=random_spectra,
                mode="lines",
                line_color=network_colors["xeno-learning"],
                line_dash="dash",
                showlegend=False,
                hoverinfo="skip",
                opacity=0.1,
            ),
            row=row,
            col=2,
        )

    for state in ["physiological", "malperfused"]:
        dfs = df_states[df_states.perfusion_state == state]

        spec = np.stack(dfs["median_normalized_spectrum"])  # Spectra from all subjects
        spec_mean = np.mean(spec, axis=0)
        spec_std = np.std(spec, axis=0)

        color = network_colors["xeno-learning"] if state == "projected" else species_color
        dash = "solid" if state == "physiological" else "dot"

        # Median spectra
        add_range_fill(
            fig=fig,
            x=wavelengths,
            top_line=spec_mean + spec_std,
            mid_line=spec_mean,
            bottom_line=spec_mean - spec_std,
            color=color,
            line_dash=dash,
            name=state,
            row=row,
            col=2,
        )

    # PCA plot
    df_ps = df_projections[
        (df_projections.source_species == source_species) & (df_projections.target_species == target_species)
    ]
    myPDF, (v1, v2) = fastkde.pdf(
        df_ps["pca_1"].values, df_ps["pca_2"].values, var_names=["x", "y"], use_xarray=False, num_points_per_sigma=50
    )

    color_rgb = to_rgb(network_colors["xeno-learning"])
    fig.add_trace(
        go.Heatmap(
            x=v1,
            y=v2,
            z=myPDF,
            colorscale=[
                [0, f"rgba({color_rgb[0]}, {color_rgb[1]}, {color_rgb[2]}, 0)"],
                [0.01, f"rgba({color_rgb[0]}, {color_rgb[1]}, {color_rgb[2]}, 0.09)"],
                [1, f"rgba({color_rgb[0]}, {color_rgb[1]}, {color_rgb[2]}, 0.09)"],
            ],
            showlegend=True,
            showscale=False,
        ),
        row=row,
        col=3,
    )

    df_as = df_all[(df_all.species_name == target_species) & (df_all.image_name.isin(df_raw.image_name))]
    param_min = df_as[param_name].min()
    param_max = df_as[param_name].max()
    max_lighten = 0.65
    print(target_species, lighten_color(species_color, max_lighten))
    colors = [
        lighten_color(species_color, max_lighten * (1 - (x - param_min) / (param_max - param_min)))
        for x in df_as[param_name]
    ]

    df_raw_xeno = df_raw[df_raw.network == "xeno-learning"].set_index("image_name", verify_integrity=True)

    fig.add_trace(
        go.Scatter(
            x=df_as["pca_1"],
            y=df_as["pca_2"],
            text=[
                f"{param_name}: {x[param_name]:.2f}<br>phase_type: {x['phase_type']}<br>dice_metric:"
                f" {df_raw_xeno.loc[x['image_name'], 'dice_metric']:.2f}"
                for _, x in df_as.iterrows()
            ],
            mode="markers",
            marker_color=colors,
            marker_size=4,
        ),
        row=row,
        col=3,
    )
    margin = 0.2
    fig.update_xaxes(
        range=[
            min(df_ps["pca_1"].min(), df_as["pca_1"].min()) - margin,
            max(df_ps["pca_1"].max(), df_as["pca_1"].max()) + margin,
        ],
        row=row,
        col=3,
    )
    fig.update_yaxes(
        range=[
            min(df_ps["pca_2"].min(), df_as["pca_2"].min()) - margin,
            max(df_ps["pca_2"].max(), df_as["pca_2"].max()) + margin,
        ],
        row=row,
        col=3,
    )
    fig.update_xaxes(
        title=f"<b>PC 1 ({pca.explained_variance_ratio_[0] * 100:.0f} %) [a.u.]</b>", title_standoff=4, row=row, col=3
    )
    fig.update_yaxes(
        title=f"<b>PC 2 ({pca.explained_variance_ratio_[1] * 100:.0f} %) [a.u.]</b>", title_standoff=2, row=row, col=3
    )


add_species_learning(target_species="pig", source_species="rat", row=1)
add_species_learning(target_species="rat", source_species="pig", row=2)
add_species_learning(target_species="human", source_species="rat", row=3)
add_species_learning(target_species="human", source_species="pig", row=4)

fig.update_layout(showlegend=False)
fig.update_yaxes(title="<b>DSC</b>", title_standoff=12, col=1)
fig.update_yaxes(title="<b>L1 normalized reflectance [a.u.]</b>", title_font_size=16, title_standoff=0, col=2)
fig.update_xaxes(title="<b>StO<sub>2</sub> [a.u.]</b>", showticklabels=True, title_standoff=4, col=1)
fig.update_xaxes(title="<b>wavelength [nm]</b>", title_standoff=4, col=2)
fig.update_xaxes(title_standoff=9, row=4)
fig.update_xaxes(range=[df_all[param_name].min(), df_all[param_name].max()], col=1)
fig.update_layout(height=1300, width=1200, template="plotly_white")
fig.update_layout(font_family="Libertinus Sans", font_size=17)
fig.update_annotations(font=dict(size=21))
fig.update_layout(margin=dict(l=0, r=0, t=22, b=0))
fig.write_image(settings_species.paper_dir / "perfusion_performance.pdf")
compress_html(settings_species.paper_dir / "perfusion_performance.html", fig)

pig #bee1db
rat #d6edf9
human #f3edcf
human #f3edcf
