In [None]:
import pandas as pd
import numpy as np
from plotly.subplots import make_subplots
from scipy import stats
import plotly.express as px
from scripts.python.routines.plot.scatter import add_scatter_trace
import plotly.graph_objects as go
from scripts.python.routines.plot.save import save_figure
from scripts.python.routines.plot.layout import add_layout, get_axis
from statsmodels.stats.multitest import multipletests
import plotly.io as pio
pio.kaleido.scope.mathjax = None
from plotly.offline import init_notebook_mode, iplot
from scipy.interpolate import interp1d
init_notebook_mode(connected=False)
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
import seaborn as sns
from glob import glob
import pathlib
from sklearn.decomposition import PCA, TruncatedSVD
from sklearn.decomposition import MiniBatchDictionaryLearning, FastICA
from sklearn.random_projection import GaussianRandomProjection, SparseRandomProjection
from sklearn.manifold import MDS, Isomap
from openTSNE import TSNE
from sklearn.metrics import mean_absolute_error
from scipy import stats
import patchworklib as pw
import os
import functools
from scipy.stats import iqr
from statannotations.Annotator import Annotator
from scipy.stats import mannwhitneyu
import shap
from slugify import slugify
from src.models.tabular.widedeep.ft_transformer import WDFTTransformerModel
from art.estimators.regression.pytorch import PyTorchRegressor
from art.estimators.regression.blackbox import BlackBoxRegressor
from art.attacks.evasion import ProjectedGradientDescentNumpy, FastGradientMethod, BasicIterativeMethod, MomentumIterativeMethod
from art.defences.detector.evasion.binary_input_detector import BinaryInputDetector
import torch
from src.tasks.metrics import get_cls_pred_metrics, get_cls_prob_metrics, get_reg_metrics
import matplotlib.lines as mlines

from sdv.metadata import SingleTableMetadata
from sdv.lite import SingleTablePreset
from sdv.single_table import GaussianCopulaSynthesizer, CTGANSynthesizer, TVAESynthesizer, CopulaGANSynthesizer
from sdv.evaluation.single_table import evaluate_quality
from sdv.evaluation.single_table import get_column_plot
from sdv.evaluation.single_table import get_column_pair_plot


def conjunction(conditions):
    return functools.reduce(np.logical_and, conditions)


def disjunction(conditions):
    return functools.reduce(np.logical_or, conditions)

# 1. Adversarial examples for immunology data

## 1.1. Preparing original data, model and functions

In [None]:
path = "D:/YandexDisk/Work/pydnameth/datasets/GPL21145/GSEUNN"
path_model = f"{path}/data/immuno/models/SImAge"
path_save = f"{path}/special/046_adversarial_robustness_toolbox/immunology"
pathlib.Path(f"{path_save}").mkdir(parents=True, exist_ok=True)
df = pd.read_excel(f"{path}/data/immuno/models/SImAge/data.xlsx", index_col='sample_id')
feats = pd.read_excel(f"{path}/data/immuno/models/SImAge/feats_con_top10.xlsx", index_col=0).index.values
ids_feat = list(range(len(feats)))
col_trgt = 'Age'
col_pred = 'SImAge'

df_preds = pd.read_excel(f"{path}/data/immuno/models/SImAge/results/predictions.xlsx", index_col=0)
ids_trn = df_preds.index[df_preds['fold_0002'] == 'trn'].values
ids_val = df_preds.index[df_preds['fold_0002'] == 'val'].values
ids_tst = df_preds.index[df_preds['fold_0002'] == 'tst_ctrl_central'].values
ids_all = list(set.union(set(ids_trn), set(ids_val), set(ids_tst)))
ids_trn_val = list(set.union(set(ids_trn), set(ids_val)))
df = df.loc[ids_all, :]

df_X = df.loc[ids_all, feats]

model = WDFTTransformerModel.load_from_checkpoint(checkpoint_path=f"{path}/data/immuno/models/SImAge/best_fold_0002.ckpt")
model.eval()
model.freeze()

def predict_func_regression(X):
    model.produce_probabilities = True
    batch = {
        'all': torch.from_numpy(np.float32(X[:, ids_feat])),
        'continuous': torch.from_numpy(np.float32(X[:, ids_feat])),
        'categorical': torch.from_numpy(np.int32(X[:, []])),
    }
    tmp = model(batch)
    return tmp.cpu().detach().numpy()

