In [3]:
import os, json, ast, requests
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, TensorDataset
from torchvision import models, datasets, transforms
from tqdm import tqdm

# ─── Config ────────────────────────────────────────────────────────────────
PATCH_SIZE = 32
NUM_PATCHES = 10
EPSILON = 0.5
ALPHA = 0.05
STEPS = 100
RESTARTS = 10
MOMENTUM = 0.9
TARGET_CLASS = 270  # fixed target
BATCH_SIZE = 64
MEAN = (0.485, 0.456, 0.406)
STD = (0.229, 0.224, 0.225)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
SAVE_DIR = "AdversarialTestSets"
os.makedirs(SAVE_DIR, exist_ok=True)

# ─── Patch Mask ────────────────────────────────────────────────────────────
def multi_patch_mask(b, H, W, patch_size, num_patches, device):
    mask = torch.zeros(b, 1, H, W, device=device)
    for i in range(b):
        for _ in range(num_patches):
            t = torch.randint(0, H - patch_size + 1, (1,), device=device).item()
            l = torch.randint(0, W - patch_size + 1, (1,), device=device).item()
            mask[i, :, t:t + patch_size, l:l + patch_size] = 1.0
    return mask

def clamp_norm(x, mean, std):
    mean_t = torch.tensor(mean, device=x.device)[:, None, None]
    std_t = torch.tensor(std, device=x.device)[:, None, None]
    min_v = (0.0 - mean_t) / std_t
    max_v = (1.0 - mean_t) / std_t
    return torch.max(torch.min(x, max_v), min_v)

# ─── PGD Patch Attack ──────────────────────────────────────────────────────
def attack_pgd_patch(x, model, target_class):
    model.eval()
    b, c, H, W = x.shape
    x_nat = x.to(DEVICE)
    targets = torch.full((b,), target_class, device=DEVICE, dtype=torch.long)

    best_adv = x_nat.clone()
    best_loss = torch.full((b,), float('inf'), device=DEVICE)

    for _ in range(RESTARTS):
        delta = torch.zeros_like(x_nat).uniform_(-EPSILON, EPSILON).to(DEVICE).requires_grad_(True)
        v = torch.zeros_like(delta)
        for _ in range(STEPS):
            mask = multi_patch_mask(b, H, W, PATCH_SIZE, NUM_PATCHES, DEVICE)
            adv = clamp_norm(x_nat + delta * mask, MEAN, STD)
            logits = model(adv)
            loss = F.cross_entropy(logits, targets, reduction='sum')
            grads = torch.autograd.grad(loss, delta)[0]
            denom = grads.abs().view(b, -1).mean(dim=1).view(b, 1, 1, 1).clamp(min=1e-8)
            g_norm = grads / denom
            v = MOMENTUM * v + g_norm
            delta.data = (delta - ALPHA * v.sign() * mask).clamp(-EPSILON, EPSILON)
            delta.grad = None

        with torch.no_grad():
            final = clamp_norm(x_nat + delta * mask, MEAN, STD)
            losses = F.cross_entropy(model(final), targets, reduction='none')
            better = losses < best_loss
            best_loss[better] = losses[better]
            best_adv[better] = final[better]

    return best_adv.detach()

# ─── Eval ──────────────────────────────────────────────────────────────────
def evaluate(model, loader, ks=(1, 5)):
    model.eval()
    correct = {k: 0 for k in ks}
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            logits = model(x)
            _, pred = logits.topk(max(ks), dim=1)
            for k in ks:
                correct[k] += (pred[:, :k] == y.unsqueeze(1)).any(dim=1).sum().item()
            total += y.size(0)
    return {k: correct[k] / total for k in ks}

# ─── Load Model and Data ───────────────────────────────────────────────────
model = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1).to(DEVICE).eval()

DATASET_PATH = "./test_dataset"
LABELS_JSON = os.path.join(DATASET_PATH, "labels_list.json")

with open(LABELS_JSON, "r") as f:
    entries = json.load(f)
subset_indices = [int(e.split(":")[0]) for e in entries]

url = "https://gist.githubusercontent.com/fnielsen/4a5c94eaa6dcdf29b7a62d886f540372/raw/imagenet_label_to_wordnet_synset.txt"
synset_map = ast.literal_eval(requests.get(url).text)
idx_to_wnid = {int(k): "n" + v["id"].split("-")[0] for k, v in synset_map.items() if int(k) in subset_indices}
wnid_to_global = {wnid: idx for idx, wnid in idx_to_wnid.items()}
valid_wnids = set(wnid_to_global.keys())

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=MEAN, std=STD)
])

class FilteredImageFolder(datasets.ImageFolder):
    def find_classes(self, directory):
        dirs = [d for d in os.listdir(directory)
                if os.path.isdir(os.path.join(directory, d)) and d in valid_wnids]
        dirs.sort()
        return dirs, {cls_name: i for i, cls_name in enumerate(dirs)}

base_ds = FilteredImageFolder(DATASET_PATH, transform=transform)
folder_to_global = {i: wnid_to_global[wnid] for i, wnid in enumerate(base_ds.classes)}

class SubsetDataset(Dataset):
    def __init__(self, base_ds, index_map):
        self.base = base_ds
        self.index_map = index_map
    def __len__(self): return len(self.base)
    def __getitem__(self, idx):
        img, folder_idx = self.base[idx]
        return img, self.index_map[folder_idx]

dataset = SubsetDataset(base_ds, folder_to_global)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)

# ─── Run the Attack ────────────────────────────────────────────────────────
adv_imgs, adv_labels = [], []
for imgs, labels in tqdm(loader, desc="Generating Adversarial Set"):
    adv_batch = attack_pgd_patch(imgs, model, TARGET_CLASS)
    adv_imgs.append(adv_batch.cpu())
    adv_labels.append(labels)

adv_imgs = torch.cat(adv_imgs)
adv_labels = torch.cat(adv_labels)
torch.save((adv_imgs, adv_labels), os.path.join(SAVE_DIR, "adv_test_set3_C1.pt"))
print("✅ Saved to adv_test_set3_C1.pt")

# ─── Final Eval ────────────────────────────────────────────────────────────
eval_loader = DataLoader(TensorDataset(adv_imgs, adv_labels), batch_size=64, shuffle=False)
accs = evaluate(model, eval_loader)
print(f"Top-1 Accuracy: {accs[1]*100:.2f}%, Top-5 Accuracy: {accs[5]*100:.2f}%")


Generating Adversarial Set: 100%|██████████| 8/8 [14:06<00:00, 105.85s/it]


✅ Saved to adv_test_set3_C1.pt
Top-1 Accuracy: 5.20%, Top-5 Accuracy: 19.60%
