# Idea
Generate 1000 attacks and use the following features to fit a detection head on them for decision making:
1. $MSE(x, f_{recon}(x|\hat{c}))$
2. $D(x)$ (veracity)
3. $p(\hat{c}|x)$ from $D_{aux}$
5. $p(\hat{c}|x)$ from the victim
6. $JS(victim(x), D_{aux}(x))$
4. $\log(D(x)) + \log(p(\hat{c}|x))$

# REMEMBER: Convert all outputs to probability distributions!

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys  # TODO: Fix this?

sys.path.append("PyTorch-StudioGAN/")
import itertools
import random

import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from Attacks.attacks_suite import *
from Classifiers.resnet import load_model as load_victim_model
from Defense.def_utils import *
import json
from functools import partial

import numpy as np

from gan_loader import load_model as load_gan_model
from LGDPM.diffusion_loader import load_model as load_diffusion_model
from sklearn.model_selection import GridSearchCV, train_test_split
from torchmetrics.functional.classification import binary_auroc, multiclass_accuracy
from torchmetrics.functional.regression import kl_divergence
from torchvision import datasets, transforms as T
from tqdm import tqdm
from xgboost import XGBClassifier

assert torch.cuda.is_available()


In [3]:
seed = 42
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True

random.seed(seed);
np.random.seed(seed);
torch.cuda.manual_seed(seed);
torch.manual_seed(seed);


In [4]:
class NormalizeInverse(T.Normalize):
    def __init__(self, mean, std, *args, **kwargs):
        mean = torch.as_tensor(mean)
        std = torch.as_tensor(std)
        std_inv = 1 / (std + 1e-7)
        mean_inv = -mean * std_inv
        super().__init__(mean=mean_inv, std=std_inv, *args, **kwargs)

    def __call__(self, tensor):
        return super().__call__(tensor.clone())

In [5]:
def JSD(p, q, log_prob=False):
    m = 0.5 * (p + q)
    return 0.5 * kl_divergence(p, m, log_prob=log_prob) + 0.5 * kl_divergence(
        q, m, log_prob=log_prob
    )


In [6]:
V_WEIGHTS = {
    "MNIST": "saved_models/victims/mnist_resnet18.pth",
    "CIFAR10": "saved_models/victims/cifar10_resnet18.pth",
    "TIMGNET": "saved_models/victims/timgnet_resnet18.pth",
}
cifar10_victim = load_victim_model(V_WEIGHTS["CIFAR10"], num_classes=10, in_channels=3)
timgnet_victim = load_victim_model(V_WEIGHTS["TIMGNET"], num_classes=200, in_channels=3)
cifar10_victim.eval()
timgnet_victim.eval()
VICTIM_MODELS = {
    "CIFAR10": cifar10_victim,
    "TIMGNET": timgnet_victim,
}


In [7]:
D_WEIGHTS = {
    "CIFAR10_DIFF": "saved_models/defenders/diffusion/cifar10_diff.pth",
    "TIMGNET_DIFF": "saved_models/defenders/diffusion/timgnet_diff.pth",
    "CIFAR10_REACGAN": (
        "saved_models/defenders/acgan/cifar10/G_cifar10_reacgan.pth",
        "saved_models/defenders/acgan/cifar10/D_cifar10_reacgan.pth",
    ),
    "TIMGNET_REACGAN": (
        "saved_models/defenders/acgan/timgnet/G_timgnet_reacgan.pth",
        "saved_models/defenders/acgan/timgnet/D_timgnet_reacgan.pth",
    ),
}


