# Debugging autoreload

In [None]:
%load_ext autoreload
%autoreload 2

# Load packages

In [None]:
import pandas as pd
import numpy as np
from src.utils.outliers.iqr import add_iqr_outs_to_df, plot_iqr_outs, plot_iqr_outs_regression_error
from src.utils.outliers.pyod import add_pyod_outs_to_df, plot_pyod_outs, plot_pyod_outs_regression_error
from plotly.subplots import make_subplots
from scipy import stats
import plotly.express as px
import plotly.io as pio
import importlib
pio.kaleido.scope.mathjax = None
from plotly.offline import init_notebook_mode, iplot
from scipy.interpolate import interp1d
from src.utils.verbose import NoStdStreams
init_notebook_mode(connected=False)
import matplotlib.pyplot as plt
from matplotlib import colors
from omegaconf import OmegaConf
from tqdm import tqdm
import seaborn as sns
from glob import glob
import pathlib
import patchworklib as pw
import os
import functools
from scipy.stats import iqr
from statannotations.Annotator import Annotator
from scipy.stats import mannwhitneyu
import shap
from slugify import slugify
from src.models.tabular.widedeep.ft_transformer import WDFTTransformerModel
from art.estimators.regression.pytorch import PyTorchRegressor
from art.estimators.classification import PyTorchClassifier
from art.estimators.regression.blackbox import BlackBoxRegressor
from art.attacks.evasion import ProjectedGradientDescentNumpy, FastGradientMethod, BasicIterativeMethod, MomentumIterativeMethod
from art.attacks.evasion import ZooAttack, CarliniL2Method, ElasticNet, NewtonFool
import torch
from src.tasks.metrics import get_cls_pred_metrics, get_cls_prob_metrics, get_reg_metrics
import matplotlib.lines as mlines

from pyod.models.ecod import ECOD
from pyod.models.copod import COPOD
from pyod.models.sos import SOS
from pyod.models.qmcd import QMCD as QMCDOD
from pyod.models.sampling import Sampling
from pyod.models.gmm import GMM
from pyod.models.pca import PCA
from pyod.models.mcd import MCD
from pyod.models.cd import CD
from pyod.models.lmdd import LMDD
from pyod.models.lof import LOF
from pyod.models.cof import COF
from pyod.models.cblof import CBLOF
from pyod.models.hbos import HBOS
from pyod.models.knn import KNN
from pyod.models.sod import SOD
from pyod.models.rod import ROD
from pyod.models.iforest import IForest
from pyod.models.inne import INNE
from pyod.models.dif import DIF
from pyod.models.feature_bagging import FeatureBagging
from pyod.models.loda import LODA
from pyod.models.lunar import LUNAR

from pythresh.thresholds.iqr import IQR
from pythresh.thresholds.mad import MAD
from pythresh.thresholds.fwfm import FWFM
from pythresh.thresholds.yj import YJ
from pythresh.thresholds.zscore import ZSCORE
from pythresh.thresholds.aucp import AUCP
from pythresh.thresholds.qmcd import QMCD
from pythresh.thresholds.fgd import FGD
from pythresh.thresholds.dsn import DSN
from pythresh.thresholds.clf import CLF
from pythresh.thresholds.filter import FILTER
from pythresh.thresholds.wind import WIND
from pythresh.thresholds.eb import EB
from pythresh.thresholds.regr import REGR
from pythresh.thresholds.boot import BOOT
from pythresh.thresholds.mcst import MCST
from pythresh.thresholds.hist import HIST
from pythresh.thresholds.moll import MOLL
from pythresh.thresholds.chau import CHAU
from pythresh.thresholds.gesd import GESD
from pythresh.thresholds.mtt import MTT
from pythresh.thresholds.karch import KARCH
from pythresh.thresholds.ocsvm import OCSVM
from pythresh.thresholds.clust import CLUST
from pythresh.thresholds.decomp import DECOMP
from pythresh.thresholds.meta import META
from pythresh.thresholds.vae import VAE
from pythresh.thresholds.cpd import CPD
from pythresh.thresholds.gamgmm import GAMGMM
from pythresh.thresholds.mixmod import MIXMOD

from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
from pytorch_tabular.models.common.heads import LinearHeadConfig
from pytorch_tabular import model_sweep
import warnings


def conjunction(conditions):
    return functools.reduce(np.logical_and, conditions)


