In [None]:
import numpy as np
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go

import pandas as pd
import joblib
import seaborn as sns
import matplotlib.pyplot as plt

from src import settings
from src.visualization.plot import line
from src.data.utils import get_organ_labels
from src.visualization.templates import cmap_qualitative, cmap_quantitative_list
from src.data.utils import get_label_mapping, get_pa_label_mapping
from src.utils.susi import ExperimentResults

In [None]:
def prepare_data(df: pd.DataFrame):
    to_replace = {
        'simulated_sampled': 'simulated',
        'small_bowel': 'small bowel',
        'inn': 'cINN',
        'unit': 'UNIT',
        'real - inn': 'real - cINN',
        'real - unit': 'real - UNIT',
    }
    to_rename = {
        'wavelength': 'wavelength [nm]',
        'reflectance': 'reflectance [a.u.]',
        'dataset': 'data',
        'source': 'data',
        'pai_signal': 'PA signal [a.u.]',
        'waic': 'WAIC [a.u.]'
    }
    new_df = df.replace(to_replace, inplace=False)
    new_df.rename(to_rename, inplace=True, axis=1)
    to_rename.update(to_replace)
    return new_df, to_rename

# plot HSI spectra and classification metrics

In [None]:
spectra_file = settings.figures_dir / 'semantic_reflectance.csv'
accuracy_file = settings.results_dir / 'rf' / 'rf_classifier_metrics.csv'
metric = ('AUC', 'per_class_auroc')
metrics_df = pd.read_csv(accuracy_file, index_col=None, header=[0])
metrics_df = metrics_df[metrics_df.data != 'real']
metrics_df, _ = prepare_data(metrics_df)
df = pd.read_csv(spectra_file)
df = df[df.organ != 'gallbladder']

agg = df.groupby([('wavelength'), 'dataset', 'organ']).agg({'reflectance': ['mean', 'std']}).reset_index()
data = agg.copy()
data.drop('reflectance', axis=1, inplace=True, level=0)
data['reflectance'] = agg[('reflectance', 'mean')]
data['sd'] = agg[('reflectance', 'std')]
data, mapper = prepare_data(data)

organs = data.organ.unique()
subplot_titles = [(o, metric[0]) for o in organs]
subplot_titles = np.array(subplot_titles).flatten()
fig_specs = [[{"type": "xy"}, {"type": "bar"}] * 5,
             [{"type": "xy"}, {"type": "bar"}] * 5]
fig = make_subplots(
    rows=2, cols=10,
    specs=fig_specs,
    subplot_titles=subplot_titles,
    horizontal_spacing=0.03
)

for i, organ in enumerate(organs):
    tmp_ac = metrics_df[(metrics_df.organ == organ)]
    base_value = tmp_ac[tmp_ac.data == 'simulated'][metric[1]].values[0]
    tmp_ac = tmp_ac[~tmp_ac['data'].isin(['simulated', 'real'])]
    metric_diff_value = []
    for source in ['real', 'simulated', 'cINN', 'UNIT']:
        tmp = data[(data.organ == organ) & (data['data'] == source)]
        x = tmp[mapper.get('wavelength')]
        y = tmp[mapper.get('reflectance')]
        tr = go.Scatter(
            x=x,
            y=y,
            line=dict(color=cmap_qualitative.get(source)),
            name=source,
            legendgroup=source,
            showlegend = True if i==0 else False

        )
        if (i%5)>0:
            col = 2*(i%5)+1
        else:
            col = 1
        if i < 5:
            row = 1
        else:
            row = 2
        fig.add_trace(trace=tr, row=row, col=col)


        tmp_ac2 = tmp_ac[tmp_ac.data == source]
        x = tmp_ac2['data']
        y = tmp_ac2[metric[1]].values
        y -= base_value
        if y.size:
            metric_diff_value.append(float(y))
        tr_ac = go.Bar(
            x=x,
            y=y,
            name=source,
            marker_color=cmap_qualitative.get(source),
            legendgroup=source,
            showlegend=False,
            width=0.5
        )
        if (i%5)>0:
            col = 2*(i%5) + 2
        else:
            col = 2
        if i > 4:
            row = 2
        else:
            row = 1
        fig.add_trace(trace=tr_ac, row=row, col=col)
        fig.add_hline(
            y=0,
            line_color=cmap_qualitative.get('simulated'),
            annotation_font={'color': cmap_qualitative.get('simulated')},
            line_dash="dot",
            annotation_text="",
            annotation_position="bottom right",
            annotation_font_size=10,
            annotation_font_color="black",
            row=row,
            col=col
            )
    y_range = (-max(np.abs(metric_diff_value)), max(np.abs(metric_diff_value)))
    fig.update_yaxes(range=y_range, row=row, col=col)