In [8]:
def generate_sample_baskets(n_samples=200):
    all_datasets = {
        "CIFAR10": datasets.CIFAR10(
            root="PyTorch-StudioGAN/data/",
            transform=T.Compose(
                [
                    T.ToTensor(),
                    T.Normalize(
                        DEFAULT_MOMENTS.CIFAR10.mean, DEFAULT_MOMENTS.CIFAR10.std
                    ),
                ]
            ),
        ),
        "TIMGNET": datasets.ImageFolder(
            root="PyTorch-StudioGAN/data/tiny-imagenet-200/train/",
            transform=T.Compose(
                [
                    T.ToTensor(),
                    T.Normalize(
                        DEFAULT_MOMENTS.TIMGNET.mean, DEFAULT_MOMENTS.TIMGNET.std
                    ),
                ]
            ),
        ),
    }
    samples = {}
    for name, ds in all_datasets.items():
        targets = torch.tensor(ds.targets)
        all_classes = torch.unique(targets)
        assert (
            n_samples % len(all_classes) == 0
        ), f"Cannot sample equally with the provided n_samples={n_samples} and n_targets={len(all_classes)}"
        n_samples_per_class = n_samples // len(all_classes)
        balanced_sample_indices = []
        for cls in all_classes:
            only_cls = torch.where(targets == cls)[0].tolist()
            balanced_sample_indices.extend(
                random.sample(only_cls, n_samples_per_class)
            )
        samples.update({name: [ds[idx] for idx in balanced_sample_indices]})
    return samples


In [9]:
def generate_attacks(victim, ds_name, real_sample_baskets, targeted=False, **kwargs):
    images, labels = zip(*real_sample_baskets[ds_name])
    images, labels = torch.stack(images).to("cuda"), torch.tensor(list(labels)).to(
        "cuda", dtype=torch.long
    )
    if isinstance(targeted, bool):
        suite = partial(run_attack_suite, targets="auto" if targeted else None)
    elif isinstance(targeted, torch.Tensor):
        suite = partial(run_attack_suite, targets=targeted)
    attacks = suite(victim, ds_name, images, labels, **kwargs)
    torch.cuda.empty_cache()
    return attacks


In [10]:
if os.path.exists("1k_real_sample_baskets.pth"):
    real_sample_baskets = torch.load("1k_real_sample_baskets.pth")
else:
    real_sample_baskets = generate_sample_baskets(n_samples=1000)
    torch.save(real_sample_baskets, "1k_real_sample_baskets.pth")
if os.path.exists("1k_ADV_ATTACKS.pth"):
    ADV_ATTACKS = torch.load("1k_ADV_ATTACKS.pth")
else:
    ADV_ATTACKS = {
        "CIFAR10": generate_attacks(
            cifar10_victim, "CIFAR10", real_sample_baskets, targeted=True, splits=10
        ),
        "TIMGNET": generate_attacks(
            timgnet_victim, "TIMGNET", real_sample_baskets, targeted=True, splits=10
        ),
    }
    torch.save(ADV_ATTACKS, "1k_ADV_ATTACKS.pth")


In [11]:
# TODO: have attacks and clean images here
# -> Discriminator Outputs
# -> Victim outputs
# -> Targeted Reconstruction
# -> Make training data
# -> Fit XGB

In [12]:
DEFENDER_DEFAULT_CONFIGS = {
    "DIFF": {
        "CIFAR10": "CIFAR10_DIFF",
        "TIMGNET": "TIMGNET_DIFF",
        "aux_d_type_for_diff": "REACGAN",
    },
    "ACGAN":{
        "CIFAR10": "CIFAR10_REACGAN",
        "TIMGNET": "TIMGNET_REACGAN",
    }
}

In [13]:
BASE_PATH = "DHT_models/"

# DIFFUSION

In [14]:
@torch.inference_mode()
def diffusion_metrics(
    conditional_diffuser,
    auxiliary_discriminator,
    inputs,
    labels,
    splits=1,
    **targeted_purify_kwargs
):
    num_channels = inputs.shape[1]
    discriminator_normalizer = T.Normalize(
        mean=DISCRIMINATOR_MOMENTS[num_channels]["mean"],
        std=DISCRIMINATOR_MOMENTS[num_channels]["std"],
    )
    recon_loss = diffusion_purify_targeted(
        conditional_diffuser,
        inputs,
        labels,
        return_purified=False,
        splits=splits,
        **targeted_purify_kwargs
    )
    dis_inputs = discriminator_normalizer(inputs).detach()
    dis_inputs_chunks = torch.chunk(dis_inputs, splits)
    veracity, class_posteriors = [], []
    for dis_inputs_chunk in dis_inputs_chunks:
        d_result = auxiliary_discriminator.forward_emb(dis_inputs_chunk)
        veracity.append(F.sigmoid(d_result["adv_output"].detach()))
        class_posteriors.append(F.softmax(d_result["cls_output"].detach(), dim=1))
    return recon_loss, torch.cat(veracity), torch.cat(class_posteriors)


