# PGD attacks

This notebook tests **EuroSAT ResNet18 model** on images modified with **PGD attacks**.

### 1. Setup environment and imports

In [6]:
import sys, os, re
sys.path.append(os.path.abspath(".."))

if not hasattr(sys, "frozen"):
    os.environ["PYDEVD_DISABLE_FILE_VALIDATION"] = "1"
import random
import torch
import torch.nn as nn
from torchvision import models
import matplotlib.pyplot as plt
from PIL import Image
import glob
import numpy as np
import tifffile
from skimage.transform import resize

from src.training.simple_cnn import SimpleCNN
from src.data.dataloader import get_dataloaders
from src.attacks.evaluate import evaluate_pgd             
from src.attacks.metrics_eval import evaluate_adv, plot_confusion_matrix  
from src.attacks.utils import select_rgb_bands, gdal_style_scale

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model_name = "resnet50" # or resnet18 simplecnn
data_dir = '../data/raw'
batch_size = 64

Using device: cpu


- Get path

In [7]:
if model_name == "simplecnn":
        checkpoint_path = "../experiments/checkpoints/simplecnn_best.pth"
elif model_name == "resnet18":
        checkpoint_path = "../experiments/checkpoints/resnet18_best.pth"
elif model_name == "resnet50":
        checkpoint_path = "../experiments/checkpoints/resnet50_e2.pth"


### 2. Load data and model

- Load dataloaders

In [8]:
train_loader, val_loader, test_loader, classes = get_dataloaders(data_dir=data_dir, batch_size=batch_size)
print(f'Loaded {len(classes)} classes: {classes}')

num_classes = len(classes)

Loaded 4 classes: ['AnnualCrop', 'Forest', 'Residential', 'River']


- Load model

In [9]:
if model_name.lower() == "resnet18":
    model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
elif model_name.lower() == "resnet50":
    model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
elif model_name.lower() == "simplecnn":
    model = SimpleCNN(num_classes=num_classes)
else:
    raise ValueError(f"Unsupported model_name: {model_name}")


- Load trained weights

In [10]:
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.to(device)
model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

### 3. Run PGD attacks and save adversarial images 

In [11]:
out_dir = '../data/pgd'

eps = 0.005
alpha = eps / 10
iters = 200

res = evaluate_pgd(
    model=model,
    dataloader=test_loader,
    device=device,
    eps=eps,
    alpha=alpha,
    iters=iters,
    out_dir=out_dir,
    save_every=40,
    max_save=64,
    targeted=False,
    target_class=None
)

print('PGD run result:', res)


  mean = torch.tensor(t.mean).view(-1, 1, 1)
  std = torch.tensor(t.std).view(-1, 1, 1)
                                                                       

PGD run result: {'clean_acc': 0.998840579710145, 'adv_acc': 0.991304347826087, 'clean_loss': 0.005903373479951119, 'adv_loss': 0.021152364750919136, 'eps': 0.005, 'saved': 0, 'out_dir': '/Users/joseantonioruizheredia/Code/Python/ml-satellite-adv/data/pgd'}




### 4. Evaluate saved adversarial images for each epsilon

- Evaluation of metrics

In [13]:
print(f"\n=== Model Evaluation on Adversarial Images ===")

metrics_adv = evaluate_adv(
    adv_folder=out_dir,
    model_path=checkpoint_path,
    data_dir=data_dir,
    batch_size=batch_size,
    model_name=model_name,
    device=device,
    mean_std_sample_size=2000,
    image_pattern="*.tif"
)

print(f"Num images: {metrics_adv['num_images']}")

print(f"Accuracy: {metrics_adv['accuracy']*100:.2f}%")
print(f"Loss: {metrics_adv['loss']:.4f}")
print(f"Precision: {metrics_adv['precision']:.4f}")
print(f"Recall: {metrics_adv['recall']:.4f}")
print(f"F1-score: {metrics_adv['f1']:.4f}")

