# Imports

In [None]:
import pandas as pd
import numpy as np
from plotly.subplots import make_subplots
from scipy import stats
import plotly.express as px
from scripts.python.routines.plot.scatter import add_scatter_trace
import plotly.graph_objects as go
from scripts.python.routines.plot.save import save_figure
from scripts.python.routines.plot.layout import add_layout, get_axis
from statsmodels.stats.multitest import multipletests
import plotly.io as pio
pio.kaleido.scope.mathjax = None
from plotly.offline import init_notebook_mode, iplot
init_notebook_mode(connected=False)
import matplotlib.pyplot as plt
import seaborn as sns
import pathlib
from sklearn.metrics import mean_absolute_error
import patchworklib as pw

# Init data

In [None]:
path_save = "E:/YandexDisk/Work/pydnameth/datasets/GPL21145/GSEUNN/special/038_tai_report_immuno"
pathlib.Path(f"{path_save}").mkdir(parents=True, exist_ok=True)

path_load = "E:/YandexDisk/Work/pydnameth/datasets/GPL21145/GSEUNN/special/021_ml_data/immuno"

fn = "260_imp(fast_knn)_replace(quarter)"

df = pd.read_excel(f"{path_load}/{fn}.xlsx", index_col="index")
feats = pd.read_excel(f"{path_load}/feats_con.xlsx", index_col="features").index.values

# # Creating inference file with "data_part" field

In [None]:
parts_df = pd.read_excel(f"E:/YandexDisk/Work/pydnameth/draft/06_somewhere/models/baseline/k_5/widedeep_tab_net/1/30/predictions.xlsx", index_col="index")
df.loc[df.index, "data_part"] = parts_df.loc[df.index, parts_df.columns[0]]
df.to_excel(f"{path_save}/inference.xlsx", index=True)

# Plots for best_model

In [None]:
df_best = pd.read_excel(f"{path_load}/models/immuno_inference_widedeep_tab_net/runs/2022-11-09_13-31-51/df.xlsx", index_col="index")

parts = {"trn": "Train", "val": "Test"}
ptp = np.ptp(df_best.loc[:, "Age"].values)
bin_size = ptp / 15
fig = go.Figure()
for part_id, part_name in parts.items():
    fig.add_trace(
        go.Histogram(
            x=df_best.loc[df_best["data_part"] == part_id, "Age"].values,
            name=part_name,
            showlegend=True,
            marker=dict(
                opacity=0.75,
                line=dict(
                    width=1
                ),
            ),
            xbins=dict(size=bin_size)
        )
    )
add_layout(fig, f"Age", "Count", "")
fig.update_layout(margin=go.layout.Margin(l=90, r=20, b=100, t=50, pad=0), width=800, height=600)
fig.update_layout(legend_font_size=24)
fig.update_xaxes(autorange=False)
fig.update_xaxes(range=[12, 100])
fig.update_layout({'colorway': ["lime", "magenta"]}, barmode='overlay', legend={'itemsizing': 'constant'})
pathlib.Path(f"{path_save}/best_model").mkdir(parents=True, exist_ok=True)
save_figure(fig, f"{path_save}/best_model/Histogram")

fig = go.Figure()
add_scatter_trace(fig, [12, 100], [12, 100], "", mode="lines")
add_scatter_trace(fig, df_best.loc[df_best["data_part"] == "trn", "Age"].values, df_best.loc[df_best["data_part"] == "trn", "Estimation"].values, f"Train", size=13)
add_scatter_trace(fig, df_best.loc[df_best["data_part"] == "val", "Age"].values, df_best.loc[df_best["data_part"] == "val", "Estimation"].values, f"Test", size=13)
add_layout(fig, "Age", f"SImAge", f"")
fig.update_layout({'colorway': ["black", "lime", "magenta"]}, legend={'itemsizing': 'constant'})
fig.update_layout(legend_font_size=24)
fig.update_xaxes(autorange=False)
fig.update_xaxes(range=[12, 100])
fig.update_yaxes(autorange=False)
fig.update_yaxes(range=[12, 100])
fig.update_layout(margin=go.layout.Margin(l=90, r=20, b=100, t=50, pad=0), width=800, height=600)
pathlib.Path(f"{path_save}/best_model").mkdir(parents=True, exist_ok=True)
save_figure(fig, f"{path_save}/best_model/Scatter")