In [15]:
NUM_SPLITS = 10
for dataset_name in ["CIFAR10", "TIMGNET"]:
    for attack_method in ["CW", "FGSM", "PGD"]:
        training_file_name = os.path.join(
            BASE_PATH, "DIFF", f"DIFF_{dataset_name}_{attack_method}_train.pth"
        )
        if not os.path.exists(training_file_name):
            victim = VICTIM_MODELS[dataset_name]
            defender = DEFENDER_DEFAULT_CONFIGS["DIFF"][dataset_name]
            aux_d_name = (
                dataset_name
                + "_"
                + DEFENDER_DEFAULT_CONFIGS["DIFF"]["aux_d_type_for_diff"]
            )
            _, defender_diff_ema = load_diffusion_model(D_WEIGHTS[defender], defender)
            _, defender_D, _ = load_gan_model(
                D_WEIGHTS[aux_d_name][0], D_WEIGHTS[aux_d_name][1], aux_d_name
            )
            defender_D.eval()
            defender_diff_ema.eval()
            clean_images, clean_true_labels = zip(*real_sample_baskets[dataset_name])
            clean_images, clean_true_labels = torch.stack(clean_images).to(
                "cuda"
            ), torch.tensor(list(clean_true_labels)).to("cuda")
            adv_denorm_images, adv_norm_images, adv_true_labels = (
                ADV_ATTACKS[dataset_name][attack_method][
                    "unnormalized_clipped_samples"
                ],
                ADV_ATTACKS[dataset_name][attack_method]["normalized_clipped_samples"],
                ADV_ATTACKS[dataset_name]["clean_labels"],
            )
            inverse_normalizer = NormalizeInverse(
                mean=DEFAULT_MOMENTS[dataset_name]["mean"],
                std=DEFAULT_MOMENTS[dataset_name]["std"],
            )
            ### CLEAN
            # clean images are normalized!
            clean_images_chunks = torch.chunk(clean_images, NUM_SPLITS)
            clean_victim_class_posteriors = F.softmax(
                torch.cat(
                    [victim(chunk).detach() for chunk in clean_images_chunks], dim=0
                ),
                dim=1,
            )
            clean_victim_labels = torch.argmax(clean_victim_class_posteriors, dim=1)
            clean_victim_max_class_posteriors = clean_victim_class_posteriors.gather(
                dim=1, index=clean_victim_labels[:, None]
            ).squeeze()
            (
                clean_recon_loss,
                clean_veracity,
                clean_dis_class_posteriors,
            ) = diffusion_metrics(
                defender_diff_ema,
                defender_D,
                inverse_normalizer(clean_images),
                clean_victim_labels,
                splits=NUM_SPLITS,
                disable_tqdm=False,
            )
            clean_dis_max_class_posteriors = clean_dis_class_posteriors.gather(
                dim=1, index=clean_victim_labels[:, None]
            ).squeeze()
            clean_jsd = torch.as_tensor(
                list(
                    map(
                        JSD,
                        clean_victim_class_posteriors.unsqueeze(-2),
                        clean_dis_class_posteriors.unsqueeze(-2),
                    )
                )
            )
            clean_sum_of_logs = torch.log(clean_veracity) + torch.log(
                clean_dis_max_class_posteriors
            )

            ### ADVERSARIAL
            adv_images_chunks = torch.chunk(adv_norm_images, NUM_SPLITS)
            adv_victim_class_posteriors = F.softmax(
                torch.cat(
                    [victim(chunk).detach() for chunk in adv_images_chunks],
                    dim=0,
                ),
                dim=1,
            )
            adv_victim_labels = torch.argmax(adv_victim_class_posteriors, dim=1)
            adv_victim_max_class_posteriors = adv_victim_class_posteriors.gather(
                dim=1, index=adv_victim_labels[:, None]
            ).squeeze()
            adv_recon_loss, adv_veracity, adv_dis_class_posteriors = diffusion_metrics(
                defender_diff_ema,
                defender_D,
                adv_denorm_images,
                adv_victim_labels,
                splits=NUM_SPLITS,
                disable_tqdm=False,
            )
            adv_dis_max_class_posteriors = adv_dis_class_posteriors.gather(
                dim=1, index=adv_victim_labels[:, None]
            ).squeeze()
            adv_jsd = torch.as_tensor(
                list(
                    map(
                        JSD,
                        adv_victim_class_posteriors.unsqueeze(-2),
                        adv_dis_class_posteriors.unsqueeze(-2),
                    )
                )
            )
            adv_sum_of_logs = torch.log(adv_veracity) + torch.log(
                adv_dis_max_class_posteriors
            )

            ### AGGREGATE
            clean_X_train = torch.vstack(
                [
                    clean_recon_loss.cpu(),
                    clean_veracity.cpu(),
                    clean_dis_max_class_posteriors.cpu(),
                    clean_victim_max_class_posteriors.cpu(),
                    clean_jsd.cpu(),
                    clean_sum_of_logs.cpu(),
                ]
            )
            adv_X_train = torch.vstack(
                [
                    adv_recon_loss.cpu(),
                    adv_veracity.cpu(),
                    adv_dis_max_class_posteriors.cpu(),
                    adv_victim_max_class_posteriors.cpu(),
                    adv_jsd.cpu(),
                    adv_sum_of_logs.cpu(),
                ]
            )
            X_train = torch.vstack([clean_X_train.T, adv_X_train.T]).numpy()
            y_train = torch.vstack(
                [
                    torch.zeros(clean_X_train.shape[1], 1),
                    torch.ones(adv_X_train.shape[1], 1),
                ]
            ).numpy()
            training_data = {
                "X": X_train,
                "y": y_train,
            }
            torch.save(training_data, training_file_name)
        else:
            training_data = torch.load(training_file_name)
            X_train, y_train = training_data["X"], training_data["y"]

        ### TRAIN
        clf_file_name = os.path.join(
            BASE_PATH, "DIFF", f"DIFF_{dataset_name}_{attack_method}_XGB.json"
        )
        gridsearch_hist_file_name = os.path.join(
            BASE_PATH, "DIFF", f"DIFF_{dataset_name}_{attack_method}_GridSearchCV.hist"
        )
        best_params_file_name = os.path.join(
            BASE_PATH, "DIFF", f"DIFF_{dataset_name}_{attack_method}_best_params.json"
        )
        base_xgb_model = XGBClassifier(objective="binary:logistic", eval_metric="auc")
        gridsearch_clf = GridSearchCV(
            base_xgb_model,
            {
                "max_depth": [1, 2, 3, 4, 5],
                "n_estimators": [2, 5, 10, 50],
            },
            verbose=1,
        )
        gridsearch_clf.fit(X_train, y_train)
        print(
            f"Best CV score: {gridsearch_clf.best_score_} - Best params: {gridsearch_clf.best_params_}"
        )
        torch.save(gridsearch_clf.cv_results_, gridsearch_hist_file_name)
        with open(best_params_file_name, "w") as hist_fp:
            json.dump(gridsearch_clf.best_params_, hist_fp)
        best_xgb_clf = XGBClassifier(
            objective="binary:logistic",
            eval_metric="auc",
            **gridsearch_clf.best_params_,
        )
        best_xgb_clf.fit(X_train, y_train)
        print(
            f"{dataset_name} - {attack_method} XGB clf Score: {best_xgb_clf.score(X_train, y_train)}"
        )
        best_xgb_clf.save_model(clf_file_name)
        print(best_xgb_clf.feature_importances_)
        torch.cuda.empty_cache()