def disjunction(conditions):
    return functools.reduce(np.logical_or, conditions)



# Load data and model, define PyTorchRegressor, setup colors

In [None]:
path = "D:/YandexDisk/Work/pydnameth/datasets/GPL21145/GSEUNN"
path_model = f"{path}/data/immuno/models/SImAge"
path_save = f"{path}/special/064_tai_report_4/immuno"
pathlib.Path(f"{path_save}").mkdir(parents=True, exist_ok=True)
df = pd.read_excel(f"{path}/data/immuno/models/SImAge/data.xlsx", index_col='sample_id')
feats = pd.read_excel(f"{path}/data/immuno/models/SImAge/feats_con_top10.xlsx", index_col=0).index.values
ids_feat = list(range(len(feats)))
col_trgt = 'Age'
col_pred = 'SImAge'

df_preds = pd.read_excel(f"{path}/data/immuno/models/SImAge/results/predictions.xlsx", index_col=0)
ids_trn = df_preds.index[df_preds['fold_0002'] == 'trn'].values
ids_val = df_preds.index[df_preds['fold_0002'] == 'val'].values
ids_tst = df_preds.index[df_preds['fold_0002'] == 'tst_ctrl_central'].values
ids_all = df_preds.index[df_preds['fold_0002'].isin(['trn', 'val', 'tst_ctrl_central'])].values
ids_trn_val = df_preds.index[df_preds['fold_0002'].isin(['trn', 'val'])].values
ids_dict = {
    'all': ids_all,
    'trn_val': ids_trn_val,
    'tst': ids_tst
}

df = df.loc[ids_all, :]
df["SImAge Error"] = df["SImAge"] - df["Age"]
df["|SImAge Error|"] = df["SImAge Error"].abs()
df['Data'] = 'Real'
df['Eps'] = 'Origin'

model = WDFTTransformerModel.load_from_checkpoint(checkpoint_path=f"{path}/data/immuno/models/SImAge/best_fold_0002.ckpt")
model.eval()
model.freeze()

def predict_func_regression(X):
    model.produce_probabilities = True
    batch = {
        'all': torch.from_numpy(np.float32(X[:, ids_feat])),
        'continuous': torch.from_numpy(np.float32(X[:, ids_feat])),
        'categorical': torch.from_numpy(np.int32(X[:, []])),
    }
    tmp = model(batch)
    return tmp.cpu().detach().numpy()

art_regressor = PyTorchRegressor(
    model=model,
    loss=model.loss_fn,
    input_shape=[len(feats)],
    optimizer=torch.optim.Adam(
        params=model.parameters(),
        lr=model.hparams.optimizer_lr,
        weight_decay=model.hparams.optimizer_weight_decay
    ),
    use_amp=False,
    opt_level="O1",
    loss_scale="dynamic",
    channels_first=True,
    clip_values=None,
    preprocessing_defences=None,
    postprocessing_defences=None,
    preprocessing=(0.0, 1.0),
    device_type="cpu",
)

colors_atks = {
    "MomentumIterative": px.colors.qualitative.D3[0],
    "BasicIterative": px.colors.qualitative.D3[1],
    "FastGradient": px.colors.qualitative.D3[3],
}

df.to_excel(f"{path_save}/df_origin.xlsx", index_label='sample_id')

# Create pyod and pythresh models

In [None]:
classifiers = {
    'ECDF-Based (ECOD)': ECOD(),
    'Copula-Based (COPOD)': COPOD(),
    # 'Stochastic (SOS)': SOS(),
    'Quasi-Monte Carlo Discrepancy (QMCD)': QMCDOD(),
    'Rapid distance-based via Sampling': Sampling(),
    'Probabilistic Mixture Modeling (GMM)': GMM(),
    'Principal Component Analysis (PCA)': PCA(),
    'Minimum Covariance Determinant (MCD)': MCD(),
    'Cook\'s Distance (CD)': CD(),
    'Deviation-based Outlier Detection (LMDD)': LMDD(),
    'Local Outlier Factor (LOF)': LOF(),
    'Connectivity-Based Outlier Factor (COF)': COF(),
    'Clustering-Based Local Outlier Factor (CBLOF)': CBLOF(),
    'Histogram-based Outlier Score (HBOS)': HBOS(),
    'k Nearest Neighbors (kNN)': KNN(),
    'Subspace Outlier Detection (SOD)': SOD(),
    # 'Rotation-based Outlier Detection (ROD)': ROD(),
    'Isolation Forest': IForest(),
    # 'Isolation-Based with Nearest-Neighbor Ensembles (INNE)': INNE(),
    'Deep Isolation Forest for Anomaly Detection (DIF)': DIF(),
    'Feature Bagging': FeatureBagging(),
    'Lightweight On-line Detector of Anomalies (LODA)': LODA(),
    'LUNAR': LUNAR()
}