art_regressor = PyTorchRegressor(
    model=model,
    loss=model.loss_fn,
    input_shape=[len(feats)],
    optimizer=torch.optim.Adam(
        params=model.parameters(),
        lr=model.hparams.optimizer_lr,
        weight_decay=model.hparams.optimizer_weight_decay
    ),
    use_amp=False,
    opt_level="O1",
    loss_scale="dynamic",
    channels_first=True,
    clip_values=None,
    preprocessing_defences=None,
    postprocessing_defences=None,
    preprocessing=(0.0, 1.0),
    device_type="cpu",
)

## 1.2. Dimensionality reduction models

In [None]:
data_dim_red = df.loc[ids_trn_val, feats].values

dim_red_labels = {
    'PCA': ['PC 1', 'PC 2'],
    'SVD': ['SVD 1', 'SVD 2'],
    't-SNE': ['t-SNE 1', 't-SNE 2'],
    'MDS': ['MDS 1', 'MDS 2'],
    'GRP': ['GRP 1', 'GRP 2'],
    'SRP': ['SRP 1', 'SRP 2'],
    'IsoMap': ['IsoMap 1', 'IsoMap 2'],
    'MBDL': ['MBDL 1', 'MBDL 2'],
    'ICA': ['ICA 1', 'ICA 2'],
}

dim_red_models = {
    'PCA': PCA(n_components=2, whiten=False).fit(data_dim_red),
    'SVD': TruncatedSVD(n_components=2, algorithm='randomized', n_iter=5).fit(data_dim_red),
    't-SNE': TSNE(n_components=2).fit(data_dim_red), 'MDS': MDS(n_components=2, metric=True),
    'GRP': GaussianRandomProjection(n_components=2, eps=0.5).fit(data_dim_red),
    'SRP': SparseRandomProjection(n_components=2, density='auto', eps=0.5, dense_output=False).fit(data_dim_red),
    'IsoMap': Isomap(n_components=2, n_neighbors=5).fit(data_dim_red),
    'MBDL': MiniBatchDictionaryLearning(n_components=2, batch_size=100, alpha=1, n_iter=25).fit(data_dim_red),
    'ICA': FastICA(n_components=2, algorithm='parallel', whiten=True, tol=1e-3, max_iter=1000)
}

## 1.3 Generate augmented data

In [None]:
color_augs = {
    'FAST_ML': px.colors.qualitative.Light24[0],
    'GaussianCopula': px.colors.qualitative.Light24[1],
    'CTGANSynthesizer': px.colors.qualitative.Light24[2],
    'TVAESynthesizer': px.colors.qualitative.Light24[3],
    'CopulaGANSynthesizer': px.colors.qualitative.Light24[4],
}

In [None]:
# Naive augmented data: the same distribution of original feats ===============
path_curr = f"{path_save}/Augmentation/Naive"
pathlib.Path(f"{path_curr}").mkdir(parents=True, exist_ok=True)

n_bins = 100
n_smps = 10000

df_aug_naive = pd.DataFrame(columns=np.concatenate((feats, ['Age'])))
for f in feats:
    f_vals = df.loc[ids_trn_val, f].values
    counts, bin_edges = np.histogram(df.loc[ids_trn_val, f].values, bins=n_bins)
    df_aug_naive[f] = np.random.choice(bin_edges[:-1], size=n_smps, p=counts/len(f_vals))
df_aug_naive["SImAge"] = model(torch.from_numpy(np.float32(df_aug_naive.loc[:, feats].values))).cpu().detach().numpy().ravel()
df_aug_naive["SImAge Error"] = df_aug_naive["SImAge"] - df_aug_naive["Age"]
df_aug_naive["abs(SImAge Error)"] = df_aug_naive["SImAge Error"].abs()
df_aug_naive.to_excel(f"{path_curr}/df.xlsx", index_label='index')

In [None]:
# Augmented data with Synthetic Data Vault (SDV) ==============================
n_smps = 10000

df_aug_sdv_input = df.loc[:, np.concatenate((feats, ['Age']))]

metadata = SingleTableMetadata()
metadata.detect_from_dataframe(data=df_aug_sdv_input)