fig.update_layout(template="plotly_white", width=1000, height=400,
                  margin=dict(l=20, r=20, t=20, b=20))
fig.update_layout(font=dict(size=10, family="Libertinus Serif"))
fig.update_xaxes(title_font=dict(size=10, family="Libertinus Serif"), tickangle=90)
fig.update_yaxes(title_font=dict(size=10, family="Libertinus Serif"))
fig.write_image(settings.figures_dir / 'manuscript' / 'semantic_reflectance_acc.png', scale=3)
fig.write_html(settings.figures_dir / 'manuscript' / 'semantic_reflectance_acc.html')

In [None]:
metrics_file = settings.results_dir / 'rf' / 'rf_classifier_metrics.csv'
metrics_df = pd.read_csv(metrics_file, index_col=None, header=[0])
metrics_df, _ = prepare_data(metrics_df)
organs = metrics_df.organ.unique()
metric = ('AUC', 'per_class_auroc')
results = ExperimentResults()
metric_diff_value = []
for organ in organs:
    tmp = metrics_df[(metrics_df.organ == organ)]
    base_value = tmp[tmp.data == 'simulated'][metric[1]].values[0]
    values = []
    for j, source in enumerate(['cINN', 'UNIT']):
        metric_value = tmp[tmp['data'] == source][metric[1]].values[0]
        diff = metric_value - base_value
        values.append(float(diff))
        results.append(name="organ", value=organ)
        results.append(name="data", value=source)
        results.append(name=metric[0], value=diff)
    metric_diff_value.append(values)
df = results.get_df()
df = df[~df['data'].isin(['simulated', 'real'])]
df, _ = prepare_data(df)
fig = px.bar(
    data_frame=df,
    x="data",
    y=metric[0],
    color="data",
    template="plotly_white",
    facet_col="organ",
    facet_col_wrap=5,
    color_discrete_map=cmap_qualitative,
    facet_col_spacing=0.1,
    category_orders={'organ': list(organs)}
)
fig.for_each_yaxis(lambda yaxis: yaxis.update(showticklabels=True))
fig.update_layout(template="plotly_white",
                  # width=1000,
                  # height=400,
                  margin=dict(l=20, r=20, t=20, b=20)
                  )
for i, values in enumerate(metric_diff_value):
    if i>0:
        col = (i%5) + 1
    else:
        col = 1
    if i < 5:
        row = 2
    else:
        row = 1

    y_range = [-max(np.abs(values)), max(np.abs(values))]
    fig.update_yaxes(range=y_range, row=row, col=col)
fig.update_traces(width=0.5)
fig.update_layout(font=dict(size=12, family="Libertinus Serif"))
fig.update_xaxes(title_font=dict(size=12, family="Libertinus Serif"), tickangle=90)
fig.update_yaxes(title_font=dict(size=12, family="Libertinus Serif"), matches=None)
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fig.write_image(settings.figures_dir / 'manuscript' / f'semantic_{metric[0]}.png', scale=3)
fig.write_image(settings.figures_dir / 'manuscript' / f'semantic_{metric[0]}.pdf')
fig.write_html(settings.figures_dir / 'manuscript' / f'semantic_{metric[0]}.html')

# plot HSI spectra

In [None]:
spectra_file = settings.figures_dir / 'semantic_reflectance.csv'
df = pd.read_csv(spectra_file)
df = df[df.organ != 'gallbladder']
organs_to_plot = ['spleen', 'liver', 'colon', 'omentum', 'small_bowel']
organs = get_organ_labels()['organ_labels']
organs = [o for o in organs if o != 'gallbladder']
df = df[df['organ'].isin(organs)]
organs = [' '.join(o.split('_')) for o in organs]
df_prepared, mapper = prepare_data(df)
fig, plot_data = line(
    data_frame=df_prepared,
    x=mapper.get("wavelength"),
    y=mapper.get('reflectance'),
    facet_col="organ",
    color=mapper.get("dataset"),
    facet_col_wrap=5,
    template="plotly_white",
    width=800,
    height=400,
    category_orders=dict(organ=organs, data=['real', 'simulated', 'UNIT', 'cINN']),
    facet_row_spacing=0.2,
    facet_col_spacing=0.05,
    color_discrete_map=cmap_qualitative,
    range_x=(900, 1000),
    range_y=(0.005, 0.015)
)

