# Imports

In [None]:
import pandas as pd
import numpy as np
from plotly.subplots import make_subplots
from scipy import stats
import plotly.express as px
from scripts.python.routines.plot.scatter import add_scatter_trace
import plotly.graph_objects as go
from scripts.python.routines.plot.save import save_figure
from scripts.python.routines.plot.layout import add_layout, get_axis
from statsmodels.stats.multitest import multipletests
import plotly.io as pio
pio.kaleido.scope.mathjax = None
from plotly.offline import init_notebook_mode, iplot
init_notebook_mode(connected=False)
import matplotlib.pyplot as plt
import seaborn as sns
import pathlib
from sklearn.metrics import mean_absolute_error
import patchworklib as pw

# Init data

In [None]:
path_save = "E:/YandexDisk/Work/pydnameth/datasets/GPL21145/GSEUNN/special/033_immuno_ml_draft_figures"
pathlib.Path(f"{path_save}").mkdir(parents=True, exist_ok=True)

path_load = "E:/YandexDisk/Work/pydnameth/datasets/GPL21145/GSEUNN/special/021_ml_data/immuno"

fn = "260_imp(fast_knn)_replace(quarter)" # "df_type(minmax_left(0.05)_right(0.95)_combat)_ctrl(550)_imp(fast_knn)_replace(quarter)" # "260_imp(fast_knn)_replace(quarter)"

df = pd.read_excel(f"{path_load}/{fn}.xlsx", index_col="index")
feats = pd.read_excel(f"{path_load}/feats_con.xlsx", index_col="features").index.values

# df = df.loc[df["Region"] == "Central", :]

# ipAGE checking

In [None]:
df_ipage_model = pd.read_excel(f"E:/YandexDisk/Work/pydnameth/datasets/GPL21145/GSEUNN/special/011_immuno_part3_and_part4_check_clocks/legacy/Control/v22/clock.xlsx")
features = df_ipage_model['feature'].to_list()
coefs = df_ipage_model['coef'].to_list()

df['ipAGE'] = np.full(df.shape[0], coefs[0])
for feat_id in range(1, len(features)):
    df['ipAGE'] += df.loc[:, features[feat_id]].values * coefs[feat_id]

ipage_mae_all = mean_absolute_error(df['Age'], df['ipAGE'])
ipage_mae_trn = mean_absolute_error(df.loc[df['ipAGE_trn_set'] == True, 'Age'].values, df.loc[df['ipAGE_trn_set'] == True, 'ipAGE'].values)
ipage_mae_new = mean_absolute_error(df.loc[df['ipAGE_all_set'] == False, 'Age'].values, df.loc[df['ipAGE_all_set'] == False, 'ipAGE'].values)

# Data description (Participants) figures

In [None]:
cat_feat_colors = {"Sex": [("F", "red"), ("M", "blue")]}
feat_x = "Age"
bin_size = 5
for feat, fields in cat_feat_colors.items():
    fig = go.Figure()
    for val, color in fields:
        xs = df.loc[df[feat] == val, feat_x].values
        fig.add_trace(
            go.Histogram(
                x=xs,
                name=f"{val} ({len(xs)})",
                showlegend=True,
                marker=dict(
                    color=color,
                    opacity=0.75,
                    line=dict(
                        width=1,
                        color="black"
                    ),
                ),
                xbins=dict(size=bin_size)
            )
        )
    add_layout(fig, f"{feat_x}", "Count", "")
    fig.update_layout(
        margin=go.layout.Margin(l=90, r=20, b=75, t=50, pad=0),
        legend_font_size=20,
        legend={'itemsizing': 'constant'},
        barmode='overlay'
    )
    pathlib.Path(f"{path_save}/data_description_participants").mkdir(parents=True, exist_ok=True)
    save_figure(fig, f"{path_save}/data_description_participants/histogram_cont({feat_x})_cat({feat})")

    df_save = df.loc[:, [feat_x, feat]]
    df_save.to_excel(f"{path_save}/data_description_participants/histogram_cont({feat_x})_cat({feat}).xlsx", index=True)

# Data description (Features) figures

## Generate data for figure