synthesizers = {
    'FAST_ML': SingleTablePreset(metadata, name='FAST_ML'),
    'GaussianCopula': GaussianCopulaSynthesizer(metadata),
    'CTGANSynthesizer': CTGANSynthesizer(metadata),
    'TVAESynthesizer': TVAESynthesizer(metadata),
    'CopulaGANSynthesizer': CopulaGANSynthesizer(metadata),
}
for s_name, s in synthesizers.items():
    path_curr = f"{path_save}/Augmentation/{s_name}"
    pathlib.Path(f"{path_curr}").mkdir(parents=True, exist_ok=True)

    s.fit(
        data=df_aug_sdv_input
    )
    s.save(
        filepath=f"{path_curr}/synthesizer.pkl"
    )
    df_aug_sdv = s.sample(
        num_rows=n_smps
    )
    quality_report = evaluate_quality(
        df_aug_sdv_input,
        df_aug_sdv,
        metadata
    )
    
    q_rep_prop = quality_report.get_properties()
    q_rep_prop.set_index('Property', inplace=True)
    
    df_col_shapes = quality_report.get_details(property_name='Column Shapes')
    df_col_shapes.sort_values(["Score"], ascending=[False], inplace=True)
    df_col_shapes.to_excel(f"{path_curr}/ColumnShapes.xlsx", index=False)
    fig = plt.figure(figsize=(3, 5))
    sns.set_theme(style='whitegrid')
    barplot = sns.barplot(
        data=df_col_shapes,
        x="Score",
        y="Column",
        edgecolor='black',
        color=color_augs[s_name],
        dodge=False,
        orient='h'
    )
    barplot.set_title(f"{s_name} Average Score: {q_rep_prop.at['Column Shapes', 'Score']:0.2f}")
    barplot.set_xlabel(f"KSComplement")
    barplot.set_ylabel(f"Features")
    plt.savefig(f"{path_curr}/ColumnShapes.png", bbox_inches='tight', dpi=200)
    plt.savefig(f"{path_curr}/ColumnShapes.pdf", bbox_inches='tight')
    plt.close(fig)
    
    df_col_pair_trends = quality_report.get_details(property_name='Column Pair Trends')
    df_col_pair_trends.to_excel(f"{path_curr}/ColumnPairTrends.xlsx", index=False)
    feats_plot = np.concatenate((feats, ['Age']))
    df_corr_mtx = pd.DataFrame(data=np.zeros(shape=(len(feats_plot), len(feats_plot))), index=feats_plot, columns=feats_plot)
    df_pair_mtx = pd.DataFrame(index=feats_plot, columns=feats_plot)
    for index, row in df_col_pair_trends.iterrows():
        df_corr_mtx.at[row['Column 1'], row['Column 2']] = row['Real Correlation']
        df_corr_mtx.at[row['Column 2'], row['Column 1']] = row['Synthetic Correlation']
        df_pair_mtx.at[row['Column 1'], row['Column 2']] = row['Score']
        df_pair_mtx.at[row['Column 2'], row['Column 1']] = row['Score']
    
    fig = plt.figure()
    df_pair_mtx.fillna(value=np.nan, inplace=True)
    sns.set_theme(style='whitegrid')
    heatmap = sns.heatmap(
        data=df_pair_mtx,
        cmap='plasma',
        annot=True,
        fmt="0.2f",
        cbar_kws={'label': "Correlation Similarity"},
        mask=df_pair_mtx.isnull()
    )
    heatmap.set(xlabel="", ylabel="")
    heatmap.tick_params(axis='x', rotation=90)
    heatmap.set_title(f"{s_name} Average Score: {q_rep_prop.at['Column Pair Trends', 'Score']:0.2f}")
    plt.savefig(f"{path_curr}/ColumnPairTrends.png", bbox_inches='tight', dpi=200)
    plt.savefig(f"{path_curr}/ColumnPairTrends.pdf", bbox_inches='tight')
    plt.close(fig)
    
    sns.set_theme(style='whitegrid')
    mtx_to_plot = df_corr_mtx.to_numpy()
    mtx_triu = np.triu(mtx_to_plot, +1)
    mtx_triu_mask = np.ma.masked_array(mtx_triu, mtx_triu==0)
    cmap_triu = plt.get_cmap("seismic").copy()
    mtx_tril = np.tril(mtx_to_plot, -1)
    mtx_tril_mask = np.ma.masked_array(mtx_tril, mtx_tril==0)
    cmap_tril = plt.get_cmap("PRGn").copy()
    fig, ax = plt.subplots()
    im_triu = ax.imshow(mtx_triu_mask, cmap=cmap_triu, vmin=-1, vmax=1)
    cbar_triu = ax.figure.colorbar(im_triu, ax=ax, location='right', shrink=0.7, pad=0.1)
    cbar_triu.ax.tick_params(labelsize=10)
    cbar_triu.set_label("Real Correlation", horizontalalignment='center', fontsize=12)
    im_tril = ax.imshow(mtx_tril_mask, cmap=cmap_tril, vmin=-1, vmax=1)
    cbar_tril = ax.figure.colorbar(im_tril, ax=ax, location='right', shrink=0.7, pad=0.1)
    cbar_tril.ax.tick_params(labelsize=10)
    cbar_tril.set_label("Synthetic Correlation", horizontalalignment='center', fontsize=12)
    ax.grid(None)
    ax.set_aspect("equal")
    ax.set_xticks(np.arange(df_corr_mtx.shape[1]))
    ax.set_yticks(np.arange(df_corr_mtx.shape[0]))
    ax.set_xticklabels(df_corr_mtx.columns.values)
    ax.set_yticklabels(df_corr_mtx.index.values)
    plt.setp(ax.get_xticklabels(), rotation=90)
    ax.tick_params(axis='both', which='major', labelsize=10)
    ax.tick_params(axis='both', which='minor', labelsize=10)
    for i in range(df_corr_mtx.shape[0]):
        for j in range(df_corr_mtx.shape[1]):
            color = "black"
            if i != j:
                color = "black"
                if np.abs(mtx_tril[i, j]) > 0.5:
                    color = 'white'
                text = ax.text(j, i, f"{mtx_to_plot[i, j]:0.2f}", ha="center", va="center", color=color, fontsize=7)
    fig.tight_layout()
    plt.savefig(f"{path_curr}/Correlations.png", bbox_inches='tight', dpi=200)
    plt.savefig(f"{path_curr}/Correlations.pdf", bbox_inches='tight')
    plt.clf()  
    
    df_aug_sdv["SImAge"] = model(torch.from_numpy(np.float32(df_aug_sdv.loc[:, feats].values))).cpu().detach().numpy().ravel()
    df_aug_sdv["SImAge Error"] = df_aug_sdv["SImAge"] - df_aug_sdv["Age"]
    df_aug_sdv["abs(SImAge Error)"] = df_aug_sdv["SImAge Error"].abs()
    df_aug_sdv.to_excel(f"{path_curr}/df.xlsx", index_label='index')