fig.update_layout(font=dict(size=12, family="Libertinus Serif"),
                  margin=dict(l=20, r=20, t=20, b=20))
fig.update_xaxes(title_font=dict(size=12, family="Libertinus Serif"))
fig.update_yaxes(title_font=dict(size=12, family="Libertinus Serif"))
# fig.for_each_yaxis(lambda yaxis: yaxis.update(showticklabels=True))
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fig.write_image(settings.figures_dir / 'manuscript' / 'semantic_reflectance.pdf')
fig.write_image(settings.figures_dir / 'manuscript' / 'semantic_reflectance.png', scale=2)

# plot HSI confusion matrices

In [None]:
stages = [
    'real',
    'simulated',
    'cINN',
    'UNIT'
]
stage_mapper = dict(real='real', sampled='simulated', adapted_inn='cINN', unit='UNIT')
mapping = get_label_mapping()
for stage in stages:
    file = settings.results_dir / 'rf' / f"rf_classifier_matrix_{stage}.npz"
    data = np.load(file)
    matrix = data['matrix']
    labels = data['labels']
    names = [mapping.get(str(l)) for l in labels]
    names = [' '.join(n.split('_')) for n in names]
    fig = px.imshow(matrix,
                    text_auto='.2f',
                    color_continuous_scale=cmap_quantitative_list.get(stage),
                    zmin=0,
                    zmax=1,
                    template='plotly_white',
                    labels={'small_bowel': 'small bowel'}
                    )
    axis_ticks = dict(
            tickmode='array',
            tickvals=np.arange(0, len(names)),
            ticktext=names
        )
    fig.update_layout(
        xaxis=axis_ticks,
        yaxis=axis_ticks,
        coloraxis_colorbar=dict(
            title="probability",
            x=0.85,
            ticks="outside",
            ticksuffix="",
        )
    )
    fig.update_xaxes(title="predicted class", title_font=dict(size=16, family="Libertinus Serif"))
    fig.update_yaxes(title="true class", title_font=dict(size=16, family="Libertinus Serif"))
    fig.update_layout(font=dict(size=16, family="Libertinus Serif"),
                      margin=dict(l=10, r=0, t=10, b=10))
    fig.write_image(settings.figures_dir / 'manuscript' / f'semantic_rf_confusion_matrix_{stage}.pdf')
    fig.write_image(settings.figures_dir / 'manuscript' / f'semantic_rf_confusion_matrix_{stage}.png', scale=2)

# plot HSI spectral differences

In [None]:
df = pd.read_csv(settings.figures_dir / 'semantic_diff.csv')
organs_to_plot = ['spleen', 'liver', 'colon', 'omentum', 'small_bowel']
organs = get_organ_labels()['organ_labels']
organs = [o for o in organs if o != 'gallbladder']
df = df[df['organ'].isin(organs)]
organs = [' '.join(o.split('_')) for o in organs]
df['difference [%]'] *= 100
df_prepared, mapper = prepare_data(df)
fig = px.violin(data_frame=df_prepared,
             x="data",
             y="difference [%]",
             color="data",
             facet_col="organ",
             facet_col_wrap=5,
             color_discrete_map=cmap_qualitative,
             template="plotly_white",
             category_orders=dict(organ=organs, data=['real - simulated', 'real - UNIT', 'real - cINN']),
             facet_row_spacing=0.2,
             facet_col_spacing=0.05,
             width=800,
             height=400
             )
fig.update_traces(scalemode='width', meanline_visible=True)
TRACE_INDEX = []
def split_scale_group(tr):
    if TRACE_INDEX:
        TRACE_INDEX.append(TRACE_INDEX[-1] + 1)
    else:
        TRACE_INDEX.append(0)
    tr.scalegroup = TRACE_INDEX[-1]
    return tr
fig.for_each_trace(split_scale_group)
fig.update_layout(font=dict(size=12, family="Libertinus Serif"),
                  margin=dict(l=20, r=20, t=20, b=20)
                  )