thresholders = {
        'Inter-Quartile Region (IQR)':IQR(),
        'Median Absolute Deviation (MAD)':MAD(),
        'Full Width at Full Minimum (FWFM)':FWFM(),
        'Yeo-Johnson Transformation (YJ)': YJ(),
        'Z Score (ZSCORE)': ZSCORE(),
        'AUC Percentage (AUCP)': AUCP(),
        'Quasi-Monte Carlo Discreperancy (QMCD)': QMCD(),
        'Fixed Gradient Descent (FGD)': FGD(),
        'Distance Shift from Normal (DSN)': DSN(),
        'Trained Classifier (CLF)': CLF(),
        'Filtering Based (FILTER)': FILTER(),
        'Topological Winding Number (WIND)': WIND(),
        'Elliptical Boundary (EB)': EB(),
        'Regression Intercept (REGR)': REGR(),
        # 'Bootstrap Method (BOOT)': BOOT(),
        'Monte Carlo Statistical Tests (MCST)': MCST(),
        # 'Histogram Based Methods (HIST)': HIST(),
        'Mollifier (MOLL)': MOLL(),
        "Chauvenet's Criterion (CHAU)": CHAU(),
        'Generalized Extreme Studentized Deviate (GESD)': GESD(),
        # 'Modified Thompson Tau Test (MTT)': MTT(),
        'Karcher Mean (KARCH)': KARCH(),
        'One-Class SVM (OCSVM)': OCSVM(),
        'Clustering (CLUST)': CLUST(),
        'Decomposition (DECOMP)': DECOMP(),
        'Meta-model (META)': META(),
        'Variational Autoencoder (VAE)': VAE(),
        'Change Point Detection (CPD)': CPD(),
        # 'Bayesian Gamma GMM (GAMGMM)': GAMGMM(skip=True),
        'Mixture Models (MIXMOD)': MIXMOD(),
}

# Outliers for original data

## Generate

In [None]:
%%capture

df_outs = pd.DataFrame(index=list(classifiers.keys()), columns=list(thresholders.keys()))
for pyod_m_name, pyod_m in classifiers.items():
    scores = pyod_m.fit(df.loc[:, feats].values).decision_scores_
    for pythresh_m_name, pythresh_m in thresholders.items():
        labels = pythresh_m.eval(scores)
        df_outs.at[pyod_m_name, pythresh_m_name] = sum(labels) / len(labels) * 100

## Plot

In [None]:
df_outs.to_excel(f"{path_save}/outliers.xlsx")
df_fig = df_outs.astype(float)
sns.set_theme(style='ticks', font_scale=1.0)
fig, ax = plt.subplots(figsize=(16, 12))
heatmap = sns.heatmap(
    df_fig,
    annot=True,
    fmt=".1f",
    cmap='hot',
    linewidth=0.1,
    linecolor='black',
    cbar_kws={
        'orientation': 'horizontal',
        'location': 'top',
        'pad': 0.025,
        'aspect': 30
    },
    annot_kws={"size": 12},
    ax=ax
)
ax.set_xlabel('Outliers Detection Algorithms')
ax.set_ylabel('Thresholding Algorithms')
heatmap_pos = heatmap.get_position()
ax.figure.axes[-1].set_title("Outliers' percentage")
ax.figure.axes[-1].tick_params()
for spine in ax.figure.axes[-1].spines.values():
    spine.set_linewidth(1)
plt.savefig(f"{path_save}/outliers.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path_save}/outliers.pdf", bbox_inches='tight')
plt.close(fig)

# Adversarial attacks

## Generate

In [None]:
epsilons = sorted(list(set.union(
    set(np.linspace(0.1, 1.0, 10)), 
    set(np.linspace(0.01, 0.1, 10)),
)))
df_eps = pd.DataFrame(index=epsilons)