## 1.3. Evasion with defences

In [None]:
ids_trgt = ids_all

epsilons_plot = [0.01, 0.02, 0.05, 0.1, 0.2, 0.5]
epsilons = sorted(list(set.union(set(epsilons_plot), set(np.linspace(0.1, 1.0, 10)), set(np.linspace(0.01, 0.1, 10)))))
df_eps = pd.DataFrame(index=epsilons)

for eps_raw in epsilons:

    eps = np.array([eps_raw * iqr(df.loc[ids_trn, feat].values) for feat in feats])
    eps_step = np.array([0.2 * eps_raw * iqr(df.loc[ids_trn, feat].values) for feat in feats])

    attacks = {
        'MI': MomentumIterativeMethod(
            estimator=art_regressor,
            norm=np.inf,
            eps=eps,
            eps_step=eps_step,
            decay=0.1,
            max_iter=100,
            targeted=False,
            batch_size=512,
            verbose=True
        ),
        'BI': BasicIterativeMethod(
            estimator=art_regressor,
            eps=eps,
            eps_step=eps_step,
            max_iter=100,
            targeted=False,
            batch_size=512,
            verbose=True
        ),
        'PGD': ProjectedGradientDescentNumpy(
            estimator=art_regressor,
            norm=np.inf,
            eps=eps,
            eps_step=eps_step,
            decay=None,
            max_iter=100,
            targeted=False,
            num_random_init=0,
            batch_size=512,
            random_eps=False,
            summary_writer=False,
            verbose=True
        ),
        'FG': FastGradientMethod(
            estimator=art_regressor,
            norm=np.inf,
            eps=eps,
            eps_step=eps_step,
            targeted=False,
            num_random_init=0,
            batch_size=512,
            minimal=False,
            summary_writer=False,
        ),
    }

    for attack_type, attack in attacks.items():

        pathlib.Path(f"{path_save}/eps_{eps_raw:0.4f}/{attack_type}").mkdir(parents=True, exist_ok=True)

        # Save adversarial and clean samples for binary input detector
        df_adv = pd.DataFrame(data=attack.generate(np.float32(df_X.loc[ids_all, :].values)), columns=feats, index=ids_all)
        df_adv.loc[ids_all, 'Dataset'] = df.loc[ids_all, 'Dataset']
        df_adv['Index'] = [f"{sample}_adv" for sample in ids_all]
        df_adv.set_index('Index', inplace=True)
        df_adv['DataType'] = 'Adversarial'
        df_cln = df.loc[ids_all, list(feats) + ['Dataset']]
        df_cln['DataType'] = 'Clean'
        df_detector = pd.concat([df_cln, df_adv])
        df_detector.to_excel(f"{path_save}/eps_{eps_raw:0.4f}/{attack_type}/data_detector.xlsx", index_label='index')

        X = df_X.loc[ids_trgt, :].values
        X_adv = attack.generate(np.float32(df_X.loc[ids_trgt, :].values))

        y_real = np.float32(df.loc[ids_trgt, target].values)
        y_pred = model(torch.from_numpy(X)).cpu().detach().numpy().ravel()
        y_pred_adv = model(torch.from_numpy(X_adv)).cpu().detach().numpy().ravel()

        metrics = get_reg_metrics()
        df_metrics = pd.DataFrame(index=[m for m in metrics])
        for m in metrics:
            m_val = float(metrics[m][0](torch.from_numpy(y_pred), torch.from_numpy(y_real)).numpy())
            df_metrics.at[m, 'Origin'] = m_val
            metrics[m][0].reset()
            m_val = float(metrics[m][0](torch.from_numpy(y_pred_adv), torch.from_numpy(y_real)).numpy())
            df_metrics.at[m, 'Attack'] = m_val
            metrics[m][0].reset()
        df_metrics.to_excel(f"{path_save}/eps_{eps_raw:0.4f}/{attack_type}/metrics.xlsx")

        df_eps.loc[eps_raw, f"{attack_type}_MAE_Origin"] = df_metrics.at['mean_absolute_error', 'Origin']
        df_eps.loc[eps_raw, f"{attack_type}_MAE_Attack"] = df_metrics.at['mean_absolute_error', 'Attack']

        df_error = df.loc[ids_trgt, ["Age"]].copy()
        df_error.loc[ids_trgt, "Error Origin"] = y_pred - y_real
        df_error.loc[ids_trgt, "Error Attack"] = y_pred_adv - y_real
        for sample in ids_trgt:
            df_eps.loc[eps_raw, f"{attack_type}_{sample}_ErrorDiff"] = df_error.loc[sample, "Error Attack"] - df_error.loc[sample, "Error Origin"]

        if eps_raw in epsilons_plot:

            n_bins = 100
            n_bckgrnd = 100000
            df_bckgrnd = pd.DataFrame(columns=feats)
            for feat in feats:
                f_vals = df.loc[ids_trn_val, feat].values
                counts, bin_edges = np.histogram(df.loc[ids_trn_val, feat].values, bins=n_bins)
                df_bckgrnd[feat] = np.random.choice(bin_edges[:-1], size=n_bckgrnd, p=counts/len(f_vals))
            X_bckgrnd = df_bckgrnd.loc[:, feats].values
            df_bckgrnd["SImAge"] = model(torch.from_numpy(np.float32(X_bckgrnd))).cpu().detach().numpy().ravel()

            df_dim_red = df.loc[ids_trgt, ['Age']].copy()
            df_dim_red.loc[ids_trgt, "SImAge"] = y_pred
            df_dim_red.loc[ids_trgt, "Error"] = df_dim_red.loc[ids_trgt, "SImAge"].values - df_dim_red.loc[ids_trgt, "Age"].values
            df_dim_red_adv = df.loc[ids_trgt, ['Age']].copy()
            df_dim_red_adv.loc[ids_trgt, "SImAge"] = y_pred_adv
            df_dim_red_adv.loc[ids_trgt, "Error"] = df_dim_red_adv.loc[ids_trgt, "SImAge"].values - df_dim_red_adv.loc[ids_trgt, "Age"].values
            df_dim_red_adv['index'] = df_dim_red_adv.index.values + '_adv'
            df_dim_red_adv.set_index('index', inplace=True)

            for m, drm in dim_red_models.items():
                dim_red_res = drm.transform(X)
                df_dim_red.loc[:, dim_red_labels[m][0]] = dim_red_res[:, 0]
                df_dim_red.loc[:, dim_red_labels[m][1]] = dim_red_res[:, 1]
                dim_red_res_adv = drm.transform(X_adv)
                df_dim_red_adv.loc[:, dim_red_labels[m][0]] = dim_red_res_adv[:, 0]
                df_dim_red_adv.loc[:, dim_red_labels[m][1]] = dim_red_res_adv[:, 1]
                dim_red_res_bckgrnd = drm.transform(X_bckgrnd)
                df_bckgrnd.loc[:, dim_red_labels[m][0]] = dim_red_res_bckgrnd[:, 0]
                df_bckgrnd.loc[:, dim_red_labels[m][1]] = dim_red_res_bckgrnd[:, 1]
            df_dim_red_all = pd.concat([df_dim_red, df_dim_red_adv])
            df_dim_red_w_bckgrnd = pd.concat([df_dim_red, df_dim_red_adv, df_bckgrnd])
            df_dim_red_all.to_excel(f"{path_save}/eps_{eps_raw:0.4f}/{attack_type}/df_dim_red.xlsx")

            for trgt in ["Age", "SImAge", "Error"]:
                for m, drm in dim_red_models.items():
                    legend_handles = []
                    norm = plt.Normalize(df_dim_red_all[trgt].min(), df_dim_red_all[trgt].max())
                    sm = plt.cm.ScalarMappable(cmap="spring", norm=norm)
                    sm.set_array([])
                    fig = plt.figure(figsize=(8, 6))
                    sns.set_theme(style='whitegrid')

                    scatter = sns.scatterplot(
                        data=df_dim_red,
                        x=dim_red_labels[m][0],
                        y=dim_red_labels[m][1],
                        palette='spring',
                        hue=trgt,
                        linewidth=0.5,
                        alpha=0.75,
                        edgecolor="k",
                        marker='o',
                        s=50,
                    )
                    scatter.get_legend().remove()
                    legend_handles.append(mlines.Line2D([], [], marker='o', linestyle='None', markeredgecolor='k', markerfacecolor='lightgrey', markersize=10, label='Real'))

                    scatter = sns.scatterplot(
                        data=df_dim_red_adv,
                        x=dim_red_labels[m][0],
                        y=dim_red_labels[m][1],
                        palette='spring',
                        hue=trgt,
                        linewidth=0.5,
                        alpha=0.75,
                        edgecolor="k",
                        marker='X',
                        s=50,
                    )
                    scatter.get_legend().remove()
                    legend_handles.append(mlines.Line2D([], [], marker='X', linestyle='None', markeredgecolor='k', markerfacecolor='lightgrey', markersize=10, label='Attack'))

                    plt.legend(handles=legend_handles, title="Samples", bbox_to_anchor=(0, 1.02, 1, 0.2), loc="lower left", mode="expand", borderaxespad=0, ncol=3, frameon=False)
                    fig.colorbar(sm, label=trgt)
                    plt.savefig(f"{path_save}/eps_{eps_raw:0.4f}/{attack_type}/{trgt}_{m}.png", bbox_inches='tight', dpi=400)
                    plt.savefig(f"{path_save}/eps_{eps_raw:0.4f}/{attack_type}/{trgt}_{m}.pdf", bbox_inches='tight')
                    plt.close()

                    if trgt == "SImAge":

                        n_bins = 100
                        x_xtd = (df_dim_red_w_bckgrnd[dim_red_labels[m][0]].max() - df_dim_red_w_bckgrnd[dim_red_labels[m][0]].min()) * 0.075
                        x_min = df_dim_red_w_bckgrnd[dim_red_labels[m][0]].min() - x_xtd
                        x_max = df_dim_red_w_bckgrnd[dim_red_labels[m][0]].max() + x_xtd
                        x_shift = (x_max - x_min) / n_bins
                        x_bin_centers = np.linspace(
                            start=x_min + 0.5 * x_shift,
                            stop=x_max - 0.5 * x_shift,
                            num=n_bins
                        )
                        y_xtd = (df_dim_red_w_bckgrnd[dim_red_labels[m][1]].max() - df_dim_red_w_bckgrnd[dim_red_labels[m][1]].min()) * 0.075
                        y_min = df_dim_red_w_bckgrnd[dim_red_labels[m][1]].min() - y_xtd
                        y_max = df_dim_red_w_bckgrnd[dim_red_labels[m][1]].max() + y_xtd
                        y_shift = (y_max - y_min) / n_bins
                        y_bin_centers = np.linspace(
                            start=y_min + 0.5 * y_shift,
                            stop=y_max - 0.5 * y_shift,
                            num=n_bins
                        )
                        df_heatmap_sum = pd.DataFrame(index=x_bin_centers, columns=y_bin_centers, data=np.zeros((n_bins, n_bins)))
                        df_heatmap_cnt = pd.DataFrame(index=x_bin_centers, columns=y_bin_centers, data=np.zeros((n_bins, n_bins)))

                        xs = df_bckgrnd.loc[:, dim_red_labels[m][0]].values
                        xs_ids = np.floor((xs - x_min) / (x_shift + 1e-10)).astype(int)
                        ys = df_bckgrnd.loc[:, dim_red_labels[m][1]].values
                        ys_ids = np.floor((ys - y_min) / (y_shift + 1e-10)).astype(int)
                        zs = df_bckgrnd.loc[:, trgt].values
                        for d_id in range(len(xs_ids)):
                            df_heatmap_sum.iat[xs_ids[d_id], ys_ids[d_id]] += zs[d_id]
                            df_heatmap_cnt.iat[xs_ids[d_id], ys_ids[d_id]] += 1
                        df_heatmap = pd.DataFrame(data=df_heatmap_sum.values / df_heatmap_cnt.values, columns=df_heatmap_sum.columns, index=df_heatmap_sum.index)
                        df_heatmap.to_excel(f"{path_save}/eps_{eps_raw:0.4f}/{attack_type}/heatmap_{trgt}_{m}.xlsx")

                        legend_handles = []
                        norm = plt.Normalize(df_dim_red_w_bckgrnd[trgt].min(), df_dim_red_w_bckgrnd[trgt].max())
                        sm = plt.cm.ScalarMappable(cmap="spring", norm=norm)
                        sm.set_array([])
                        fig = plt.figure(figsize=(8, 6))
                        sns.set_theme(style='whitegrid')

                        plt.gca().imshow(
                            X=df_heatmap.transpose().iloc[::-1].values,
                            extent=[x_min, x_max, y_min, y_max],
                            vmin=df_dim_red_w_bckgrnd[trgt].min(),
                            vmax=df_dim_red_w_bckgrnd[trgt].max(),
                            aspect=x_shift/y_shift,
                            cmap="spring",
                            alpha=0.75
                        )
                        legend_handles.append(mlines.Line2D([], [], marker='s', linestyle='None',markeredgewidth=0, markerfacecolor='lightgrey', markersize=10, label='Background'))

                        scatter = sns.scatterplot(
                            data=df_dim_red,
                            x=dim_red_labels[m][0],
                            y=dim_red_labels[m][1],
                            palette='spring',
                            hue=trgt,
                            linewidth=0.5,
                            alpha=0.75,
                            edgecolor="k",
                            marker='o',
                            s=50,
                        )
                        scatter.get_legend().remove()
                        legend_handles.append(mlines.Line2D([], [], marker='o', linestyle='None', markeredgecolor='k', markerfacecolor='lightgrey', markersize=10, label='Real'))

                        scatter = sns.scatterplot(
                            data=df_dim_red_adv,
                            x=dim_red_labels[m][0],
                            y=dim_red_labels[m][1],
                            palette='spring',
                            hue=trgt,
                            linewidth=0.5,
                            alpha=0.75,
                            edgecolor="k",
                            marker='X',
                            s=50,
                        )
                        scatter.get_legend().remove()
                        legend_handles.append(mlines.Line2D([], [], marker='X', linestyle='None', markeredgecolor='k', markerfacecolor='lightgrey', markersize=10, label='Attack'))

                        plt.legend(handles=legend_handles, title="Samples", bbox_to_anchor=(0, 1.02, 1, 0.2), loc="lower left", mode="expand", borderaxespad=0, ncol=3, frameon=False)
                        fig.colorbar(sm, label=trgt)
                        plt.savefig(f"{path_save}/eps_{eps_raw:0.4f}/{attack_type}/{trgt}_{m}_w_bckgrnd.png", bbox_inches='tight', dpi=400)
                        plt.savefig(f"{path_save}/eps_{eps_raw:0.4f}/{attack_type}/{trgt}_{m}_w_bckgrnd.pdf", bbox_inches='tight')
                        plt.close()