fig.update_xaxes(title_font=dict(size=12, family="Libertinus Serif"))
fig.update_yaxes(title_font=dict(size=12, family="Libertinus Serif"), range=(0, 50))
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fig.write_image(settings.figures_dir / 'manuscript' / 'semantic_diff.pdf')
fig.write_image(settings.figures_dir / 'manuscript' / 'semantic_diff.png', scale=2)

# plot HSI PCA

In [None]:
df = pd.read_csv(settings.figures_dir  / 'semantic_pca.csv')
# df = df[df['dataset'] != 'unit']
df, mapper = prepare_data(df)
sns.set_style('whitegrid', {"grid.color": "ebf0f8ff", "grid.linewidth": 1})
plt.rcParams["font.family"] = "serif"
plt.rcParams["font.serif"] = ["Libertinus Serif"]


for organ in df.organ.unique():
    model_file = settings.results_dir / 'pca' / f"semantic_pca_{'_'.join(organ.split(' '))}.joblib"
    model = joblib.load(model_file)
    tmp = df[df['organ'] == organ].copy()
    tmp = tmp.rename({'pc_1': f"PC 1 [{round(model.explained_variance_ratio_[0]*100)}%]",
                       'pc_2': f"PC 2 [{round(model.explained_variance_ratio_[1]*100)}%]"},
                     axis=1)
    g = sns.jointplot(data=tmp,
                      x=f"PC 1 [{round(model.explained_variance_ratio_[0]*100)}%]",
                      y=f"PC 2 [{round(model.explained_variance_ratio_[1]*100)}%]",
                      hue="data",
                      kind="kde",
                      fill=True,
                      alpha=0.4,
                      marginal_kws={'common_norm': False},
                      palette=cmap_qualitative,
                      levels=10)
    # plt.xlim(-0.3, 0.3)
    # plt.ylim(-0.1, 0.1)
    sns.despine(left=True, bottom=True)
    plt.tight_layout()
    plt.savefig(settings.figures_dir / 'manuscript' / 'semantic_pca' / f'semantic_pca_{organ}.pdf')
    plt.savefig(settings.figures_dir / 'manuscript' / 'semantic_pca' / f'semantic_pca_{organ}.png', dpi=300)
    plt.clf()

# plot PAI signal

In [None]:
df = pd.read_csv(settings.figures_dir / 'pai_signal.csv')
df, mapper = prepare_data(df)
for tissue in df.tissue.unique():
    tmp = df[df.tissue == tissue]
    fig, _ = line(data_frame=tmp,
                  x=mapper.get("wavelength"),
                  y=mapper.get('pai_signal'),
                  facet_col=None,
                  color="data",
                  template="plotly_white",
                  color_discrete_map=cmap_qualitative,
                  width=800,
                  height=400,
                  category_orders=dict(data=['real', 'simulated', 'UNIT', 'cINN'])
               )
    fig.update_layout(font=dict(size=12, family="Libertinus Serif"),
                      legend=dict(orientation="h", xanchor="center", x=0.5, y=1),
                      margin=dict(l=20, r=10, t=10, b=20)
                      )
    fig.update_xaxes(title_font=dict(size=12, family="Libertinus Serif"))
    fig.update_yaxes(title_font=dict(size=12, family="Libertinus Serif"))
    fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
    fig.write_image(settings.figures_dir / 'manuscript' / f'pai_signal_{tissue}.pdf')
    fig.write_image(settings.figures_dir / 'manuscript' / f'pai_signal_{tissue}.png', scale=2)

# plot PAI signal differences

In [None]:
df = pd.read_csv(settings.figures_dir / 'pai_diff.csv')
df['difference [%]'] *= 100
df, mapper = prepare_data(df)
# fig, _ = line(data_frame=df,
#               x=mapper.get("wavelength"),
#               y="difference [%]",
#               facet_col="tissue",
#               color="data",
#               template="plotly_white",
#               facet_col_spacing=0.05,
#               color_discrete_map=cmap_qualitative,
#               category_orders=dict(tissue=['vein', 'artery'], data=['real - simulated', 'real - UNIT', 'real - cINN'])
#            )
tissues = df.tissue.unique()
for tissue in tissues:
    tmp = df[df.tissue==tissue]
    fig = px.violin(data_frame=tmp,
                 x="data",
                 y="difference [%]",
                 color="data",
                 color_discrete_map=cmap_qualitative,
                 template="plotly_white",
                 category_orders=dict(tissue=['vein', 'artery'], data=['real - simulated', 'real - UNIT', 'real - cINN']),
                 facet_col_spacing=0.05,
                 )
    fig.update_traces(scalemode='width', meanline_visible=True)
    TRACE_INDEX = []
    def split_scale_group(tr):
        if TRACE_INDEX:
            TRACE_INDEX.append(TRACE_INDEX[-1] + 1)
        else:
            TRACE_INDEX.append(0)
        tr.scalegroup = TRACE_INDEX[-1]
        return tr
    fig.for_each_trace(split_scale_group)
    fig.update_layout(font=dict(size=12, family="Libertinus Serif"))
    fig.update_xaxes(title_font=dict(size=12, family="Libertinus Serif"), tickangle=90)
    fig.update_yaxes(title_font=dict(size=12, family="Libertinus Serif"))
    fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
    fig.write_image(settings.figures_dir / 'manuscript' / f'pai_diff_{tissue}.pdf')
    fig.write_image(settings.figures_dir / 'manuscript' / f'pai_diff_{tissue}.png', scale=2)