In [None]:
feats_plot = ["Age"] + list(feats)
df_corr = pd.DataFrame(data=np.zeros(shape=(len(feats_plot), len(feats_plot))), index=feats_plot, columns=feats_plot)
for f_id_1 in range(len(feats_plot)):
    for f_id_2 in range(f_id_1, len(feats_plot)):
        f_1 = feats_plot[f_id_1]
        f_2 = feats_plot[f_id_2]
        if f_id_1 != f_id_2:
            vals_1 = df.loc[:, f_1].values
            vals_2 = df.loc[:, f_2].values
            corr, pval = stats.pearsonr(vals_1, vals_2)
            df_corr.at[f_2, f_1] = pval
            df_corr.at[f_1, f_2] = corr
        else:
            df_corr.at[f_2, f_1] = np.nan
selection = np.tri(df_corr.shape[0], df_corr.shape[1], -1, dtype=np.bool)
df_fdr = df_corr.where(selection).stack().reset_index()
df_fdr.columns = ['row', 'col', 'pval']
_, df_fdr['pval_fdr_bh'], _, _ = multipletests(df_fdr.loc[:, 'pval'].values, 0.05, method='fdr_bh')
df_corr_fdr = df_corr.copy()
for line_id in range(df_fdr.shape[0]):
    df_corr_fdr.loc[df_fdr.at[line_id, 'row'], df_fdr.at[line_id, 'col']] = -np.log10(df_fdr.at[line_id, 'pval_fdr_bh'])

## Plot correlation matrix

In [None]:
df_to_plot = df_corr_fdr.copy()
mtx_to_plot = df_to_plot.to_numpy()

mtx_triu = np.triu(mtx_to_plot, +1)
max_corr = np.max(mtx_triu)
min_corr = np.min(mtx_triu)
mtx_triu_mask = np.ma.masked_array(mtx_triu, mtx_triu==0)
cmap_triu = plt.get_cmap("bwr").copy()

mtx_tril = np.tril(mtx_to_plot, -1)
mtx_tril_mask = np.ma.masked_array(mtx_tril, mtx_tril==0)
cmap_tril = plt.get_cmap("viridis").copy()
cmap_tril.set_under('black')

fig, ax = plt.subplots()

im_triu = ax.imshow(mtx_triu_mask, cmap=cmap_triu, vmin=-1, vmax=1)
cbar_triu = ax.figure.colorbar(im_triu, ax=ax, location='right')
cbar_triu.set_label(r"$\mathrm{Correlation\:coefficient}$", horizontalalignment='center', fontsize=10)

im_tril = ax.imshow(mtx_tril_mask, cmap=cmap_tril, vmin=-np.log10(0.05))
cbar_tril = ax.figure.colorbar(im_tril, ax=ax, location='right')
cbar_tril.set_label(r"$-\log_{10}(\mathrm{p-value})$", horizontalalignment='center', fontsize=10)

ax.grid(None)
ax.set_aspect("equal")
ax.set_xticks(np.arange(df_to_plot.shape[1]))
ax.set_yticks(np.arange(df_to_plot.shape[0]))
ax.set_xticklabels(df_to_plot.columns.values)
ax.set_yticklabels(df_to_plot.index.values)
plt.setp(ax.get_xticklabels(), rotation=90)
threshold = np.ptp(mtx_tril.flatten()) * 0.5
ax.tick_params(axis='both', which='major', labelsize=5)
ax.tick_params(axis='both', which='minor', labelsize=5)
textcolors = ("black", "white")
for i in range(df_to_plot.shape[0]):
    for j in range(df_to_plot.shape[1]):
        color = "black"
        if i > j:
            color = textcolors[int(mtx_tril[i, j] < threshold)]
        if np.isinf(mtx_to_plot[i, j]) or np.isnan(mtx_to_plot[i, j]):
            text = ax.text(j, i, f"", ha="center", va="center", color=color, fontsize=1.3)
        else:
            text = ax.text(j, i, f"{mtx_to_plot[i, j]:0.2f}", ha="center", va="center", color=color, fontsize=1.3)