df_eps.to_excel(f"{path_save}/df_eps.xlsx", index_label='eps')

### MAE from eps plot

In [None]:
attacks = {
    'MI': "Momentum Iterative",
    'BI': "Basic Iterative",
    'PGD': "Projected Gradient Descent",
    'FG': "Fast Gradient"
}
attacks_palette = {
    "Momentum Iterative": px.colors.qualitative.D3[0],
    "Basic Iterative": px.colors.qualitative.D3[1],
    "Projected Gradient Descent": px.colors.qualitative.D3[2],
    "Fast Gradient": px.colors.qualitative.D3[3]
}
df_fig = df_eps.loc[:, [f"{x}_MAE_Attack" for x in attacks]].copy()
df_fig.rename(columns={f"{x}_MAE_Attack": attacks[x] for x in attacks}, inplace=True)
df_fig.to_excel(f"{path_save}/mae_vs_eps.xlsx", index_label='Eps')
df_fig['Eps'] = df_fig.index.values
df_fig = df_fig.melt(id_vars="Eps", var_name='Method', value_name="MAE")
fig = plt.figure()
sns.set_theme(style='whitegrid', font_scale=1)
lines = sns.lineplot(
    data=df_fig,
    x='Eps',
    y="MAE",
    hue=f"Method",
    style=f"Method",
    palette=attacks_palette,
    hue_order=list(attacks_palette.keys()),
    markers=True,
    dashes=False,
)
plt.xscale('log')
x_min = 0.009
x_max = 1.05
mae_basic = df_eps.at[0.01, "MI_MAE_Origin"]
lines.set_xlim(x_min, x_max)
plt.gca().plot(
    [x_min, x_max],
    [mae_basic, mae_basic],
    color='k',
    linestyle='dashed',
    linewidth=1
)
plt.savefig(f"{path_save}/mae_vs_eps.png", bbox_inches='tight', dpi=400)
plt.savefig(f"{path_save}/mae_vs_eps.pdf", bbox_inches='tight')
plt.close(fig)

