In [None]:
import pandas as pd
import numpy as np
from plotly.subplots import make_subplots
from scipy import stats
import plotly.express as px
from scripts.python.routines.plot.scatter import add_scatter_trace
import plotly.graph_objects as go
from scripts.python.routines.plot.save import save_figure
from scripts.python.routines.plot.layout import add_layout, get_axis
from statsmodels.stats.multitest import multipletests
import plotly.io as pio
pio.kaleido.scope.mathjax = None
from plotly.offline import init_notebook_mode, iplot
init_notebook_mode(connected=False)
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
import seaborn as sns
from glob import glob
import pathlib
from sklearn.decomposition import PCA, TruncatedSVD
from sklearn.manifold import MDS
from openTSNE import TSNE
from sklearn.metrics import mean_absolute_error
from scipy import stats
import patchworklib as pw
import os
import functools
from scipy.stats import iqr
from statannotations.Annotator import Annotator
from scipy.stats import mannwhitneyu
import shap
from slugify import slugify
from src.models.tabular.widedeep.ft_transformer import WDFTTransformerModel
from art.estimators.regression.pytorch import PyTorchRegressor
from art.estimators.regression.blackbox import BlackBoxRegressor
from art.attacks.evasion import ProjectedGradientDescentNumpy
from art.defences.detector.evasion.binary_input_detector import BinaryInputDetector
import torch
from src.tasks.metrics import get_cls_pred_metrics, get_cls_prob_metrics, get_reg_metrics
import matplotlib.lines as mlines


def conjunction(conditions):
    return functools.reduce(np.logical_and, conditions)


def disjunction(conditions):
    return functools.reduce(np.logical_or, conditions)

# 1. Adversarial examples for immunology data

## 1.1. Preparing data, model, functions for black-boxes

In [None]:
path = "D:/YandexDisk/Work/pydnameth/datasets/GPL21145/GSEUNN/special"
path_model = f"{path}/044_small_immuno_clocks_revision/models/10_trn_val_tst/widedeep_ft_transformer_trn_val_tst/multiruns/2023-05-07_19-40-40_1337/64"
path_save = f"{path}/046_adversarial_robustness_toolbox/immuno"
pathlib.Path(f"{path_save}").mkdir(parents=True, exist_ok=True)
df = pd.read_excel(f"{path}/044_small_immuno_clocks_revision/figure_simage/df.xlsx", index_col=0)
feats = pd.read_excel(f"{path}/044_small_immuno_clocks_revision/feats_con_top10_new.xlsx", index_col=0).index.values
ids_feat = list(range(len(feats)))
target = 'Age'

df_preds = pd.read_excel(f"{path_model}/predictions.xlsx", index_col=0)
ids_trn = df_preds.index[df_preds['fold_0002'] == 'trn'].values
ids_val = df_preds.index[df_preds['fold_0002'] == 'val'].values
ids_tst = df_preds.index[df_preds['fold_0002'] == 'tst_ctrl_central'].values
ids_all = list(set.union(set(ids_trn), set(ids_val), set(ids_tst)))
ids_trn_val = list(set.union(set(ids_trn), set(ids_val)))
df = df.loc[ids_all, :]

df_X = df.loc[ids_all, feats]

eps = np.array([0.4 * iqr(df.loc[ids_trn, feat].values) for feat in feats])
eps_step = np.array([0.1 * iqr(df.loc[ids_trn, feat].values) for feat in feats])

model = WDFTTransformerModel.load_from_checkpoint(checkpoint_path=f"{path_model}/best_fold_0002.ckpt")
model.eval()
model.freeze()

def predict_func_regression(X):
    model.produce_probabilities = True
    batch = {
        'all': torch.from_numpy(np.float32(X[:, ids_feat])),
        'continuous': torch.from_numpy(np.float32(X[:, ids_feat])),
        'categorical': torch.from_numpy(np.int32(X[:, []])),
    }
    tmp = model(batch)
    return tmp.cpu().detach().numpy()

# We don't need optimizer, because model already trained
art_regressor = PyTorchRegressor(
    model=model,
    loss=model.loss_fn,
    input_shape=[len(feats)],
    optimizer=torch.optim.Adam(
        params=model.parameters(),
        lr=model.hparams.optimizer_lr,
        weight_decay=model.hparams.optimizer_weight_decay
    ),
    use_amp=False,
    opt_level="O1",
    loss_scale="dynamic",
    channels_first=True,
    clip_values=None,
    preprocessing_defences=None,
    postprocessing_defences=None,
    preprocessing=(0.0, 1.0),
    device_type="cpu",
)

