In [5]:
import sys
import os
import numpy as np
import torch
from art.estimators.classification import PyTorchClassifier
from art.attacks.evasion import FastGradientMethod
import time
from pathlib import Path

sys.path.append(os.path.abspath("../utils"))
import myextensions

In [6]:
torch.manual_seed(123)
np.random.seed(123)

model = myextensions.get_vgg()

DATASET_PATH = "../../dataset/filtered_correct_vgg16/"
MODE = "fgsm"

SAVE_DIR_PATH = f"/home/cat/uni/bakis/inputs/{MODE}"

Path(f"{SAVE_DIR_PATH}/fail").mkdir(parents=True, exist_ok=True)
Path(f"{SAVE_DIR_PATH}/success").mkdir(parents=True, exist_ok=True)



In [7]:
loss_fn = torch.nn.CrossEntropyLoss()
classifier = PyTorchClassifier(
    model=model,
    loss=loss_fn,
    input_shape=(3, 224, 224),
    nb_classes=1000,
    clip_values=(0.0, 1.0),
)

attack = FastGradientMethod(estimator=classifier, eps=0.1, targeted= False)


In [8]:
timings = []

for filename in os.listdir(DATASET_PATH):
    full_path = os.path.join(DATASET_PATH, filename)
    _, input_image_np = myextensions.get_image(full_path, myextensions.PREPROCESS_ATTACK)
    
    start_time = time.perf_counter()
    x_adv = attack.generate(x=input_image_np, y = None)
    elapsed = time.perf_counter() - start_time

    input_pred = np.argmax(classifier.predict(input_image_np), axis=1)
    adv_pred = np.argmax(classifier.predict(x_adv), axis=1)

    if (input_pred == adv_pred):
        save_path = f'{SAVE_DIR_PATH}/fail'
    else:
        save_path = f'{SAVE_DIR_PATH}/success'

    full_save_path = os.path.join(save_path, filename)
    myextensions.save_attack_image(x_adv, full_save_path)

    timings.append(elapsed)


In [9]:
def print_time(timings, filenum):
    timings_np = np.array(timings)
    with open(f'{SAVE_DIR_PATH}/time_statistics{filenum}.txt', "w") as f:
        f.write("=== Attribution Time Stats ===\n")
        f.write(f"Total images:      {len(timings)}\n")
        f.write(f"Average time:      {timings_np.mean():.4f} s\n")
        f.write(f"Standard deviation:{timings_np.std():.4f} s\n")
        f.write(f"Minimum time:      {timings_np.min():.4f} s\n")
        f.write(f"Maximum time:      {timings_np.max():.4f} s\n")
        f.write(f"Epsilon: {attack.eps}\n")    

print_time(timings, filenum=0)