fig.tight_layout()
pathlib.Path(f"{path_save}/data_description_features").mkdir(parents=True, exist_ok=True)
plt.savefig(f"{path_save}/data_description_features/corr_mtx_fdr.png", bbox_inches='tight', dpi=400)
plt.savefig(f"{path_save}/data_description_features/corr_mtx_fdr.pdf", bbox_inches='tight', dpi=400)
plt.clf()

df_save = df_corr_fdr
df_save.to_excel(f"{path_save}/data_description_features/corr_mtx_fdr.xlsx", index=True)

## Scatter plot of biomarkers vs age

In [None]:
feats_plot = list(feats)
num_cols = 5
num_rows = int(np.ceil(len(feats_plot) / num_cols))
sns.set_style("whitegrid")
fig, axes = plt.subplots(num_rows, num_cols, figsize=(16,22))
for r_id in range(num_rows):
    for c_id in range(num_cols):
        rc_id = r_id * num_cols + c_id
        if rc_id < len(feats_plot):
            f = feats_plot[rc_id]

            y_beg = 0
            y_pctl = np.percentile(df.loc[:, f].values, [99])[0]
            y_max = np.max(df.loc[:, f].values)
            if y_max > 2*y_pctl:
                y_end = y_pctl * 1.3
            else:
                y_end = y_max

            sns.kdeplot(
                ax=axes[r_id, c_id],
                data=df, x="Age",
                y=f,
                #hue="Sex",
                #palette={"F": "red", "M": "blue"},
                legend=False,
                clip=((0, 110), (y_beg, y_end)),
                fill=True,
                cbar=True,
                cmap='magma',
                cbar_kws={'format': '%0.2e'}
            )
        else:
            axes[r_id, c_id].axis('off')

fig.tight_layout()
pathlib.Path(f"{path_save}/data_description_features").mkdir(parents=True, exist_ok=True)
plt.savefig(f"{path_save}/data_description_features/kdes_feats_vs_Age.png", bbox_inches='tight')
plt.savefig(f"{path_save}/data_description_features/kdes_feats_vs_Age.pdf", bbox_inches='tight')
plt.clf()

df_save = df.loc[:, ["Age"] + list(feats)]
df_save.to_excel(f"{path_save}/data_description_features/kdes_feats_vs_Age.xlsx", index=True)

In [None]:
feats_plot = list(feats)
num_cols = 5
num_rows = int(np.ceil(len(feats_plot) / num_cols))

fig = make_subplots(rows=num_rows, cols=num_cols, shared_yaxes=False)

for r_id in range(num_rows):
    for c_id in range(num_cols):
        rc_id = r_id * num_cols + c_id
        if rc_id < len(feats_plot):
            f = feats_plot[rc_id]

            if rc_id == 0:
                show_legend=True
            else:
                show_legend=False

            x_f = df.loc[df["Sex"] == "F", 'Age'].values
            y_f = df.loc[df["Sex"] == "F", f].values
            x_m = df.loc[df["Sex"] == "M", 'Age'].values
            y_m = df.loc[df["Sex"] == "M", f].values

            y_beg = 0
            y_pctl = np.percentile(df.loc[:, f].values, [90])[0]
            y_max = np.max(df.loc[:, f].values)
            if y_max > 2*y_pctl:
                y_end = y_pctl * 1.3
            else:
                y_end = y_max

            fig.add_trace(
                go.Scatter(
                    x=x_f,
                    y=y_f,
                    showlegend=show_legend,
                    name=f"F ({len(x_f)})",
                    mode='markers',
                    marker=dict(
                        size=10,
                        opacity=0.7,
                        line=dict(
                            width=0.1
                        ),
                        color='red'
                    )
                ),
                row=r_id + 1,
                col=c_id + 1
            )
            fig.add_trace(
                go.Scatter(
                    x=x_m,
                    y=y_m,
                    showlegend=show_legend,
                    name=f"M ({len(x_m)})",
                    mode='markers',
                    marker=dict(
                        size=10,
                        opacity=0.7,
                        line=dict(
                            width=0.1
                        ),
                        color='blue'
                    )
                ),
                row=r_id + 1,
                col=c_id + 1
            )
            fig.update_xaxes(
                autorange=False,
                title_text="Age",
                range=[10, 100],
                row=r_id + 1,
                col=c_id + 1,
                showgrid=True,
                zeroline=False,
                linecolor='black',
                showline=True,
                gridcolor='gainsboro',
                gridwidth=0.05,
                mirror=True,
                ticks='outside',
                titlefont=dict(
                    color='black',
                    size=20
                ),
                showticklabels=True,
                tickangle=0,
                tickfont=dict(
                    color='black',
                    size=20
                ),
                exponentformat='e',
                showexponent='all'
            )
            fig.update_yaxes(
                autorange=False,
                title_text=f"{f}",
                range=[y_beg, y_end],
                row=r_id + 1,
                col=c_id + 1,
                showgrid=True,
                zeroline=False,
                linecolor='black',
                showline=True,
                gridcolor='gainsboro',
                gridwidth=0.05,
                mirror=True,
                ticks='outside',
                titlefont=dict(
                    color='black',
                    size=20
                ),
                showticklabels=True,
                tickangle=0,
                tickfont=dict(
                    color='black',
                    size=20
                ),
                exponentformat='e',
                showexponent='all'
            )