# Plots for SHAP

## Global explainability

In [None]:
df_best = pd.read_excel(f"{path_load}/models/immuno_inference_widedeep_tab_net/runs/2022-11-09_20-16-40/df.xlsx", index_col="index")
df_shap = pd.read_excel(f"{path_load}/models/immuno_inference_widedeep_tab_net/runs/2022-11-09_20-16-40/shap/all/shap.xlsx", index_col="index")
df_shap.rename(columns={x: f"{x}_shap" for x in feats}, inplace=True)

df_feats_importance = pd.DataFrame(data=np.zeros(len(feats)), index=feats, columns=["Mean(|SHAP|)"])
df_feats_importance.index.name = "Features"

shap_mean_abs = []
for feat in feats:
    shap_mean_abs.append(np.mean(np.abs(df_shap.loc[:, f"{feat}_shap"].values)))
df_feats_importance["Mean(|SHAP|)"] = shap_mean_abs

df_feats_importance.sort_values(by="Mean(|SHAP|)", ascending=False, inplace=True)
plt.figure(figsize=(34, 10))
plt.xticks(rotation=90)
sns.set_theme(style='white', font_scale=3)
sns.barplot(data=df_feats_importance, x=df_feats_importance.index, y="Mean(|SHAP|)")
pathlib.Path(f"{path_save}/feature_importance").mkdir(parents=True, exist_ok=True)
plt.savefig(f"{path_save}/feature_importance/bar.png", bbox_inches='tight')
plt.savefig(f"{path_save}/feature_importance/bar.pdf", bbox_inches='tight')

## Local explainability

In [None]:
df_best = pd.read_excel(f"{path_load}/models/immuno_inference_widedeep_tab_net/runs/2022-11-09_20-16-40/df.xlsx", index_col="index")
df_best = df_best.loc[:, ["Age", "Estimation"] + list(feats)]
df_best.rename(columns={x: f"{x}_values" for x in feats}, inplace=True)
df_best["Diff"] = df_best["Estimation"] - df_best["Age"]
df_shap = pd.read_excel(f"{path_load}/models/immuno_inference_widedeep_tab_net/runs/2022-11-09_20-16-40/shap/all/shap.xlsx", index_col="index")
df_shap.rename(columns={x: f"{x}_shap" for x in feats}, inplace=True)
df_shap = pd.merge(df_shap, df_best, left_index=True, right_index=True)
pathlib.Path(f"{path_save}/SHAP").mkdir(parents=True, exist_ok=True)
df_shap.to_excel(f"{path_save}/SHAP/data.xlsx", index=True)

neg, pos = np.quantile(df_best["Diff"], [0.1, 0.9])
print(f"Pos: {pos}, neg: {neg}")

diff_sign = {
    'positive': df_shap.loc[df_shap["Diff"] > pos, :],
    'negative': df_shap.loc[df_shap["Diff"] < neg, :],
}

n_top_features = 10