In [None]:
attacks = {
    'MI': "Momentum Iterative",
    'BI': "Basic Iterative",
    'PGD': "Projected Gradient Descent",
    'FG': "Fast Gradient"
}
for attack in attacks:
    df_fig = df_eps.loc[:, [f"{attack}_{sample}_ErrorDiff" for sample in ids_trgt]].copy()
    for sample in ids_trgt:
        func = interp1d(df_fig.index, df_fig[f"{attack}_{sample}_ErrorDiff"], kind='cubic')
        df_fig[f"{attack}_{sample}_ErrorDiff"] = func(df_fig.index)

    df_fig['Eps'] = df_fig.index.values
    df_fig = df_fig.melt(id_vars="Eps", var_name='ID', value_name="Error(Attack) - Error(Origin)")
    fig = plt.figure()
    sns.set_theme(style='whitegrid', font_scale=1)
    lines = sns.lineplot(
        data=df_fig,
        x='Eps',
        y="Error(Attack) - Error(Origin)",
        hue=f"ID",
        markers=False,
        dashes=False,
        legend=False,
        linewidth=0.2,
        alpha=0.7
    )
    plt.xscale('log')
    plt.savefig(f"{path_save}/{attack}.png", bbox_inches='tight', dpi=400)
    plt.savefig(f"{path_save}/{attack}.pdf", bbox_inches='tight')
    plt.close(fig)