fig.update_layout(
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=1.01,
        xanchor="center",
        x=0.5,
        itemsizing='constant',
        font_size=50
    ),
    title=dict(
        text="",
        font=dict(size=25)
    ),
    template="none",
    autosize=False,
    width=3000,
    height=4000,
    margin=go.layout.Margin(
        l=100,
        r=40,
        b=100,
        t=100,
        pad=0
    ),
)
pathlib.Path(f"{path_save}/data_description_features").mkdir(parents=True, exist_ok=True)
save_figure(fig, f"{path_save}/data_description_features/scatters_feats_vs_Age")
save_figure(fig, f"{path_save}/data_description_features/scatters_feats_vs_Age")

# Feature selection and dimensionality reduction

## Frequency of features occurring in Top-10

In [None]:
path_models = "E:/YandexDisk/Work/pydnameth/datasets/GPL21145/GSEUNN/special/021_ml_data/immuno/models/feature_selection"
models_best = {
    "danet": "2022-09-17_15-16-02",
    "lightgbm": "2022-09-17_17-14-05",
    "widedeep_tab_net": "2022-09-17_14-34-04"
}
n_top = 10

df_feats_freqs = pd.DataFrame(data=np.zeros(len(feats)), index=feats, columns=["Frequency"])
df_feats_freqs.index.name = "Features"

for model, date_time in models_best.items():
    df_feat_imps = pd.read_excel(f"{path_models}/{model}/{date_time}/feature_importances_cv.xlsx", index_col="fold")
    df_feat_imps = df_feat_imps.T
    for fold_id in df_feat_imps.columns.values:
        df_feat_imps.sort_values(by=fold_id, ascending=False, inplace=True)
        for top_id in range(n_top):
            df_feats_freqs.at[df_feat_imps.index.values[top_id], "Frequency"] += 1
df_feats_freqs.sort_values(by="Frequency", ascending=False, inplace=True)

In [None]:
plt.figure(figsize=(34, 10))
plt.xticks(rotation=90)
sns.set_theme(style='white', font_scale=3)
sns.barplot(data=df_feats_freqs, x=df_feats_freqs.index, y="Frequency")
pathlib.Path(f"{path_save}/feature_importance").mkdir(parents=True, exist_ok=True)
plt.savefig(f"{path_save}/feature_importance/bar.png", bbox_inches='tight')
plt.savefig(f"{path_save}/feature_importance/bar.pdf", bbox_inches='tight')
df_feats_freqs.to_excel(f"{path_save}/feature_importance/bar.xlsx", index=True)

## Plot supplementary correlation scatter matrix

In [None]:
feats_top10 = pd.read_excel(f"{path_load}/feats_con_top10.xlsx", index_col="features").index.values