## 1.2. Dimensionality reduction models

In [None]:
dim_red_labels = {
    'PCA': ['PC 1', 'PC 2'],
    'SVD': ['SVD 1', 'SVD 2'],
    't-SNE': ['t-SNE 1', 't-SNE 2'],
}
dim_red_models = {}
data_dim_red = df.loc[ids_trn_val, feats].values

pca = PCA(n_components=2, whiten=False)
pca.fit(data_dim_red)
dim_red_models['PCA'] = pca
svd = TruncatedSVD(n_components=2, algorithm='randomized', n_iter=5)
svd.fit(data_dim_red)
dim_red_models['SVD'] = svd
tsne = TSNE(n_components=2)
tsne_emb = tsne.fit(data_dim_red)
dim_red_models['t-SNE'] = tsne_emb

## 1.3. Evasion with defences

In [None]:
ids_trgt = ids_all

attack = ProjectedGradientDescentNumpy(
    estimator=art_regressor,
    norm=np.inf,
    eps=eps,
    eps_step=eps_step,
    decay=None,
    max_iter=100,
    targeted=False,
    num_random_init=0,
    batch_size=512,
    random_eps=False,
    summary_writer= False,
    verbose=True
)

X = df_X.loc[ids_trgt, :].values
X_adv = attack.generate(np.float32(df_X.loc[ids_trgt, :].values))

# Save adversarial and clean samples for binary input detector
df_adv = pd.DataFrame(data=attack.generate(np.float32(df_X.loc[ids_all, :].values)), columns=feats, index=ids_all)
df_adv.loc[ids_all, 'Dataset'] = df.loc[ids_all, 'Dataset']
df_adv['Index'] = [f"{sample}_adv" for sample in ids_all]
df_adv.set_index('Index', inplace=True)
df_adv['DataType'] = 'Adversarial'
df_cln = df.loc[ids_all, list(feats) + ['Dataset']]
df_cln['DataType'] = 'Clean'
df_detector = pd.concat([df_cln, df_adv])
df_detector.to_excel(f"{path_save}/data_detector.xlsx", index_label='index')

n_bins = 100
n_bckgrnd = 1000000
df_bckgrnd = pd.DataFrame(columns=feats)
for feat in feats:
    f_vals = df.loc[ids_trn_val, feat].values
    counts, bin_edges = np.histogram(df.loc[ids_trn_val, feat].values, bins=n_bins)
    df_bckgrnd[feat] = np.random.choice(bin_edges[:-1], size=n_bckgrnd, p=counts/len(f_vals))
X_bckgrnd = df_bckgrnd.loc[:, feats].values
df_bckgrnd["SImAge"] = model(torch.from_numpy(np.float32(X_bckgrnd))).cpu().detach().numpy().ravel()

y_real = np.float32(df.loc[ids_trgt, target].values)
y_pred = model(torch.from_numpy(X)).cpu().detach().numpy().ravel()
y_pred_adv = model(torch.from_numpy(X_adv)).cpu().detach().numpy().ravel()

metrics = get_reg_metrics()
df_metrics = pd.DataFrame(index=[m for m in metrics], columns=['Origin', 'PGD'])
for m in metrics:
    m_val = float(metrics[m][0](torch.from_numpy(y_pred), torch.from_numpy(y_real)).numpy())
    df_metrics.at[m, 'Origin'] = m_val
    metrics[m][0].reset()
    m_val = float(metrics[m][0](torch.from_numpy(y_pred_adv), torch.from_numpy(y_real)).numpy())
    df_metrics.at[m, 'PGD'] = m_val
    metrics[m][0].reset()
df_metrics.to_excel(f"{path_save}/metrics.xlsx")

df_dim_red = df.loc[ids_trgt, ['Age']].copy()
df_dim_red.loc[ids_trgt, "SImAge"] = y_pred
df_dim_red.loc[ids_trgt, "Error"] = df_dim_red.loc[ids_trgt, "SImAge"].values - df_dim_red.loc[ids_trgt, "Age"].values
df_dim_red_adv = df.loc[ids_trgt, ['Age']].copy()
df_dim_red_adv.loc[ids_trgt, "SImAge"] = y_pred_adv
df_dim_red_adv.loc[ids_trgt, "Error"] = df_dim_red_adv.loc[ids_trgt, "SImAge"].values - df_dim_red_adv.loc[ids_trgt, "Age"].values
df_dim_red_adv['index'] = df_dim_red_adv.index.values + '_adv'
df_dim_red_adv.set_index('index', inplace=True)