for eps_raw in epsilons:

    eps = np.array([eps_raw * iqr(df.loc[:, feat].values) for feat in feats])
    eps_step = np.array([0.2 * eps_raw * iqr(df.loc[:, feat].values) for feat in feats])

    attacks = {
        'MomentumIterative': MomentumIterativeMethod(
            estimator=art_regressor,
            norm=np.inf,
            eps=eps,
            eps_step=eps_step,
            decay=0.1,
            max_iter=100,
            targeted=False,
            batch_size=512,
            verbose=True
        ),
        'BasicIterative': BasicIterativeMethod(
            estimator=art_regressor,
            eps=eps,
            eps_step=eps_step,
            max_iter=100,
            targeted=False,
            batch_size=512,
            verbose=True
        ),
        'FastGradient': FastGradientMethod(
            estimator=art_regressor,
            norm=np.inf,
            eps=eps,
            eps_step=eps_step,
            targeted=False,
            num_random_init=0,
            batch_size=512,
            minimal=False,
            summary_writer=False,
        ),
    }

    for attack_name, attack in attacks.items():
        path_curr = f"{path_save}/Evasion/{attack_name}/eps_{eps_raw:0.4f}"
        pathlib.Path(f"{path_curr}").mkdir(parents=True, exist_ok=True)

        X_adv = attack.generate(np.float32(df.loc[:, feats].values))
        
        df_adv = df.loc[:, ['Age']].copy()
        df_adv.loc[:, feats] = X_adv
        df_adv["SImAge"] = model(torch.from_numpy(np.float32(df_adv.loc[:, feats].values))).cpu().detach().numpy().ravel()
        df_adv["SImAge Error"] = df_adv["SImAge"] - df_adv["Age"]
        df_adv["|SImAge Error|"] = df_adv["SImAge Error"].abs()
        df_adv.loc[:, "Error Origin"] = df.loc[:, "SImAge"] - df.loc[:, "Age"]
        df_adv.loc[:, "Error Attack"] = df_adv.loc[:, "SImAge"] - df_adv.loc[:, "Age"]
        df_adv['Error Diff'] = df_adv['Error Attack'] - df_adv['Error Origin']
        df_adv['|Error Diff|'] = df_adv['Error Diff'].abs()
            
        df_adv.to_excel(f"{path_curr}/df.xlsx", index_label='sample_id')

        metrics = get_reg_metrics()
        metrics_cols = [f"{m}_{p}" for m in metrics for p in ids_dict]
        df_metrics = pd.DataFrame(index=metrics_cols)
        for p, ids_part in ids_dict.items():
            for m in metrics:
                m_val = float(metrics[m][0](torch.from_numpy(np.float32(df.loc[ids_part, "SImAge"].values)), torch.from_numpy(np.float32(df.loc[ids_part, "Age"].values))).numpy())
                df_metrics.at[f"{m}_{p}", 'Origin'] = m_val
                metrics[m][0].reset()
                m_val = float(metrics[m][0](torch.from_numpy(np.float32(df_adv.loc[ids_part, "SImAge"].values)), torch.from_numpy(np.float32(df.loc[ids_part, "Age"].values))).numpy())
                df_metrics.at[f"{m}_{p}", 'Attack'] = m_val
                metrics[m][0].reset()
        df_metrics.to_excel(f"{path_curr}/metrics.xlsx", index_label='Metrics')
        
        for p in ids_dict:
            if attack_name == 'MomentumIterative':
                df_eps.loc[eps_raw, f"Origin_MAE_{p}"] = df_metrics.at[f'mean_absolute_error_{p}', 'Origin']
            df_eps.loc[eps_raw, f"{attack_name}_MAE_{p}"] = df_metrics.at[f'mean_absolute_error_{p}', 'Attack']
            
df_eps.to_excel(f"{path_save}/Evasion/df_eps.xlsx", index_label='eps')

## Plot Error from Eps