feats_plot = ["Age"] + list(feats_top10)
sns.set_theme(style="whitegrid", font_scale=2)
# pair_grid = sns.PairGrid(df, hue='Sex', vars=feats_plot, palette={"F": "red", "M": "blue"})
pair_grid = sns.PairGrid(df, vars=feats_plot)
pair_grid.map_upper(sns.scatterplot, color='darkred', s=100, alpha=0.5)
pair_grid.map_lower(sns.kdeplot, fill=True, cbar=False, cmap='rocket_r', thresh=-0.1)
pair_grid.map_diag(plt.hist, bins=15, color='darkred', edgecolor='k')
pair_grid.add_legend()
for ax in pair_grid.axes.flatten():
    ax.get_yaxis().set_label_coords(-0.5, 0.5)
plt.tick_params(axis='both')
plt.savefig(f"{path_save}/feature_importance/corr_scatter_mtx.png", bbox_inches='tight')
plt.savefig(f"{path_save}/feature_importance/corr_scatter_mtx.pdf", bbox_inches='tight')
plt.clf()
df_save = df.loc[:, ["Age"] + list(feats_top10)]
df_save.to_excel(f"{path_save}/feature_importance/corr_scatter_mtx.xlsx", index=True)

## Trying to combine two plots in sns

In [None]:
feats_top10 = pd.read_excel(f"{path_load}/feats_con_top10.xlsx", index_col="features").index.values

pw.overwrite_axisgrid()

sns.set_theme(style='white', font_scale=3)
df_feats_freqs['Features'] = df_feats_freqs.index
df_feats_freqs['Color'] = 'pink'
df_feats_freqs.loc[df_feats_freqs["Features"].isin(feats_top10), 'Color'] = 'red'
g0 = sns.FacetGrid(df_feats_freqs)
g0.map(sns.barplot, 'Features', "Frequency", palette=df_feats_freqs['Color'].values, order=df_feats_freqs['Features'].values)
for ax in g0.axes.flat:
    for label in ax.get_xticklabels():
        label.set_rotation(90)
for ax in g0.axes.flat:
    ax.text(-4, 75,'(a)', fontsize=72)
    ax.text(-4, -23,'(b)', fontsize=72)
g0 = pw.load_seaborngrid(g0, figsize=(30, 10))

feats_plot = ["Age"] + list(feats_top10)
sns.set_theme(style="whitegrid", font_scale=2.5)
g1 = sns.PairGrid(df, vars=feats_plot)
g1.map_diag(plt.hist, bins=15, color='darkred', edgecolor='k')
g1.map_upper(sns.scatterplot, color='darkred', s=100, alpha=0.5)
g1.map_lower(sns.kdeplot, fill=True, cbar=False, cmap='rocket_r', thresh=-0.1)
g1.add_legend()
for ax in g1.axes.flatten():
    ax.get_yaxis().set_label_coords(-0.6, 0.5)
g1 = pw.load_seaborngrid(g1, figsize=(30, 30))

(g0/g1).savefig(f"{path_save}/feature_importance/together.pdf")

# Plots for SImAge

In [None]:
df_simage = pd.read_excel(f"{path_load}/models/small/danet/32_inference/df.xlsx", index_col="index")

parts = {"trn": "Train", "val": "Test"}
ptp = np.ptp(df_simage.loc[:, "Age"].values)
bin_size = ptp / 15
fig = go.Figure()
for part_id, part_name in parts.items():
    fig.add_trace(
        go.Histogram(
            x=df_simage.loc[df_simage["part"] == part_id, "Age"].values,
            name=part_name,
            showlegend=True,
            marker=dict(
                opacity=0.75,
                line=dict(
                    width=1
                ),
            ),
            xbins=dict(size=bin_size)
        )
    )
add_layout(fig, f"Age", "Count", "")
fig.update_layout(margin=go.layout.Margin(l=90, r=20, b=100, t=50, pad=0), width=800, height=600)
fig.update_layout(legend_font_size=24)
fig.update_xaxes(autorange=False)
fig.update_xaxes(range=[12, 100])
fig.update_layout({'colorway': ["lime", "magenta"]}, barmode='overlay', legend={'itemsizing': 'constant'})
pathlib.Path(f"{path_save}/SImAge").mkdir(parents=True, exist_ok=True)
save_figure(fig, f"{path_save}/SImAge/Histogram")