for m, drm in dim_red_models.items():
    dim_red_res = drm.transform(X)
    df_dim_red.loc[:, dim_red_labels[m][0]] = dim_red_res[:, 0]
    df_dim_red.loc[:, dim_red_labels[m][1]] = dim_red_res[:, 1]
    dim_red_res_adv = drm.transform(X_adv)
    df_dim_red_adv.loc[:, dim_red_labels[m][0]] = dim_red_res_adv[:, 0]
    df_dim_red_adv.loc[:, dim_red_labels[m][1]] = dim_red_res_adv[:, 1]
    dim_red_res_bckgrnd = drm.transform(X_bckgrnd)
    df_bckgrnd.loc[:, dim_red_labels[m][0]] = dim_red_res_bckgrnd[:, 0]
    df_bckgrnd.loc[:, dim_red_labels[m][1]] = dim_red_res_bckgrnd[:, 1]
df_dim_red_all = pd.concat([df_dim_red, df_dim_red_adv])
df_dim_red_w_bckgrnd = pd.concat([df_dim_red, df_dim_red_adv, df_bckgrnd])
df_dim_red_all.to_excel(f"{path_save}/df_dim_red.xlsx")

for trgt in ["Age", "SImAge", "Error"]:
    for m, drm in dim_red_models.items():
        legend_handles = []
        norm = plt.Normalize(df_dim_red_all[trgt].min(), df_dim_red_all[trgt].max())
        sm = plt.cm.ScalarMappable(cmap="spring", norm=norm)
        sm.set_array([])
        fig = plt.figure(figsize=(8, 6))
        sns.set_theme(style='whitegrid')

        scatter = sns.scatterplot(
            data=df_dim_red,
            x=dim_red_labels[m][0],
            y=dim_red_labels[m][1],
            palette='spring',
            hue=trgt,
            linewidth=0.5,
            alpha=0.75,
            edgecolor="k",
            marker='o',
            s=50,
        )
        scatter.get_legend().remove()
        legend_handles.append(mlines.Line2D([], [], marker='o', linestyle='None', markeredgecolor='k', markerfacecolor='lightgrey', markersize=10, label='Real'))

        scatter = sns.scatterplot(
            data=df_dim_red_adv,
            x=dim_red_labels[m][0],
            y=dim_red_labels[m][1],
            palette='spring',
            hue=trgt,
            linewidth=0.5,
            alpha=0.75,
            edgecolor="k",
            marker='X',
            s=50,
        )
        scatter.get_legend().remove()
        legend_handles.append(mlines.Line2D([], [], marker='X', linestyle='None', markeredgecolor='k', markerfacecolor='lightgrey', markersize=10, label='Attack'))

        plt.legend(handles=legend_handles, title="Samples", bbox_to_anchor=(0, 1.02, 1, 0.2), loc="lower left", mode="expand", borderaxespad=0, ncol=3, frameon=False)
        fig.colorbar(sm, label=trgt)
        plt.savefig(f"{path_save}/{trgt}_{m}.png", bbox_inches='tight', dpi=400)
        plt.savefig(f"{path_save}/{trgt}_{m}.pdf", bbox_inches='tight')
        plt.close()

        if trgt == "SImAge":

            n_bins = 100
            x_xtd = (df_dim_red_w_bckgrnd[dim_red_labels[m][0]].max() - df_dim_red_w_bckgrnd[dim_red_labels[m][0]].min()) * 0.075
            x_min = df_dim_red_w_bckgrnd[dim_red_labels[m][0]].min() - x_xtd
            x_max = df_dim_red_w_bckgrnd[dim_red_labels[m][0]].max() + x_xtd
            x_shift = (x_max - x_min) / n_bins
            x_bin_centers = np.linspace(
                start=x_min + 0.5 * x_shift,
                stop=x_max - 0.5 * x_shift,
                num=n_bins
            )
            y_xtd = (df_dim_red_w_bckgrnd[dim_red_labels[m][1]].max() - df_dim_red_w_bckgrnd[dim_red_labels[m][1]].min()) * 0.075
            y_min = df_dim_red_w_bckgrnd[dim_red_labels[m][1]].min() - y_xtd
            y_max = df_dim_red_w_bckgrnd[dim_red_labels[m][1]].max() + y_xtd
            y_shift = (y_max - y_min) / n_bins
            y_bin_centers = np.linspace(
                start=y_min + 0.5 * y_shift,
                stop=y_max - 0.5 * y_shift,
                num=n_bins
            )
            df_heatmap_sum = pd.DataFrame(index=x_bin_centers, columns=y_bin_centers, data=np.zeros((n_bins, n_bins)))
            df_heatmap_cnt = pd.DataFrame(index=x_bin_centers, columns=y_bin_centers, data=np.zeros((n_bins, n_bins)))

            xs = df_bckgrnd.loc[:, dim_red_labels[m][0]].values
            xs_ids = np.floor((xs - x_min) / (x_shift + 1e-10)).astype(int)
            ys = df_bckgrnd.loc[:, dim_red_labels[m][1]].values
            ys_ids = np.floor((ys - y_min) / (y_shift + 1e-10)).astype(int)
            zs = df_bckgrnd.loc[:, trgt].values
            for d_id in range(len(xs_ids)):
                df_heatmap_sum.iat[xs_ids[d_id], ys_ids[d_id]] += zs[d_id]
                df_heatmap_cnt.iat[xs_ids[d_id], ys_ids[d_id]] += 1
            df_heatmap = pd.DataFrame(data=df_heatmap_sum.values / df_heatmap_cnt.values, columns=df_heatmap_sum.columns, index=df_heatmap_sum.index)
            df_heatmap.to_excel(f"{path_save}/heatmap_{trgt}_{m}.xlsx")

            legend_handles = []
            norm = plt.Normalize(df_dim_red_w_bckgrnd[trgt].min(), df_dim_red_w_bckgrnd[trgt].max())
            sm = plt.cm.ScalarMappable(cmap="spring", norm=norm)
            sm.set_array([])
            fig = plt.figure(figsize=(8, 6))
            sns.set_theme(style='whitegrid')

            plt.gca().imshow(
                X=df_heatmap.transpose().iloc[::-1].values,
                extent=[x_min, x_max, y_min, y_max],
                vmin=df_dim_red_w_bckgrnd[trgt].min(),
                vmax=df_dim_red_w_bckgrnd[trgt].max(),
                aspect=x_shift/y_shift,
                cmap="spring",
                alpha=0.75
            )
            legend_handles.append(mlines.Line2D([], [], marker='s', linestyle='None',markeredgewidth=0, markerfacecolor='lightgrey', markersize=10, label='Background'))

            scatter = sns.scatterplot(
                data=df_dim_red,
                x=dim_red_labels[m][0],
                y=dim_red_labels[m][1],
                palette='spring',
                hue=trgt,
                linewidth=0.5,
                alpha=0.75,
                edgecolor="k",
                marker='o',
                s=50,
            )
            scatter.get_legend().remove()
            legend_handles.append(mlines.Line2D([], [], marker='o', linestyle='None', markeredgecolor='k', markerfacecolor='lightgrey', markersize=10, label='Real'))

            scatter = sns.scatterplot(
                data=df_dim_red_adv,
                x=dim_red_labels[m][0],
                y=dim_red_labels[m][1],
                palette='spring',
                hue=trgt,
                linewidth=0.5,
                alpha=0.75,
                edgecolor="k",
                marker='X',
                s=50,
            )
            scatter.get_legend().remove()
            legend_handles.append(mlines.Line2D([], [], marker='X', linestyle='None', markeredgecolor='k', markerfacecolor='lightgrey', markersize=10, label='Attack'))

            plt.legend(handles=legend_handles, title="Samples", bbox_to_anchor=(0, 1.02, 1, 0.2), loc="lower left", mode="expand", borderaxespad=0, ncol=3, frameon=False)
            fig.colorbar(sm, label=trgt)
            plt.savefig(f"{path_save}/{trgt}_{m}_w_bckgrnd.png", bbox_inches='tight', dpi=400)
            plt.savefig(f"{path_save}/{trgt}_{m}_w_bckgrnd.pdf", bbox_inches='tight')
            plt.close()