In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import precision_score, recall_score, f1_score
from yolo_threat import YoloThreat

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def fgsm_attack(model, images, labels, epsilon):
    images.requires_grad = True
    outputs = model(images)
    loss = F.binary_cross_entropy_with_logits(outputs.squeeze(-1), labels.float())
    model.zero_grad()
    loss.backward()
    perturbed_images = images + epsilon * images.grad.sign()
    return torch.clamp(perturbed_images, 0, 1)

In [None]:
def pgd_attack(model, images, labels, epsilon=0.3, alpha=2 / 255, iters=40):
    images = images.clone().detach().to(device)
    labels = labels.to(device)
    loss_fn = nn.BCEWithLogitsLoss()

    ori_images = images.clone().detach()

    for i in range(iters):
        images.requires_grad = True
        outputs = model(images)
        loss = loss_fn(outputs.squeeze(-1), labels.float())
        model.zero_grad()
        loss.backward()

        adv_images = images + alpha * images.grad.sign()
        eta = torch.clamp(adv_images - ori_images, min=-epsilon, max=epsilon)
        images = torch.clamp(ori_images + eta, min=0, max=1).detach_()

In [None]:
def test(model_t, X, y):
    model_t.eval()
    with torch.no_grad():
        y_pred = model_t.forward(X)
        y_pred = torch.sigmoid(y_pred)
        y_pred = (y_pred > 0.5).float().reshape(-1)
        accuracy = np.mean((y_pred.cpu() == y.cpu()).numpy())

        # Calculate precision, recall, and F1 score
        y_true = y.cpu().numpy()
        y_pred = y_pred.cpu().numpy()
        precision = precision_score(y_true, y_pred, zero_division=0)
        recall = recall_score(y_true, y_pred, zero_division=0)
        f1 = f1_score(y_true, y_pred, zero_division=0)

        return accuracy, precision, recall, f1

In [None]:
def evaluate_robustness(model, X, y, epsilons, attack_fn, attack_name):
    results = {}
    for epsilon in epsilons:
        print(f"Evaluating {attack_name} attack with epsilon={epsilon:.3f}")
        if attack_name == "Normal":
            perturbed_X = X  # No perturbation for normal evaluation
        else:
            perturbed_X = attack_fn(model, X, y, epsilon=epsilon)  # Updated parameter for PGD

        acc, precision, recall, f1 = test(model, perturbed_X, y)
        results[epsilon] = {
            "accuracy": acc,
            "precision": precision,
            "recall": recall,
            "f1_score": f1
        }
        print(
            f"Accuracy: {acc:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}"
        )
    return results

In [None]:
def plot_accuracy(results_dict):
    plt.figure(figsize=(8, 6))
    for attack_name, results in results_dict.items():
        epsilons = list(results.keys())
        accuracies = [metrics["accuracy"] for metrics in results.values()]
        plt.plot(epsilons, accuracies, marker='o', label=attack_name)

    plt.xlabel("Epsilon")
    plt.ylabel("Accuracy")
    plt.title("Adversarial Robustness: Accuracy vs Epsilon")
    plt.legend()
    plt.grid(True)
    plt.show()

In [None]:
Xtrain = torch.load('../data/danger/raw/train.pt').to(device)
ytrain = torch.load('../data/danger/raw/train_labels.pt').to(device)
Xtest = torch.load('../data/danger/raw/test.pt').to(device)
ytest = torch.load('../data/danger/raw/test_labels.pt').to(device)

In [None]:
model = YoloThreat.load_new_model().to(device)
model.load_state_dict(torch.load('trained_model.pt', map_location=device))
model.eval()

In [None]:
epsilons = [0, 0.05, 0.1, 0.2, 0.3]

In [None]:
print("Evaluating normal model performance...")
# accuracy, precision, recall, f1 = test(model, Xtest, ytest)
# normal_results = {0: {"accuracy": accuracy, "precision": precision, "recall": recall, "f1_score": f1}}
normal_results = evaluate_robustness(model, Xtest, ytest, [0], lambda m, x, y, epsilon: x, "Normal")

In [None]:
fgsm_results = evaluate_robustness(model, Xtest, ytest, epsilons, fgsm_attack, "FGSM")

In [None]:
pgd_results = evaluate_robustness(model, Xtest, ytest, epsilons, pgd_attack, "PGD")

In [None]:
results_dict = {"Normal": normal_results, "FGSM": fgsm_results, "PGD": pgd_results}

In [None]:
plot_accuracy(results_dict)