In [1]:
%load_ext autoreload
%autoreload 2

import math

import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.io as pio
from plotly.subplots import make_subplots

from htc import settings
from htc_projects.rat.settings_rat import settings_rat
from htc_projects.species.settings_species import settings_species

pio.kaleido.scope.mathjax = None

In [2]:
df = pd.read_csv(settings.results_dir / "lmm" / "lmm_results" / "species_organ.csv", index_col=0)
labels_csv = [
    "bladder",
    "bone",
    "cartilage",
    "colon",
    "heart",
    "kidney",
    "kidney_with_Gerotas_fascia",
    "liver",
    "lung",
    "major_vein",
    "muscle",
    "omentum",
    "pancreas",
    "peritoneum",
    "skin",
    "small_bowel",
    "spleen",
    "stomach",
]
df["label_name"] = np.repeat(np.asarray(labels_csv), len(df) // len(settings_rat.label_mapping))
df

Unnamed: 0,species.,speciesCI1,speciesCI2,ID.,IDCI1,IDCI2,residual1,residualCI1,residualCI2,angle.,angleCI1,angleCI2,Image.,ImageCI1,ImageCI2,label_name
1,0.731657,0.614897,0.821545,0.156110,0.074239,0.256507,0.003312,0.002328,0.004771,-0.000840,-0.000956,0.002681,0.109761,0.076042,0.154653,bladder
2,0.738877,0.606598,0.833324,0.152888,0.070219,0.265966,0.002903,0.002055,0.004014,-0.000679,-0.000916,0.002999,0.106011,0.075163,0.148739,bladder
3,0.738121,0.623657,0.830096,0.153330,0.066363,0.254280,0.002817,0.002019,0.004055,-0.000531,-0.000930,0.003290,0.106263,0.073759,0.154788,bladder
4,0.724865,0.581785,0.829512,0.163288,0.069282,0.286547,0.002772,0.001954,0.004063,-0.000383,-0.000891,0.003764,0.109457,0.075511,0.158953,bladder
5,0.692320,0.539458,0.805955,0.187822,0.092600,0.311463,0.002856,0.002008,0.004358,-0.000233,-0.001016,0.004788,0.117235,0.082169,0.178515,bladder
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1796,0.733478,0.564680,0.839371,0.176173,0.074448,0.319300,0.002051,0.001427,0.003014,-0.000673,-0.000822,0.002474,0.088970,0.059993,0.134465,stomach
1797,0.726413,0.573894,0.835481,0.176902,0.077623,0.315207,0.002250,0.001566,0.003165,-0.000727,-0.000905,0.001800,0.095162,0.063789,0.140955,stomach
1798,0.700990,0.537082,0.816804,0.194214,0.087822,0.335236,0.002511,0.001734,0.003838,-0.000769,-0.000930,0.002920,0.103054,0.070836,0.153715,stomach
1799,0.701086,0.546453,0.804236,0.200830,0.107076,0.337253,0.002485,0.001741,0.003874,-0.000648,-0.000860,0.002218,0.096247,0.065857,0.150701,stomach


In [3]:
labels = settings_species.label_mapping_organs.label_names()
n_cols = 3
n_rows = math.ceil(len(labels) / n_cols)
fig = make_subplots(
    n_rows,
    n_cols,
    subplot_titles=labels,
    shared_yaxes="all",
    vertical_spacing=0.05,
    horizontal_spacing=0.03,
)
wavelengths = list(np.arange(500, 1000, 5) + 2.5)
factor_colors = {
    "species": "#66bbc4",
    "ID": "#A52A2A",
    "Image": "#8A2BE2",
    "angle": "#006400",
    "residual": "#A9A9A9",
}
factor_renaming = {
    "ID": "subject",
    "Image": "image",
    "residual": "residuals",
}

for i, label in enumerate(labels):
    row_idx = i // n_cols
    col_idx = i % n_cols
    df_label = df[df["label_name"] == label]

    for factor in factor_colors.keys():
        mid_line = df_label[f"{factor}1" if factor == "residual" else f"{factor}."].values
        top_line = df_label[f"{factor}CI2"].values
        bottom_line = df_label[f"{factor}CI1"].values

        upper_border = list(top_line)
        lower_border = list(bottom_line)
        lower_border.reverse()

        name = factor_renaming.get(factor, factor)
        fig.add_trace(
            go.Scatter(
                x=wavelengths,
                y=mid_line,
                name=name,
                line_color=factor_colors[factor],
                legendgroup=name,
                showlegend=i == 0,
            ),
            row=row_idx + 1,
            col=col_idx + 1,
        )
        fig.add_trace(
            go.Scatter(
                x=wavelengths + wavelengths[::-1],
                y=upper_border + lower_border,
                fill="toself",
                fillcolor=factor_colors[factor],
                line_color=factor_colors[factor],
                opacity=0.15,
                name=name,
                legendgroup=name,
                showlegend=False,
                hoverinfo="skip",
            ),
            row=row_idx + 1,
            col=col_idx + 1,
        )

    if col_idx == 0:
        fig.update_yaxes(
            title="<b>explained<br>variance [a.u.]</b>", title_standoff=0, row=row_idx + 1, col=col_idx + 1
        )
    if row_idx == n_rows - 1:
        fig.update_xaxes(title="<b>wavelength [nm]</b>", row=row_idx + 1, col=col_idx + 1)

fig.update_xaxes(showticklabels=False)
ticks = [600, 700, 800, 900]
for row, col in [(4, 1), (4, 2), (3, 3)]:
    fig.update_xaxes(
        tickmode="array",
        tickvals=ticks,
        ticktext=ticks,
        showticklabels=True,
        title="<b>wavelength [nm]</b>",
        row=row,
        col=col,
    )

fig.update_annotations(font=dict(size=20))
for i, annotation in enumerate(fig.layout.annotations):
    fig.layout.annotations[i].text = settings_rat.labels_paper_renaming.get(
        fig.layout.annotations[i].text, fig.layout.annotations[i].text
    )

fig.update_layout(height=800, width=1200, template="plotly_white")
fig.update_layout(font_family="Libertinus Sans", font_size=16, margin=dict(t=20))
fig.update_layout(showlegend=False)

fig.write_image(settings_species.paper_dir / "mixed_effect_analysis_base.pdf")
fig