In [None]:
for p in ids_dict:
    df_fig = df_eps.loc[:, [f"{x}_MAE_{p}" for x in colors_atks]].copy()
    df_fig.rename(columns={f"{x}_MAE_{p}": x for x in colors_atks}, inplace=True)
    df_fig['Eps'] = df_fig.index.values
    df_fig = df_fig.melt(id_vars="Eps", var_name='Method', value_name="MAE")
    sns.set_theme(style='ticks', font_scale=1)
    fig = plt.figure()
    lines = sns.lineplot(
        data=df_fig,
        x='Eps',
        y="MAE",
        hue=f"Method",
        style=f"Method",
        palette=colors_atks,
        hue_order=list(colors_atks.keys()),
        markers=True,
        dashes=False,
    )
    plt.xscale('log')
    lines.set_xlabel(r'$\epsilon$')
    x_min = 0.009
    x_max = 1.05
    mae_basic = df_eps.at[0.01, f"Origin_MAE_{p}"]
    lines.set_xlim(x_min, x_max)
    plt.gca().plot(
        [x_min, x_max],
        [mae_basic, mae_basic],
        color='k',
        linestyle='dashed',
        linewidth=1
    )
    plt.savefig(f"{path_save}/Evasion/line_mae_vs_eps_{p}.png", bbox_inches='tight', dpi=200)
    plt.savefig(f"{path_save}/Evasion/line_mae_vs_eps_{p}.pdf", bbox_inches='tight')
    plt.close(fig)

## Plot distributions

In [None]:
epsilons_hglt = [0.01, 0.02, 0.05, 0.1, 0.2, 0.5]
colors_epsilons = {x: px.colors.qualitative.G10[x_id] for x_id, x in enumerate(['Origin'] + epsilons_hglt)}

df['Eps'] = 'Origin'
df['MarkerSize'] = 40

for atk in colors_atks:

    for eps in epsilons_hglt:
        path_curr = f"{path_save}/Evasion/{atk}/eps_{eps:0.4f}"
        pathlib.Path(f"{path_curr}/SImAgeError").mkdir(parents=True, exist_ok=True)
        df_adv = pd.read_excel(f"{path_curr}/df.xlsx", index_col='sample_id')
        df_adv.index += f'_eps_{eps:0.4f}'
        df_adv['Eps'] = eps
        df_adv['MarkerSize'] = 30
        df_ori_adv = pd.concat([df, df_adv])
        
        pw_brick_kdes = {}
        pw_brick_scatters = {}
        for f in feats:
            
            pw_brick_kdes[f] = pw.Brick(figsize=(3, 2))
            sns.set_theme(style='whitegrid')
            kdeplot = sns.kdeplot(
                data=df_ori_adv,
                x=f,
                hue='Eps',
                palette={'Origin': 'grey', eps: colors_epsilons[eps]},
                hue_order=['Origin', eps],
                fill=True,
                common_norm=False,
                ax=pw_brick_kdes[f]
            )
            
            pw_brick_scatters[f] = pw.Brick(figsize=(3, 2))
            sns.set_theme(style='whitegrid')
            scatterplot = sns.scatterplot(
                data=df_ori_adv,
                x=f,
                y='Age',
                hue='Eps',
                palette={'Origin': 'grey', eps: colors_epsilons[eps]},
                hue_order=['Origin', eps],
                linewidth=0.85,
                alpha=0.75,
                edgecolor="k",
                marker='o',
                s=30,
                ax=pw_brick_scatters[f]
            )
        
        n_cols = 5
        n_rows = int(np.ceil(len(feats)/ n_cols))
        pw_rows_kdes = []
        pw_rows_scatters = []
        for r_id in range(n_rows):
            pw_cols_kdes = []
            pw_cols_scatters = []
            for c_id in range(n_cols):
                rc_id = r_id * n_cols + c_id
                if rc_id < len(feats):
                    f = feats[rc_id]
                    pw_cols_kdes.append(pw_brick_kdes[f])
                    pw_cols_scatters.append(pw_brick_scatters[f])
                else:
                    empty_fig = pw.Brick(figsize=(4.67, 3))
                    empty_fig.axis('off')
                    pw_cols_kdes.append(empty_fig)
                    pw_cols_scatters.append(empty_fig)
            pw_rows_kdes.append(pw.stack(pw_cols_kdes, operator="|"))
            pw_rows_scatters.append(pw.stack(pw_cols_scatters, operator="|"))
        pw_fig_kde = pw.stack(pw_rows_kdes, operator="/")
        pw_fig_kde.savefig(f"{path_curr}/feats_kde.png", bbox_inches='tight', dpi=200)
        pw_fig_kde.savefig(f"{path_curr}/feats_kde.pdf", bbox_inches='tight')
        pw_fig_scatter = pw.stack(pw_rows_scatters, operator="/")
        pw_fig_scatter.savefig(f"{path_curr}/feats_scatter.png", bbox_inches='tight', dpi=200)
        pw_fig_scatter.savefig(f"{path_curr}/feats_scatter.pdf", bbox_inches='tight')
        pw.clear()