10it [02:34, 15.46s/it]
10it [02:24, 14.41s/it]


Fitting 5 folds for each of 20 candidates, totalling 100 fits
Best CV score: 0.9970000000000001 - Best params: {'max_depth': 1, 'n_estimators': 50}
CIFAR10 - CW XGB clf Score: 0.9995
[0.00583138 0.00497438 0.8556871  0.09806246 0.00697095 0.02847368]


10it [02:22, 14.27s/it]
10it [02:22, 14.28s/it]


Fitting 5 folds for each of 20 candidates, totalling 100 fits
Best CV score: 0.9345000000000001 - Best params: {'max_depth': 1, 'n_estimators': 50}
CIFAR10 - FGSM XGB clf Score: 0.9445
[0.01572933 0.01057607 0.11516175 0.7250773  0.01718033 0.11627524]


10it [02:22, 14.26s/it]
10it [02:22, 14.27s/it]


Fitting 5 folds for each of 20 candidates, totalling 100 fits
Best CV score: 0.9935 - Best params: {'max_depth': 1, 'n_estimators': 50}
CIFAR10 - PGD XGB clf Score: 0.998
[0.         0.0088792  0.5938653  0.09941047 0.02393176 0.2739133 ]


10it [14:12, 85.28s/it]
10it [14:22, 86.21s/it]