fig = go.Figure()
add_scatter_trace(fig, [12, 100], [12, 100], "", mode="lines")
add_scatter_trace(fig, df_simage.loc[df_simage["part"] == "trn", "Age"].values, df_simage.loc[df_simage["part"] == "trn", "Estimation"].values, f"Train", size=13)
add_scatter_trace(fig, df_simage.loc[df_simage["part"] == "val", "Age"].values, df_simage.loc[df_simage["part"] == "val", "Estimation"].values, f"Test", size=13)
add_layout(fig, "Age", f"SImAge", f"")
fig.update_layout({'colorway': ["black", "lime", "magenta"]}, legend={'itemsizing': 'constant'})
fig.update_layout(legend_font_size=24)
fig.update_xaxes(autorange=False)
fig.update_xaxes(range=[12, 100])
fig.update_yaxes(autorange=False)
fig.update_yaxes(range=[12, 100])
fig.update_layout(margin=go.layout.Margin(l=90, r=20, b=100, t=50, pad=0), width=800, height=600)
pathlib.Path(f"{path_save}/SImAge").mkdir(parents=True, exist_ok=True)
save_figure(fig, f"{path_save}/SImAge/Scatter")

# Plots for SHAP

In [None]:
feats_top10 = pd.read_excel(f"{path_load}/feats_con_top10.xlsx", index_col="features").index.values
df_simage = pd.read_excel(f"E:/YandexDisk/Work/pydnameth/draft/06_somewhere/models/small/danet/32_inference/df.xlsx", index_col="index")
df_simage = df_simage.loc[:, ["Age", "Estimation"] + list(feats_top10)]
df_simage.rename(columns={x: f"{x}_values" for x in feats_top10}, inplace=True)
df_simage["Diff"] = df_simage["Estimation"] - df_simage["Age"]
df_shap = pd.read_excel(f"E:/YandexDisk/Work/pydnameth/draft/06_somewhere/models/shap/1/shap/all/shap.xlsx", index_col="index")
df_shap.rename(columns={x: f"{x}_shap" for x in feats_top10}, inplace=True)
df_shap = pd.merge(df_shap, df_simage, left_index=True, right_index=True)
pathlib.Path(f"{path_save}/SHAP").mkdir(parents=True, exist_ok=True)
df_shap.to_excel(f"{path_save}/SHAP/data.xlsx", index=True)

lim = 6.28

df_shap.loc[(df_shap["Diff"] < lim) & (df_shap["Diff"] > -lim), "Part"] = "|Acceleration| < MAE"
df_shap.loc[df_shap["Diff"] > lim, "Part"] = "Acceleration > MAE"
df_shap.loc[df_shap["Diff"] < -lim, "Part"] = "Acceleration < -MAE"
df_shap["Acceleration"] = df_shap["Diff"]

palette = {
    "|Acceleration| < MAE": 'green',
    "Acceleration > MAE": 'red',
    "Acceleration < -MAE": 'blue'
}

sns.set_theme(style='whitegrid')
sns.histplot(data=df_shap, x="Acceleration", hue="Part", palette=palette, multiple="stack")
pathlib.Path(f"{path_save}/SHAP").mkdir(parents=True, exist_ok=True)
plt.savefig(f"{path_save}/SHAP/hist.png", bbox_inches='tight')
plt.savefig(f"{path_save}/SHAP/hist.pdf", bbox_inches='tight')
plt.clf()

diff_sign = {
    'neutral': df_shap.loc[df_shap["Part"] == "|Acceleration| < MAE", :],
    'positive': df_shap.loc[df_shap["Part"] == "Acceleration > MAE", :],
    'negative': df_shap.loc[df_shap["Part"] == "Acceleration < -MAE", :],
}
colors_bar = {
    'neutral': 'green',
    'positive': 'red',
    'negative': 'blue',
}