## Outliers for attacks

In [12]:
%%capture

epsilons_hglt = [0.05, 0.1, 0.5, 1.0]

for atk in colors_atks:
    for eps in epsilons_hglt:
        path_curr = f"{path_save}/Evasion/{atk}/eps_{eps:0.4f}"
        pathlib.Path(f"{path_curr}/SImAgeError").mkdir(parents=True, exist_ok=True)
        df_adv = pd.read_excel(f"{path_curr}/df.xlsx", index_col='sample_id')
        
        df_outs = pd.DataFrame(index=list(classifiers.keys()), columns=list(thresholders.keys()))
        for pyod_m_name, pyod_m in classifiers.items():
            scores = pyod_m.fit(df_adv.loc[:, feats].values).decision_scores_
            for pythresh_m_name, pythresh_m in thresholders.items():
                labels = pythresh_m.eval(scores)
                df_outs.at[pyod_m_name, pythresh_m_name] = sum(labels) / len(labels) * 100
                
        df_outs.to_excel(f"{path_curr}/outliers.xlsx")
        
        df_fig = df_outs.astype(float)
        sns.set_theme(style='ticks', font_scale=1.0)
        fig, ax = plt.subplots(figsize=(16, 12))
        heatmap = sns.heatmap(
            df_fig,
            annot=True,
            fmt=".1f",
            cmap='hot',
            linewidth=0.1,
            linecolor='black',
            cbar_kws={
                'orientation': 'horizontal',
                'location': 'top',
                'pad': 0.025,
                'aspect': 30
            },
            annot_kws={"size": 12},
            ax=ax
        )
        ax.set_xlabel('Outliers Detection Algorithms')
        ax.set_ylabel('Thresholding Algorithms')
        heatmap_pos = heatmap.get_position()
        ax.figure.axes[-1].set_title("Outliers' percentage")
        ax.figure.axes[-1].tick_params()
        for spine in ax.figure.axes[-1].spines.values():
            spine.set_linewidth(1)
        plt.savefig(f"{path_curr}/outliers.png", bbox_inches='tight', dpi=200)
        plt.savefig(f"{path_curr}/outliers.pdf", bbox_inches='tight')
        plt.close(fig)

# Adversarial defences from attacks

## Generate detectors

In [None]:
df_ori = df[feats].copy()
df_ori['Class'] = 'Original'