Fitting 5 folds for each of 20 candidates, totalling 100 fits
Best CV score: 0.9515 - Best params: {'max_depth': 3, 'n_estimators': 50}
TIMGNET - CW XGB clf Score: 0.973
[0.02677916 0.02733272 0.5959568  0.2210356  0.08465185 0.04424397]


10it [15:15, 91.54s/it]
10it [14:49, 88.92s/it]


Fitting 5 folds for each of 20 candidates, totalling 100 fits
Best CV score: 0.9844999999999999 - Best params: {'max_depth': 1, 'n_estimators': 50}
TIMGNET - FGSM XGB clf Score: 0.9875
[0.00554014 0.00526516 0.10966279 0.68457156 0.1291508  0.06580959]


10it [14:33, 87.39s/it]
10it [14:27, 86.73s/it]


Fitting 5 folds for each of 20 candidates, totalling 100 fits
Best CV score: 0.9630000000000001 - Best params: {'max_depth': 2, 'n_estimators': 50}
TIMGNET - PGD XGB clf Score: 0.9755
[0.0131713  0.01470757 0.14458884 0.13044733 0.617773   0.07931195]


# ACGAN

In [16]:
def acgan_metrics(
    conditional_generator,
    auxiliary_discriminator,
    inputs,
    labels,
    z_dim,
    splits=1,
    **targeted_purify_kwargs
):
    num_channels = inputs.shape[1]
    discriminator_normalizer = T.Normalize(
        mean=DISCRIMINATOR_MOMENTS[num_channels]["mean"],
        std=DISCRIMINATOR_MOMENTS[num_channels]["std"],
    )
    recon_loss = acgan_purify_targeted(
        conditional_generator,
        inputs,
        labels,
        z_dim,
        return_purified=False,
        splits=splits,
        **targeted_purify_kwargs
    )
    dis_inputs = discriminator_normalizer(inputs).detach()
    dis_inputs_chunks = torch.chunk(dis_inputs, splits)
    veracity, class_posteriors = [], []
    for dis_inputs_chunk in dis_inputs_chunks:
        d_result = auxiliary_discriminator.forward_emb(dis_inputs_chunk)
        veracity.append(F.sigmoid(d_result["adv_output"].detach()))
        class_posteriors.append(F.softmax(d_result["cls_output"].detach(), dim=1))
    return recon_loss, torch.cat(veracity), torch.cat(class_posteriors)