for sign in diff_sign:

    shap_mean_abs = []
    for feat in feats_top10:
        shap_mean_abs.append(np.mean(np.abs(diff_sign[sign].loc[:, f"{feat}_shap"].values)))

    order = np.argsort(shap_mean_abs)
    feats_sorted = feats_top10[order]
    shap_mean_abs = np.array(shap_mean_abs)[order]

    fig = go.Figure()
    for feat_id, feat in enumerate(feats_sorted):
        showscale = True if feat_id == 0 else False
        xs = diff_sign[sign].loc[:, f"{feat}_shap"].values
        colors = diff_sign[sign].loc[:, f"{feat}_values"].values

        N = len(xs)
        row_height = 0.40
        nbins = 20
        xs = list(xs)
        xs = np.array(xs, dtype=float)
        quant = np.round(nbins * (xs - np.min(xs)) / (np.max(xs) - np.min(xs) + 1e-8))
        inds = np.argsort(quant + np.random.randn(N) * 1e-6)
        layer = 0
        last_bin = -1
        ys = np.zeros(N)
        for ind in inds:
            if quant[ind] != last_bin:
                layer = 0
            ys[ind] = np.ceil(layer / 2) * ((layer % 2) * 2 - 1)
            layer += 1
            last_bin = quant[ind]
        ys *= 0.9 * (row_height / np.max(ys + 1))
        ys = feat_id + ys

        fig.add_trace(
            go.Scatter(
                x=xs,
                y=ys,
                showlegend=False,
                mode='markers',
                marker=dict(
                    size=12,
                    opacity=0.5,
                    line=dict(
                        width=0.00
                    ),
                    color=colors,
                    colorscale=px.colors.sequential.Rainbow,
                    showscale=showscale,
                    colorbar=dict(
                        title=dict(text="", font=dict(size=26)),
                        tickfont=dict(size=26),
                        tickmode="array",
                        tickvals=[min(colors), max(colors)],
                        ticktext=["Min", "Max"],
                        x=1.03,
                        y=0.5,
                        len=0.99
                    )
                ),
            )
        )

    add_layout(fig, "SHAP values", "", f"")
    fig.update_layout(legend_font_size=20)
    fig.update_layout(showlegend=False)
    fig.update_xaxes(zeroline=True, zerolinewidth=1, zerolinecolor='grey')
    fig.update_layout(
        yaxis=dict(
            tickmode='array',
            tickvals=list(range(len(feats_sorted))),
            ticktext=feats_sorted,
            showticklabels=False
        )
    )
    fig.update_yaxes(autorange=False)
    fig.update_layout(yaxis_range=[-0.5, len(feats_sorted) - 0.5])
    fig.update_yaxes(tickfont_size=26)
    fig.update_xaxes(tickfont_size=26)
    fig.update_xaxes(title_font_size=26)
    fig.update_xaxes(nticks=6)
    fig.update_layout(
        xaxis=dict(showgrid=False),
        yaxis=dict(showgrid=False)
    )
    fig.update_layout(
        autosize=False,
        width=700,
        height=800,
        margin=go.layout.Margin(
            l=20,
            r=100,
            b=80,
            t=20,
            pad=0
        )
    )
    save_figure(fig, f"{path_save}/SHAP/{sign}_beeswarm")

    fig = go.Figure()
    fig.add_trace(
        go.Bar(
            x=shap_mean_abs,
            y=list(range(len(shap_mean_abs))),
            orientation='h',
            marker=dict(color=colors_bar[sign], opacity=1.0)
        )
    )
    add_layout(fig, "Mean(|SHAP values|)", "", f"")
    fig.update_layout(legend_font_size=20)
    fig.update_layout(showlegend=False)
    fig.update_layout(
        yaxis=dict(
            tickmode='array',
            tickvals=list(range(len(feats_sorted))),
            ticktext=feats_sorted
        )
    )
    fig.update_yaxes(autorange=False)
    fig.update_layout(yaxis_range=[-0.5, len(feats_sorted) - 0.5])
    fig.update_yaxes(tickfont_size=26)
    fig.update_xaxes(tickfont_size=26)
    fig.update_xaxes(title_font_size=26)
    fig.update_xaxes(nticks=6)
    fig.update_layout(
        autosize=False,
        width=500,
        height=800,
        margin=go.layout.Margin(
            l=120,
            r=100,
            b=80,
            t=20,
            pad=0
        )
    )
    save_figure(fig, f"{path_save}/SHAP/{sign}_bar")

    diff_sign[sign].to_excel(f"{path_save}/SHAP/{sign}_data.xlsx", index=True)