In [None]:
#Imports & setup
import torch
import torch.nn.functional as F
import numpy as np
import glob
from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import segmentation_models_pytorch as smp

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
#Load trained base model
NUM_CLASSES = 5

model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights=None,
    in_channels=3,
    classes=NUM_CLASSES
).to(device)

model.load_state_dict(
    torch.load("models/unet_resnet34.pth", map_location=device)
)
model.eval()
print("âœ… Base model loaded")

In [None]:
#Dataset & validation loader
class MapDataset(Dataset):
    def __init__(self, img_dir, lbl_dir):
        self.img_paths = sorted(glob.glob(img_dir + "/*.jpg"))
        self.lbl_paths = sorted(glob.glob(lbl_dir + "/*.txt"))

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img = Image.open(self.img_paths[idx]).convert("RGB").resize((256,256))
        arr = np.array(img) / 255.0
        arr = arr.transpose(2,0,1).astype("float32")
        label = int(open(self.lbl_paths[idx]).read().strip())
        return torch.tensor(arr), torch.tensor(label)

In [None]:
VAL_DIR = "data/val_1"
val_ds = MapDataset(f"{VAL_DIR}/images", f"{VAL_DIR}/labels")
val_dl = DataLoader(val_ds, batch_size=1, shuffle=False)

In [None]:
#FGSM attack
def fgsm_attack(model, x, y, eps):
    x_adv = x.clone().detach().to(device)
    x_adv.requires_grad = True

    out = model(x_adv).mean(dim=(2,3))
    loss = F.cross_entropy(out, y)

    model.zero_grad()
    loss.backward()

    adv = x_adv + eps * x_adv.grad.sign()
    adv = torch.clamp(adv, 0, 1)
    return adv.detach()

In [None]:
#FGSM evaluation
def eval_fgsm(model, dataloader, eps):
    correct, total = 0, 0
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        adv = fgsm_attack(model, x, y, eps)
        with torch.no_grad():
            pred = model(adv).mean(dim=(2,3)).argmax(dim=1)
        correct += (pred == y).sum().item()
        total += 1
    return 100 * correct / total

In [None]:
#PGD attack
def pgd_attack(model, x, y, eps=0.04, alpha=0.005, steps=20):
    x_orig = x.clone().detach().to(device)
    x_adv = x_orig.clone()

    for _ in range(steps):
        x_adv.requires_grad = True
        out = model(x_adv).mean(dim=(2,3))
        loss = F.cross_entropy(out, y)
        model.zero_grad()
        loss.backward()

        x_adv = x_adv + alpha * x_adv.grad.sign()
        perturb = torch.clamp(x_adv - x_orig, -eps, eps)
        x_adv = torch.clamp(x_orig + perturb, 0, 1).detach()

    return x_adv

In [None]:
#PGD evaluation
def eval_pgd(model, dataloader, eps):
    correct, total = 0, 0
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        adv = pgd_attack(model, x, y, eps)
        with torch.no_grad():
            pred = model(adv).mean(dim=(2,3)).argmax(dim=1)
        correct += (pred == y).sum().item()
        total += 1
    return 100 * correct / total

In [None]:
#Run attacks
EPS_LIST = [0.0, 0.02, 0.04, 0.06, 0.08]

print("FGSM Results")
for eps in EPS_LIST:
    print(eps, eval_fgsm(model, val_dl, eps))

print("PGD Results")
for eps in EPS_LIST:
    print(eps, eval_pgd(model, val_dl, eps))