In [17]:
NUM_SPLITS = 10
for dataset_name in ["CIFAR10", "TIMGNET"]:
    for attack_method in ["CW", "FGSM", "PGD"]:
        training_file_name = os.path.join(
            BASE_PATH, "ACGAN", f"ACGAN_{dataset_name}_{attack_method}_train.pth"
        )
        if not os.path.exists(training_file_name):
            victim = VICTIM_MODELS[dataset_name]
            defender = DEFENDER_DEFAULT_CONFIGS["ACGAN"][dataset_name]
            defender_G, defender_D, z_dim = load_gan_model(
                D_WEIGHTS[defender][0], D_WEIGHTS[defender][1], defender
            )
            defender_G.eval()
            defender_D.eval()
            clean_images, clean_true_labels = zip(*real_sample_baskets[dataset_name])
            clean_images, clean_true_labels = torch.stack(clean_images).to(
                "cuda"
            ), torch.tensor(list(clean_true_labels)).to("cuda")
            adv_denorm_images, adv_norm_images, adv_true_labels = (
                ADV_ATTACKS[dataset_name][attack_method][
                    "unnormalized_clipped_samples"
                ],
                ADV_ATTACKS[dataset_name][attack_method]["normalized_clipped_samples"],
                ADV_ATTACKS[dataset_name]["clean_labels"],
            )
            inverse_normalizer = NormalizeInverse(
                mean=DEFAULT_MOMENTS[dataset_name]["mean"],
                std=DEFAULT_MOMENTS[dataset_name]["std"],
            )
            ### CLEAN
            clean_images_chunks = torch.chunk(clean_images, NUM_SPLITS)
            clean_victim_class_posteriors = F.softmax(
                torch.cat(
                    [victim(chunk).detach() for chunk in clean_images_chunks], dim=0
                ),
                dim=1,
            )
            clean_victim_labels = torch.argmax(clean_victim_class_posteriors, dim=1)
            clean_victim_max_class_posteriors = clean_victim_class_posteriors.gather(
                dim=1, index=clean_victim_labels[:, None]
            ).squeeze()
            (
                clean_recon_loss,
                clean_veracity,
                clean_dis_class_posteriors,
            ) = acgan_metrics(
                defender_G,
                defender_D,
                inverse_normalizer(clean_images),
                clean_victim_labels,
                z_dim,
                splits=NUM_SPLITS,
                disable_tqdm=False,
            )
            clean_dis_max_class_posteriors = clean_dis_class_posteriors.gather(
                dim=1, index=clean_victim_labels[:, None]
            ).squeeze()
            clean_jsd = torch.as_tensor(
                list(
                    map(
                        JSD,
                        clean_victim_class_posteriors.unsqueeze(-2),
                        clean_dis_class_posteriors.unsqueeze(-2),
                    )
                )
            )
            clean_sum_of_logs = torch.log(clean_veracity) + torch.log(
                clean_dis_max_class_posteriors
            )

            ### ADVERSARIAL
            adv_images_chunks = torch.chunk(adv_norm_images, NUM_SPLITS)
            adv_victim_class_posteriors = F.softmax(
                torch.cat(
                    [victim(chunk).detach() for chunk in adv_images_chunks],
                    dim=0,
                ),
                dim=1,
            )
            adv_victim_labels = torch.argmax(adv_victim_class_posteriors, dim=1)
            adv_victim_max_class_posteriors = adv_victim_class_posteriors.gather(
                dim=1, index=adv_victim_labels[:, None]
            ).squeeze()
            adv_recon_loss, adv_veracity, adv_dis_class_posteriors = acgan_metrics(
                defender_G,
                defender_D,
                adv_denorm_images,
                adv_victim_labels,
                z_dim,
                splits=NUM_SPLITS,
                disable_tqdm=False,
            )
            adv_dis_max_class_posteriors = adv_dis_class_posteriors.gather(
                dim=1, index=adv_victim_labels[:, None]
            ).squeeze()
            adv_jsd = torch.as_tensor(
                list(
                    map(
                        JSD,
                        adv_victim_class_posteriors.unsqueeze(-2),
                        adv_dis_class_posteriors.unsqueeze(-2),
                    )
                )
            )
            adv_sum_of_logs = torch.log(adv_veracity) + torch.log(
                adv_dis_max_class_posteriors
            )

            ### AGGREGATE
            clean_X_train = torch.vstack(
                [
                    clean_recon_loss.detach().cpu(),
                    clean_veracity.detach().cpu(),
                    clean_dis_max_class_posteriors.detach().cpu(),
                    clean_victim_max_class_posteriors.detach().cpu(),
                    clean_jsd.detach().cpu(),
                    clean_sum_of_logs.detach().cpu(),
                ]
            )
            adv_X_train = torch.vstack(
                [
                    adv_recon_loss.detach().cpu(),
                    adv_veracity.detach().cpu(),
                    adv_dis_max_class_posteriors.detach().cpu(),
                    adv_victim_max_class_posteriors.detach().cpu(),
                    adv_jsd.detach().cpu(),
                    adv_sum_of_logs.detach().cpu(),
                ]
            )
            X_train = torch.vstack([clean_X_train.T, adv_X_train.T]).numpy()
            y_train = torch.vstack(
                [
                    torch.zeros(clean_X_train.shape[1], 1),
                    torch.ones(adv_X_train.shape[1], 1),
                ]
            ).numpy()
            training_data = {
                "X": X_train,
                "y": y_train,
            }
            torch.save(training_data, training_file_name)
        else:
            training_data = torch.load(training_file_name)
            X_train, y_train = training_data["X"], training_data["y"]

        ### TRAIN
        clf_file_name = os.path.join(
            BASE_PATH, "ACGAN", f"ACGAN_{dataset_name}_{attack_method}_XGB.json"
        )
        gridsearch_hist_file_name = os.path.join(
            BASE_PATH,
            "ACGAN",
            f"ACGAN_{dataset_name}_{attack_method}_GridSearchCV.hist",
        )
        best_params_file_name = os.path.join(
            BASE_PATH, "ACGAN", f"ACGAN_{dataset_name}_{attack_method}_best_params.json"
        )
        base_xgb_model = XGBClassifier(objective="binary:logistic", eval_metric="auc")
        gridsearch_clf = GridSearchCV(
            base_xgb_model,
            {
                "max_depth": [1, 2, 3, 4, 5],
                "n_estimators": [2, 5, 10, 50],
            },
            verbose=1,
        )
        gridsearch_clf.fit(X_train, y_train)
        print(
            f"Best CV score: {gridsearch_clf.best_score_} - Best params: {gridsearch_clf.best_params_}"
        )
        torch.save(gridsearch_clf.cv_results_, gridsearch_hist_file_name)
        with open(best_params_file_name, "w") as hist_fp:
            json.dump(gridsearch_clf.best_params_, hist_fp)
        best_xgb_clf = XGBClassifier(
            objective="binary:logistic",
            eval_metric="auc",
            **gridsearch_clf.best_params_,
        )
        best_xgb_clf.fit(X_train, y_train)
        print(
            f"{dataset_name} - {attack_method} XGB clf Score: {best_xgb_clf.score(X_train, y_train)}"
        )
        best_xgb_clf.save_model(clf_file_name)
        print(best_xgb_clf.feature_importances_)
        torch.cuda.empty_cache()


ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [00:43<00:00, 11.40it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [00:44<00:00, 11.31it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [00:44<00:00, 11.32it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [00:44<00:00, 11.33it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [00:44<00:00, 11.31it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [00:44<00:00, 11.31it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [00:44<00:00, 11.30it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [00:44<00:00, 11.31it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [00:44<00:00, 11.34it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [00:44<00:00, 11.31it/s]
10it [07:21, 44.19s/it]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [00:43<00:00, 11.40it/s]
ACGAN Purify Target

Fitting 5 folds for each of 20 candidates, totalling 100 fits
Best CV score: 0.9970000000000001 - Best params: {'max_depth': 1, 'n_estimators': 50}
CIFAR10 - CW XGB clf Score: 0.9995
[0.         0.00584522 0.86298263 0.09538148 0.00707443 0.02871611]


ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [00:43<00:00, 11.41it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [00:44<00:00, 11.28it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [00:44<00:00, 11.29it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [00:44<00:00, 11.29it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [00:44<00:00, 11.28it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [00:44<00:00, 11.28it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [00:44<00:00, 11.27it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [00:44<00:00, 11.27it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [00:44<00:00, 11.30it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [00:44<00:00, 11.28it/s]
10it [07:23, 44.31s/it]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [00:43<00:00, 11.37it/s]
ACGAN Purify Target

Fitting 5 folds for each of 20 candidates, totalling 100 fits
Best CV score: 0.9404999999999999 - Best params: {'max_depth': 1, 'n_estimators': 50}
CIFAR10 - FGSM XGB clf Score: 0.948
[0.01429752 0.01099937 0.12091529 0.71401864 0.0177051  0.12206416]


ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [00:43<00:00, 11.41it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [00:44<00:00, 11.28it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [00:44<00:00, 11.28it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [00:44<00:00, 11.29it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [00:44<00:00, 11.28it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [00:44<00:00, 11.27it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [00:44<00:00, 11.27it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [00:44<00:00, 11.27it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [00:44<00:00, 11.30it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [00:44<00:00, 11.27it/s]
10it [07:23, 44.32s/it]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [00:43<00:00, 11.39it/s]
ACGAN Purify Target

Fitting 5 folds for each of 20 candidates, totalling 100 fits
Best CV score: 0.9935 - Best params: {'max_depth': 1, 'n_estimators': 50}
CIFAR10 - PGD XGB clf Score: 0.998
[0.         0.0088792  0.5938653  0.09941047 0.02393176 0.2739133 ]


ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [01:05<00:00,  7.69it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [01:05<00:00,  7.63it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [01:05<00:00,  7.63it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [01:05<00:00,  7.63it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [01:05<00:00,  7.63it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [01:05<00:00,  7.63it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [01:05<00:00,  7.63it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [01:05<00:00,  7.63it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [01:05<00:00,  7.60it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [01:09<00:00,  7.23it/s]
10it [10:59, 65.93s/it]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [01:04<00:00,  7.77it/s]
ACGAN Purify Target

Fitting 5 folds for each of 20 candidates, totalling 100 fits
Best CV score: 0.9515 - Best params: {'max_depth': 3, 'n_estimators': 50}
TIMGNET - CW XGB clf Score: 0.9725
[0.01939472 0.02565976 0.61750895 0.22071359 0.0800981  0.03662479]


ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [01:04<00:00,  7.79it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [01:04<00:00,  7.72it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [01:04<00:00,  7.71it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [01:04<00:00,  7.72it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [01:05<00:00,  7.68it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [01:04<00:00,  7.70it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [01:05<00:00,  7.63it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [01:05<00:00,  7.67it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [01:05<00:00,  7.63it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [01:05<00:00,  7.61it/s]
10it [10:50, 65.10s/it]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [01:05<00:00,  7.66it/s]
ACGAN Purify Target

Fitting 5 folds for each of 20 candidates, totalling 100 fits
Best CV score: 0.984 - Best params: {'max_depth': 1, 'n_estimators': 50}
TIMGNET - FGSM XGB clf Score: 0.987
[0.00441428 0.00552707 0.11697416 0.6908551  0.12165723 0.06057219]


ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [01:04<00:00,  7.70it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [01:05<00:00,  7.63it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [01:05<00:00,  7.64it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [01:05<00:00,  7.63it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [01:05<00:00,  7.64it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [01:05<00:00,  7.63it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [01:05<00:00,  7.63it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [01:05<00:00,  7.63it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [01:05<00:00,  7.63it/s]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [01:05<00:00,  7.64it/s]
10it [10:54, 65.49s/it]
ACGAN Purify Targeted: Optimizing Z: 100%|██████████| 500/500 [01:05<00:00,  7.66it/s]
ACGAN Purify Target

Fitting 5 folds for each of 20 candidates, totalling 100 fits
Best CV score: 0.959 - Best params: {'max_depth': 3, 'n_estimators': 50}
TIMGNET - PGD XGB clf Score: 0.9875
[0.01349808 0.01824572 0.12875135 0.07903984 0.6781514  0.08231358]
