# PGD attacks

This notebook tests **EuroSAT ResNet18 model** on images modified with **FGSM attacks**.

### 1. Setup environment and imports

In [None]:

import sys, os
sys.path.append(os.path.abspath(".."))

import torch
import torch.nn as nn
from torchvision import models
import matplotlib.pyplot as plt
from PIL import Image
import glob
import numpy as np

from src.data.dataloader import get_dataloaders, compute_mean_std
from src.attacks.pgd import evaluate_pgd
from src.attacks.evaluate import evaluate_adv, plot_confusion_matrix

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

data_dir = '../data/raw'
checkpoint_path = '../experiments/checkpoints/resnet18_best.pth'
batch_size = 64

### 2. Load data and model

- Load dataloaders

In [None]:
train_loader, val_loader, test_loader, classes = get_dataloaders(data_dir=data_dir, batch_size=batch_size)
print(f'Loaded {len(classes)} classes: {classes}')



- Load model

In [None]:
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
model.fc = nn.Linear(model.fc.in_features, len(classes))
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.to(device)
model.eval()
print(f'Model loaded from {checkpoint_path}')

### 3. Run FGSM attacks and save adversarial images 

In [None]:
eps = 0.00001
out_dir = '../data/adversarial/pgd'
res = evaluate_pgd(
    model=model,
    dataloader=test_loader,
    data_dir=data_dir,
    device=device,
    epsilon=eps,
    out_dir=out_dir,
    save_every=20,
    max_save=64,
    mean_std_sample_size=2000
)
print('PGD run result:', res)


### 4. Evaluate saved adversarial images for each epsilon

In [None]:
folder = f'../data/adversarial/pgd'

if not os.path.isdir(folder):
    print(f"No folder found for eps={eps}: {folder}")

- Evaluation of metrics

In [None]:
print(f"\n=== Evaluation on adversarial folder eps={eps} ===")
metrics_adv = evaluate_adv(
    adv_folder=folder,
    model_path=checkpoint_path,
    data_dir=data_dir,          
    batch_size=batch_size,
    model_name="resnet18",
    device=device,
    mean_std_sample_size=2000
)

print(f"Num images: {metrics_adv['num_images']}")

print(f"Accuracy: {metrics_adv['accuracy']*100:.2f}%")
print(f"Loss: {metrics_adv['loss']:.4f}")
print(f"Precision: {metrics_adv['precision']:.4f}")
print(f"Recall: {metrics_adv['recall']:.4f}")
print(f"F1-score: {metrics_adv['f1']:.4f}")

print("\nClassification metrics per category:\n\n", metrics_adv["classification_report"])



- Confusion Matrix

In [None]:
plot_confusion_matrix(metrics_adv['confusion_matrix'], metrics_adv['class_names'], normalize=True)


- Show a small sample of images 

In [None]:
sample_paths = sorted(glob.glob(os.path.join(folder, '*.png')))[:10]
if len(sample_paths) == 0:
    print("No images to display.")
    
else:
    cols = 5
    rows = (len(sample_paths) + cols - 1) // cols
    plt.figure(figsize=(cols*2, rows*2))
    for i, p in enumerate(sample_paths):
        img = Image.open(p).convert("RGB")
        ax = plt.subplot(rows, cols, i+1)
        ax.axis('off')
        plt.imshow(img)
    plt.suptitle(f'Adversarial examples (eps={eps})')
    plt.show()