In [None]:
import torch
import torch.nn.functional as F
import segmentation_models_pytorch as smp

import numpy as np
import cv2
import io
import matplotlib.pyplot as plt
import pandas as pd

from PIL import Image
from io import BytesIO

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

In [None]:
NUM_CLASSES = 5
IMG_SIZE = 256

classes = ["LULC", "River", "Road", "Settlement", "Soil"]

def load_model(path):
    model = smp.Unet(
        encoder_name="resnet34",
        encoder_weights=None,
        in_channels=3,
        classes=NUM_CLASSES
    ).to(DEVICE)
    model.load_state_dict(torch.load(path, map_location=DEVICE))
    model.eval()
    return model

base_model = load_model("models/unet_resnet34.pth")
adv_model  = load_model("models/unet_resnet34_adv.pth")

print("âœ… Base and adversarially trained models loaded")

In [None]:
def bytes_to_tensor(raw_bytes):
    img = Image.open(io.BytesIO(raw_bytes)).convert("RGB")
    img = img.resize((IMG_SIZE, IMG_SIZE))

    arr = np.array(img) / 255.0
    tensor = torch.tensor(
        arr.transpose(2,0,1),
        dtype=torch.float32
    ).unsqueeze(0).to(DEVICE)

    return img, tensor

In [None]:
def fgsm_attack(model, x, y, eps):
    x_adv = x.clone().detach().to(DEVICE)
    x_adv.requires_grad = True

    out = model(x_adv).mean(dim=(2,3))
    loss = F.cross_entropy(out, y)

    model.zero_grad()
    loss.backward()

    adv = x_adv + eps * x_adv.grad.sign()
    return torch.clamp(adv, 0, 1).detach()


def pgd_attack(model, x, y, eps=0.04, alpha=0.005, steps=15):
    x_orig = x.clone().detach().to(DEVICE)
    x_adv = x_orig.clone()

    for _ in range(steps):
        x_adv.requires_grad = True

        out = model(x_adv).mean(dim=(2,3))
        loss = F.cross_entropy(out, y)

        model.zero_grad()
        loss.backward()

        x_adv = x_adv + alpha * x_adv.grad.sign()
        perturb = torch.clamp(x_adv - x_orig, -eps, eps)
        x_adv = torch.clamp(x_orig + perturb, 0, 1).detach()

    return x_adv

In [None]:
def jpeg_def(arr, quality=30):
    im = Image.fromarray(arr)
    buf = BytesIO()
    im.save(buf, format="JPEG", quality=quality)
    buf.seek(0)
    return np.array(Image.open(buf).convert("RGB"))

def gaussian_def(arr):
    return cv2.GaussianBlur(arr, (5,5), 0)

def median_def(arr):
    return cv2.medianBlur(arr, 5)

def bitdepth_def(arr):
    return ((arr >> 4) << 4)

def hybrid_defense(arr):
    arr = jpeg_def(arr)
    arr = gaussian_def(arr)
    arr = median_def(arr)
    arr = bitdepth_def(arr)
    return arr

In [None]:
def classify(model, x):
    with torch.no_grad():
        out = model(x).mean(dim=(2,3))
        probs = torch.softmax(out, dim=1)[0]
        idx = probs.argmax().item()
    return classes[idx], float(probs[idx])

In [None]:
# Attack strength
EPS = 0.04

# NOTE: replace with true label if known
true_label = torch.tensor([0], device=DEVICE)

# Convert decrypted image
clean_img, x = bytes_to_tensor(decrypted_bytes)

# Clean prediction
clean_pred, clean_conf = classify(base_model, x)

# FGSM attack
adv_fgsm = fgsm_attack(base_model, x, true_label, EPS)
fgsm_pred_no, fgsm_conf_no = classify(base_model, adv_fgsm)
fgsm_pred_def, fgsm_conf_def = classify(adv_model, adv_fgsm)

# PGD attack
adv_pgd = pgd_attack(base_model, x, true_label, EPS)
pgd_pred_no, pgd_conf_no = classify(base_model, adv_pgd)

# Hybrid defense for PGD
arr = (adv_pgd[0].permute(1,2,0).cpu().numpy() * 255).astype(np.uint8)
arr_def = hybrid_defense(arr)

t_def = torch.tensor(
    arr_def/255.0, dtype=torch.float32
).permute(2,0,1).unsqueeze(0).to(DEVICE)

pgd_pred_def, pgd_conf_def = classify(base_model, t_def)

In [None]:
plt.figure(figsize=(14,8))

titles = [
    f"Clean\n{clean_pred} ({clean_conf:.2f})",
    f"FGSM Attack\n{fgsm_pred_no} ({fgsm_conf_no:.2f})",
    f"FGSM Defense\n{fgsm_pred_def} ({fgsm_conf_def:.2f})",
    f"PGD Attack\n{pgd_pred_no} ({pgd_conf_no:.2f})",
    f"PGD Defense\n{pgd_pred_def} ({pgd_conf_def:.2f})"
]

images = [
    clean_img,
    adv_fgsm[0].permute(1,2,0).cpu().numpy(),
    adv_fgsm[0].permute(1,2,0).cpu().numpy(),
    adv_pgd[0].permute(1,2,0).cpu().numpy(),
    arr_def
]

for i in range(5):
    plt.subplot(2,3,i+1)
    plt.imshow(images[i])
    plt.title(titles[i])
    plt.axis("off")

plt.tight_layout()
plt.show()

In [None]:
df = pd.DataFrame({
    "Stage": [
        "Clean",
        "FGSM Attack",
        "FGSM Defense (Adv Model)",
        "PGD Attack",
        "PGD Defense (Hybrid)"
    ],
    "Prediction": [
        clean_pred,
        fgsm_pred_no,
        fgsm_pred_def,
        pgd_pred_no,
        pgd_pred_def
    ],
    "Confidence": [
        clean_conf,
        fgsm_conf_no,
        fgsm_conf_def,
        pgd_conf_no,
        pgd_conf_def
    ]
})

print("\n===== CLASSIFICATION COMPARISON =====\n")
df