# plot PAI PCA

In [None]:
df = pd.read_csv(settings.results_dir / 'pca' / 'pai_pca.csv')
# df = df[df['data'] != 'unit']
df, mapper = prepare_data(df)
sns.set_style('whitegrid', {"grid.color": "ebf0f8ff", "grid.linewidth": 1})
plt.rcParams["font.family"] = "serif"
plt.rcParams["font.serif"] = ["Libertinus Serif"]


for tissue in df.tissue.unique():
    model_file = settings.results_dir / 'pca' / f'pai_pca_{tissue}.joblib'
    model = joblib.load(model_file)
    tmp = df[df['tissue'] == tissue].copy()
    tmp = tmp.rename({'pc_1': f"PC 1 [{round(model.explained_variance_ratio_[0]*100)}%]",
                       'pc_2': f"PC 2 [{round(model.explained_variance_ratio_[1]*100)}%]"},
                     axis=1)
    g = sns.jointplot(data=tmp,
                      x=f"PC 1 [{round(model.explained_variance_ratio_[0]*100)}%]",
                      y=f"PC 2 [{round(model.explained_variance_ratio_[1]*100)}%]",
                      hue="data",
                      kind="kde",
                      fill=True,
                      alpha=0.4,
                      marginal_kws={'common_norm': False},
                      palette=cmap_qualitative,
                      levels=10)
    plt.xlim(-0.3, 0.3)
    plt.ylim(-0.1, 0.1)
    sns.despine(left=True, bottom=True)
    plt.tight_layout()
    plt.savefig(settings.figures_dir / 'manuscript' / f'pai_pca_{tissue}.pdf')
    plt.savefig(settings.figures_dir / 'manuscript' / f'pai_pca_{tissue}.png', dpi=300)

# plot PAI confusion matrices

In [None]:
stages = [
    'real',
    'simulated',
    'cINN',
    'UNIT'
]
mapping = get_pa_label_mapping()
for stage in stages:
    file = settings.results_dir / 'rf_pa' / f"rf_pa_classifier_matrix_{stage}.npz"
    data = np.load(file)
    matrix = data['matrix']
    labels = data['labels']
    names = [mapping.get(l) for l in labels]
    names = [' '.join(n.split('_')) for n in names]
    fig = px.imshow(matrix,
                    text_auto='.2f',
                    color_continuous_scale=cmap_quantitative_list.get(stage),
                    zmin=0,
                    zmax=1,
                    template='plotly_white',
                    labels={'small_bowel': 'small bowel'}
                    )
    axis_ticks = dict(
            tickmode='array',
            tickvals=np.arange(0, len(names)),
            ticktext=names
        )
    fig.update_layout(
        xaxis=axis_ticks,
        yaxis=axis_ticks,
        coloraxis_colorbar=dict(
            title=dict(text="", side="right"),
            x=0.85,
            ticks="outside",
            ticksuffix="",
        )
    )
    fig.update_xaxes(title="predicted class", title_font=dict(size=40, family="Libertinus Serif"))
    fig.update_yaxes(title="true class", title_font=dict(size=40, family="Libertinus Serif"))
    fig.update_layout(font=dict(size=40, family="Libertinus Serif"))
    fig.write_image(settings.figures_dir / 'manuscript' / f'pai_rf_confusion_matrix_pa_{stage}.pdf')
    fig.write_image(settings.figures_dir / 'manuscript' / f'pai_rf_confusion_matrix_pa_{stage}.png', scale=2)