for sign in diff_sign:

    shap_mean_abs = []
    for feat in feats:
        shap_mean_abs.append(np.mean(np.abs(diff_sign[sign].loc[:, f"{feat}_shap"].values)))

    order = np.argsort(shap_mean_abs)[len(shap_mean_abs) - n_top_features : len(shap_mean_abs)]
    feats_sorted = feats[order]
    shap_mean_abs = np.array(shap_mean_abs)[order]

    fig = go.Figure()
    for feat_id, feat in enumerate(feats_sorted):
        showscale = True if feat_id == 0 else False
        xs = diff_sign[sign].loc[:, f"{feat}_shap"].values
        colors = diff_sign[sign].loc[:, f"{feat}_values"].values

        N = len(xs)
        row_height = 0.40
        nbins = 20
        xs = list(xs)
        xs = np.array(xs, dtype=float)
        quant = np.round(nbins * (xs - np.min(xs)) / (np.max(xs) - np.min(xs) + 1e-8))
        inds = np.argsort(quant + np.random.randn(N) * 1e-6)
        layer = 0
        last_bin = -1
        ys = np.zeros(N)
        for ind in inds:
            if quant[ind] != last_bin:
                layer = 0
            ys[ind] = np.ceil(layer / 2) * ((layer % 2) * 2 - 1)
            layer += 1
            last_bin = quant[ind]
        ys *= 0.9 * (row_height / np.max(ys + 1))
        ys = feat_id + ys

        fig.add_trace(
            go.Scatter(
                x=xs,
                y=ys,
                showlegend=False,
                mode='markers',
                marker=dict(
                    size=12,
                    opacity=0.5,
                    line=dict(
                        width=0.00
                    ),
                    color=colors,
                    colorscale=px.colors.sequential.Rainbow,
                    showscale=showscale,
                    colorbar=dict(
                        title=dict(text="", font=dict(size=26)),
                        tickfont=dict(size=26),
                        tickmode="array",
                        tickvals=[min(colors), max(colors)],
                        ticktext=["Min", "Max"],
                        x=1.03,
                        y=0.5,
                        len=0.99
                    )
                ),
            )
        )

    add_layout(fig, "SHAP values", "", f"")
    fig.update_layout(legend_font_size=20)
    fig.update_layout(showlegend=False)
    fig.update_xaxes(zeroline=True, zerolinewidth=1, zerolinecolor='grey')
    fig.update_layout(
        yaxis=dict(
            tickmode='array',
            tickvals=list(range(len(feats_sorted))),
            ticktext=feats_sorted,
            showticklabels=False
        )
    )
    fig.update_yaxes(autorange=False)
    fig.update_layout(yaxis_range=[-0.5, len(feats_sorted) - 0.5])
    fig.update_yaxes(tickfont_size=26)
    fig.update_xaxes(tickfont_size=26)
    fig.update_xaxes(title_font_size=26)
    fig.update_xaxes(nticks=6)
    fig.update_layout(
        xaxis=dict(showgrid=False),
        yaxis=dict(showgrid=False)
    )
    fig.update_layout(
        autosize=False,
        width=700,
        height=800,
        margin=go.layout.Margin(
            l=20,
            r=100,
            b=80,
            t=20,
            pad=0
        )
    )
    save_figure(fig, f"{path_save}/SHAP/{sign}_beeswarm")

    fig = go.Figure()
    fig.add_trace(
        go.Bar(
            x=shap_mean_abs,
            y=list(range(len(shap_mean_abs))),
            orientation='h',
            marker=dict(color='red', opacity=1.0)
        )
    )
    add_layout(fig, "Mean(|SHAP values|)", "", f"")
    fig.update_layout(legend_font_size=20)
    fig.update_layout(showlegend=False)
    fig.update_layout(
        yaxis=dict(
            tickmode='array',
            tickvals=list(range(len(feats_sorted))),
            ticktext=feats_sorted
        )
    )
    fig.update_yaxes(autorange=False)
    fig.update_layout(yaxis_range=[-0.5, len(feats_sorted) - 0.5])
    fig.update_yaxes(tickfont_size=26)
    fig.update_xaxes(tickfont_size=26)
    fig.update_xaxes(title_font_size=26)
    fig.update_xaxes(nticks=6)
    fig.update_layout(
        autosize=False,
        width=500,
        height=800,
        margin=go.layout.Margin(
            l=120,
            r=100,
            b=80,
            t=20,
            pad=0
        )
    )
    save_figure(fig, f"{path_save}/SHAP/{sign}_bar")