print("\nClassification metrics per category:\n\n", metrics_adv["classification_report"])



=== Model Evaluation on Adversarial Images ===


FileNotFoundError: No images matching *.tif found in ../data/pgd

- Confusion Matrix

In [None]:
plot_confusion_matrix(metrics_adv['confusion_matrix'], metrics_adv['class_names'], normalize=True)

- Show a small sample of images 

In [None]:
adv_folder = '../data/pgd'
raw_folder = '../data/raw'
adv_paths = sorted(glob.glob(os.path.join(adv_folder, "*.tif")))

pairs = []
for adv_p in adv_paths:
    base = os.path.basename(adv_p)
    orig_base = re.sub(r"_true\d+_pred\d+\.tif$", ".tif", base)
    orig_candidates = glob.glob(os.path.join(raw_folder, "**", orig_base), recursive=True)
    if not orig_candidates:
        continue
    pairs.append((adv_p, orig_candidates[0]))

if len(pairs) == 0:
    print("No matching pairs found.")
else:
    random.shuffle(pairs)
    pairs = pairs[:10]
    fig, axs = plt.subplots(len(pairs), 3, figsize=(12, 4 * len(pairs)), dpi=400)
    axs = np.atleast_2d(axs)

    for i, (adv_p, orig_p) in enumerate(pairs):
        adv = tifffile.imread(adv_p)
        orig = tifffile.imread(orig_p)
        raw_diff = np.abs(adv.astype(np.float32) - orig.astype(np.float32))

        if adv.ndim == 3 and adv.shape[0] in [3,4,13]:
            adv = np.transpose(adv, (1,2,0))
        if orig.ndim == 3 and orig.shape[0] in [3,4,13]:
            orig = np.transpose(orig, (1,2,0))

        adv_rgb = select_rgb_bands(adv)
        orig_rgb = select_rgb_bands(orig)
        if adv_rgb.shape != orig_rgb.shape:
            orig_rgb = resize(orig_rgb, adv_rgb.shape, preserve_range=True, anti_aliasing=True)

        adv_disp = gdal_style_scale(adv_rgb)
        orig_disp = gdal_style_scale(orig_rgb)

        class_name = os.path.basename(os.path.dirname(orig_p))

        ax0, ax1, ax2 = axs[i]
        
        # Original image
        ax0.imshow(orig_disp)
        ax0.set_title(f"Original (normalized)\nClass: {class_name}", fontsize=9)
        ax0.axis("off")
        
        # Adversarial image
        ax1.imshow(adv_disp)
        ax1.set_title(f"Adversarial (normalized)\nClass: {class_name}", fontsize=9)
        ax1.axis("off")

        # Difference heatmap
        diff = np.mean(np.abs(adv_disp - orig_disp), axis=2)
        p99 = np.percentile(diff, 90)
        diff_clipped = np.clip(diff / (p99 + 1e-12), 0, 10)
        gamma = 0.5  
        diff_vis = diff_clipped ** gamma
        
        im = ax2.imshow(diff_vis, cmap='hot_r', interpolation='nearest')
        m_true = re.search(r"_true(\d+)", adv_p)
        m_pred = re.search(r"_pred(\d+)", adv_p)
        true_label = m_true.group(1) if m_true else "?"
        pred_label = m_pred.group(1) if m_pred else "?"

        ax2.set_title(f"Diff heatmap\nTrue: {true_label}, Pred: {pred_label}", fontsize=9)
        ax2.axis("off")

        ax2.text(
            0.9, 0.1,
            f"Normalize diff: {diff.mean():.4f}\nRaw diff: {raw_diff.mean():.4f}",
            color='white',
            fontsize=9,
            ha='right',
            va='bottom',
            transform=ax2.transAxes,
            bbox=dict(facecolor='black', alpha=0.8, pad=2)
        )

        cbar = fig.colorbar(im, ax=ax2, fraction=0.046, pad=0.02)
        cbar.ax.tick_params(labelsize=8)

    plt.tight_layout()
    plt.show()