for atk in colors_atks:
    
    df_def_acc = pd.DataFrame(index=epsilons, columns=['Model'] + list(epsilons))
    
    for eps in tqdm(epsilons):
        
        path_curr = f"{path_save}/Evasion/{atk}/eps_{eps:0.4f}"
        df_adv = pd.read_excel(f"{path_curr}/df.xlsx", index_col='sample_id')
        df_adv = df_adv[feats]
        df_adv['Class'] = 'Attack'
        df_def_trn_val = pd.concat([df_ori.loc[ids_trn_val, :], df_adv.loc[ids_trn_val, :]])
        df_def_tst = pd.concat([df_ori.loc[ids_tst, :], df_adv.loc[ids_tst, :]])
        
        data_config = DataConfig(
            target=['Class'],
            continuous_cols=list(feats),
            continuous_feature_transform='yeo-johnson',
            normalize_continuous_features=True,
        )
        
        trainer_config = TrainerConfig(
            batch_size=1024,
            max_epochs=100,
            min_epochs=1,
            auto_lr_find=True,
            early_stopping='valid_loss',
            early_stopping_min_delta=0.0001,
            early_stopping_mode='min',
            early_stopping_patience=100,
            checkpoints='valid_loss',
            checkpoints_path=f"{path_curr}/detector",
            load_best=True,
            progress_bar='none',
            seed=42
        )
        
        optimizer_config = OptimizerConfig(
            optimizer='Adam',
            lr_scheduler='CosineAnnealingWarmRestarts',
            lr_scheduler_params={
                'T_0': 10,
                'T_mult': 1,
                'eta_min': 0.00001,
            },
            lr_scheduler_monitor_metric='valid_loss'
        )

        head_config = LinearHeadConfig(
            layers='',
            activation='ReLU',
            dropout=0.1,
            use_batch_norm=False,
            initialization='xavier',
        ).__dict__

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            sweep_df, best_model = model_sweep(
                task="classification",
                train=df_def_trn_val,
                test=df_def_tst,
                data_config=data_config,
                optimizer_config=optimizer_config,
                trainer_config=trainer_config,
                model_list="standard",
                common_model_args=dict(head="LinearHead", head_config=head_config),
                metrics=[
                    'accuracy',
                    'f1_score',
                    'precision',
                    'recall',
                    'specificity',
                    'cohen_kappa',
                    'auroc'
                ],
                metrics_prob_input=[True, True, True, True, True, True, True],
                metrics_params=[
                    {'task': 'multiclass', 'num_classes': 2, 'average': 'weighted'},
                    {'task': 'multiclass', 'num_classes': 2, 'average': 'weighted'},
                    {'task': 'multiclass', 'num_classes': 2, 'average': 'weighted'},
                    {'task': 'multiclass', 'num_classes': 2, 'average': 'weighted'},
                    {'task': 'multiclass', 'num_classes': 2, 'average': 'weighted'},
                    {'task': 'multiclass', 'num_classes': 2},
                    {'task': 'multiclass', 'num_classes': 2, 'average': 'weighted'},
                ],
                rank_metric=("accuracy", "higher_is_better"),
                progress_bar=False,
                verbose=False,
                suppress_lightning_logger=True,
            )
        ckpts = glob(f"{path_curr}/detector/*")
        for ckpt in ckpts:
            os.remove(ckpt)
        # best_model.save_model(f"{path_curr}/detector")
        df_def_acc.at[eps, 'Model'] = best_model.config['_model_name']
        
        for tst_eps in epsilons:
            if tst_eps != eps:
                path_tst = f"{path_save}/Evasion/{atk}/eps_{tst_eps:0.4f}"
                df_adv_tst = pd.read_excel(f"{path_tst}/df.xlsx", index_col='sample_id')
                df_adv_tst = df_adv_tst[feats]
                df_adv_tst['Class'] = 'Attack'
                df_def_tst_eps = pd.concat([df_ori, df_adv_tst])
                metrics = best_model.evaluate(test=df_def_tst_eps, verbose=False)[0]
                df_def_acc.at[eps, tst_eps] = metrics['test_accuracy']
    df_def_acc.to_excel(f"{path_save}/Evasion/{atk}/detectors_accuracy.xlsx")            
    

## Plot detectors accuracy

In [None]:
for atk in colors_atks:
    df_def_acc = pd.read_excel(f"{path_save}/Evasion/{atk}/detectors_accuracy.xlsx", index_col=0)
    df_def_acc['Eps'] = [f"{x:.2f}" for x in df_def_acc.index.values]
    df_def_acc['index'] = df_def_acc['Model'] + '\n' + df_def_acc['Eps']
    df_def_acc.set_index('index', inplace=True)
    df_def_acc.drop(['Model', 'Eps'], axis=1, inplace=True)
    df_def_acc.rename(columns={x: f"{x:.2f}" for x in df_def_acc.columns}, inplace=True)
    
    df_fig = df_def_acc.astype(float)
    sns.set_theme(style='ticks', font_scale=1.0)
    fig, ax = plt.subplots(figsize=(13, 12))
    heatmap = sns.heatmap(
        df_fig,
        annot=True,
        fmt=".2f",
        cmap='hot',
        linewidth=0.1,
        linecolor='black',
        cbar_kws={
            'orientation': 'horizontal',
            'location': 'top',
            'pad': 0.025,
            'aspect': 30
        },
        annot_kws={"size": 12},
        ax=ax
    )
    ax.set_xlabel('Test Attack Strength')
    ax.set_ylabel('Training Model and Data')
    heatmap_pos = heatmap.get_position()
    ax.figure.axes[-1].set_title("Accuracy")
    ax.figure.axes[-1].tick_params()
    for spine in ax.figure.axes[-1].spines.values():
        spine.set_linewidth(1)
    plt.savefig(f"{path_save}/Evasion/{atk}/detectors_accuracy.png", bbox_inches='tight', dpi=200)
    plt.savefig(f"{path_save}/Evasion/{atk}/detectors_accuracy.pdf", bbox_inches='tight')
    plt.close(fig)