# Adversarial Training Defense for EuroSAT

### 1. Setup environment and imports

In [None]:

import sys, os, torch, torch.nn as nn, torch.optim as optim
from torchvision import models
from tqdm import tqdm
from typing import Optional, Tuple
from dataclasses import dataclass


sys.path.append(os.path.abspath(".."))
from src.data.dataloader import get_dataloaders
from src.attacks.utils import extract_mean_std
from src.training.evaluate import evaluate_model
from src.attacks.pgd import pgd_attack_batch
from src.attacks.metrics_eval import evaluate_adv, plot_confusion_matrix  



@dataclass
class PGDConfig:
    eps: float = 0.005
    alpha: Optional[float] = None   
    iters: int = 50
    small_step_fraction: float = 0.2
    grad_mask_fraction: float = 0.25
    grad_blur_sigma: float = 1.0
    smooth_perturb_sigma: float = 1.0
    random_dither: bool = True
    dither_scale: float = 0.5
    device: Optional[torch.device] = None



# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üöÄ Using device: {device}")



üöÄ Using device: cpu


### 2. Configuration

In [2]:

model_name = "resnet50"
data_dir = "../data/raw"
checkpoint_dir = "../experiments/checkpoints"
checkpoint_path = os.path.join(checkpoint_dir, f"{model_name}_e2.pth")
adv_checkpoint_path = os.path.join(checkpoint_dir, f"{model_name}_adv_trained.pth")

epochs = 20
batch_size = 32
lr = 5e-5
epsilon = 0.005 
alpha = epsilon / 10
pgd_steps = 10



### 3. Load dataset

In [None]:

train_loader, val_loader, test_loader, classes = get_dataloaders(
    data_dir=data_dir, batch_size=batch_size
)
num_classes = len(classes)
print(f"‚úÖ Loaded {num_classes} classes.")

### 4. Load pretrained model

In [None]:
if model_name.lower() == "resnet50":
    model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
elif model_name.lower() == "resnet18":
    model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
else:
    raise ValueError(f"Unsupported model_name: {model_name}")

model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

print(f"üì¶ Loaded pretrained weights from {checkpoint_path}")

### 5. Adversarial Training Loop

In [None]:

def train_adversarial(
    model,
    train_loader,
    criterion,
    optimizer,
    device,
    eps=0.005,
    alpha=None,
    steps=10,
    epochs=20,
    grad_mask_fraction=0.25,
    smooth_sigma=1.0,
    dither_scale=0.3
):
    """
    Adversarial training loop using the SAME PGD implementation
    used in evaluate_pgd(), ensuring consistency between training and testing.
    """

    # --------------------------------------------------
    # Use same normalization parameters as in evaluation
    # --------------------------------------------------
    mean_t, std_t = extract_mean_std(train_loader)
    if alpha is None:
        alpha = eps * 0.2  

    # --------------------------------------------------
    # Build PGD configuration (same as evaluate_pgd)
    # --------------------------------------------------
    pgd_conf = PGDConfig(
        eps=eps,
        alpha=alpha,
        iters=steps,
        small_step_fraction=0.2,
        grad_mask_fraction=grad_mask_fraction,
        grad_blur_sigma=1.0,
        smooth_perturb_sigma=smooth_sigma,
        random_dither=True,
        dither_scale=dither_scale,
        device=device,
    )

    # --------------------------------------------------
    # Training Loop
    # --------------------------------------------------
    model.to(device)
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        total, correct = 0, 0

        loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False)

        for images, labels in loop:
            images, labels = images.to(device), labels.to(device)

            # -----------------------------------------
            # Generate PGD adversarial samples (batch)
            # -----------------------------------------
            adv_images = pgd_attack_batch(
                model=model,
                images=images,
                labels=labels,
                config=pgd_conf,
                targeted=False,
                target_labels=None,
            )

            # -----------------------------------------
            # Combine clean + adversarial samples
            # -----------------------------------------
            combined_images = torch.cat([images, adv_images], dim=0)
            combined_labels = torch.cat([labels, labels], dim=0)

            optimizer.zero_grad()
            outputs = model(combined_images)
            loss = criterion(outputs, combined_labels)
            loss.backward()
            optimizer.step()

            # -----------------------------------------
            # Metrics
            # -----------------------------------------
            running_loss += loss.item()
            _, preds = outputs.max(1)
            correct += (preds == combined_labels).sum().item()
            total += combined_labels.size(0)

        print(
            f"Epoch [{epoch+1}/{epochs}] "
            f"| Loss: {running_loss/len(train_loader):.4f} "
            f"| Acc: {100 * correct/total:.2f}%"
        )

    return model

In [None]:
print("\nüß† Starting Adversarial Training...")

adv_model = train_adversarial(
    model,
    train_loader,
    criterion,
    optimizer,
    device=device,
    eps=epsilon,
    alpha=alpha,
    steps=pgd_steps,
    epochs=epochs
)

torch.save(adv_model.state_dict(), adv_checkpoint_path)
print(f"\n‚úÖ Adversarially trained model saved at: {adv_checkpoint_path}")


### 6. Evaluate on Clean Test Data

In [None]:
print("\nüìä Evaluating on clean test set...")
metrics_clean = evaluate_model(
    model_path=adv_checkpoint_path,
    data_dir=data_dir,
    batch_size=batch_size,
    model_name=model_name,
    device=device
)

print("\nüìà Test Set Performance:")
print(f"Accuracy:  {metrics['accuracy']*100:.2f}%")
print(f"Loss:      {metrics['loss']:.4f}")
print(f"Precision: {metrics['precision']:.4f}")
print(f"Recall:    {metrics['recall']:.4f}")
print(f"F1-score:  {metrics['f1']:.4f}")

print("\nüîç Classification metrics per class:\n")
print(metrics["classification_report"])

### 7. Evaluate on Adversarial Test Data

In [None]:
print("\n‚öîÔ∏è Evaluating on PGD adversarial test set...")
adv_dir = '../data/pgd'

res_eval = evaluate_adv(
    model_path=adv_model,
    adv_folder=adv_dir,
    device=device,
    data_dir=data_dir,
    batch_size=batch_size,
    mean_std_sample_size=2000,
    image_pattern="*.tif"
)

print(f"Num images: {res_eval['num_images']}")

print(f"Accuracy: {res_eval['accuracy']*100:.2f}%")
print(f"Loss: {res_eval['loss']:.4f}")
print(f"Precision: {res_eval['precision']:.4f}")
print(f"Recall: {res_eval['recall']:.4f}")
print(f"F1-score: {res_eval['f1']:.4f}")

print("\nClassification metrics per category:\n\n", res_eval["classification_report"])

- Confusion Matrix

In [None]:
plot_confusion_matrix(metrics_adv['confusion_matrix'], metrics_adv['class_names'], normalize=True)
