### Experiment 2 

In [5]:
import timm
import time 
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision import datasets 
from torch.utils.data import DataLoader
# from medmnist import INFO
import numpy as np
import faiss
import copy
from tqdm import tqdm

from torch.nn.functional import softmax, cosine_similarity
from collections import Counter
import matplotlib.pyplot as plt
import torchvision.transforms.functional as TF
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import os 

import warnings
warnings.filterwarnings("ignore")

In [6]:
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

print("Using device:", device)

Using device: cuda:2


In [7]:
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from PIL import Image

class CustomImageListDataset(torch.utils.data.Dataset):
    def __init__(self, file_list, class_to_idx, transform=None):
        with open(file_list, "r") as f:
            self.samples = [line.strip() for line in f]
        self.transform = transform
        self.class_to_idx = class_to_idx

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path = self.samples[idx]
        class_folder = os.path.basename(os.path.dirname(img_path))
        label = self.class_to_idx.get(class_folder, -1)
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, label


# ---------------- Create a combined class mapping ----------------
root_dir = "dataset/imagenet_tests"
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Collect class mappings from all 10 partitions
combined_class_to_idx = {}
for i in range(1, 11):
    test_dir = os.path.join(root_dir, f"test{i}")
    dataset = datasets.ImageFolder(test_dir, transform=transform)
    combined_class_to_idx.update(dataset.class_to_idx)

print(f"✅ Combined class mapping built: {len(combined_class_to_idx)} total classes")

# ---------------- Load your 1000-image subset ----------------
subset_file = "results/hard_cases_missed_by_mobilenet.txt"
hard_dataset = CustomImageListDataset(subset_file, class_to_idx=combined_class_to_idx, transform=transform)
hard_loader = DataLoader(hard_dataset, batch_size=1, shuffle=False)

print(f"✅ Loaded {len(hard_dataset)} hard samples")

✅ Combined class mapping built: 1000 total classes
✅ Loaded 1000 hard samples


In [8]:
def get_models(dataset, model_name, key): 
    if dataset == 'imagenet':
        # save_root_path = r"checkpoint/tinyimagenet"
        model = timm.create_model(model_name, pretrained=True, num_classes=1000).to(device)
        model.eval()
        if 'inc' in key or 'vit' in key or 'bit' in key:
            return torch.nn.Sequential(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), model)
        else:
            return torch.nn.Sequential(transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), model)

### Ensemble Attack 

In [9]:
from abc import abstractmethod

import torch
import torch.nn.functional as F


class AdaEA_Base:
    def __init__(self, models, eps=8/255, alpha=2/255, max_value=1., min_value=0., threshold=0., beta=10,
                 device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')):
        assert isinstance(models, list) and len(models) >= 2, 'Error'
        self.device = device
        self.models = models
        self.num_models = len(self.models)
        for model in models:
            model.eval()

        # attack parameter
        self.eps = eps
        self.threshold = threshold
        self.max_value = max_value
        self.min_value = min_value
        self.beta = beta
        self.alpha = alpha

    def get_adv_example(self, ori_data, adv_data, grad, attack_step=None):
        """
        :param ori_data: original image
        :param adv_data: adversarial image in the last iteration
        :param grad: gradient in this iteration
        :return: adversarial example in this iteration
        """
        if attack_step is None:
            adv_example = adv_data.detach() + grad.sign() * self.alpha
        else:
            adv_example = adv_data.detach() + grad.sign() * attack_step
        delta = torch.clamp(adv_example - ori_data.detach(), -self.eps, self.eps)
        return torch.clamp(ori_data.detach() + delta, max=self.max_value, min=self.min_value)

    def agm(self, ori_data, cur_adv, grad, label):
        """
        Adaptive gradient modulation
        :param ori_data: natural images
        :param cur_adv: adv examples in last iteration
        :param grad: gradient in this iteration
        :param label: ground truth
        :return: coefficient of each model
        """
        loss_func = torch.nn.CrossEntropyLoss()

        # generate adversarial example
        adv_exp = [self.get_adv_example(ori_data=ori_data, adv_data=cur_adv, grad=grad[idx])
                   for idx in range(self.num_models)]
        loss_self = [loss_func(self.models[idx](adv_exp[idx]), label) for idx in range(self.num_models)]
        w = torch.zeros(size=(self.num_models,), device=self.device)

        for j in range(self.num_models):
            for i in range(self.num_models):
                if i == j:
                    continue
                w[j] += loss_func(self.models[i](adv_exp[j]), label) / loss_self[i] * self.beta
        w = torch.softmax(w, dim=0)

        return w

    def drf(self, grads, data_size):
        """
        disparity-reduced filter
        :param grads: gradients of each model
        :param data_size: size of input images
        :return: reduce map
        """
        reduce_map = torch.zeros(size=(self.num_models, self.num_models, data_size[0], data_size[-2], data_size[-1]),
                                 dtype=torch.float, device=self.device)
        sim_func = torch.nn.CosineSimilarity(dim=1, eps=1e-8)
        reduce_map_result = torch.zeros(size=(self.num_models, data_size[0], data_size[-2], data_size[-1]),
                                        dtype=torch.float, device=self.device)
        for i in range(self.num_models):
            for j in range(self.num_models):
                if i >= j:
                    continue
                reduce_map[i][j] = sim_func(F.normalize(grads[i], dim=1), F.normalize(grads[j], dim=1))
            if i < j:
                one_reduce_map = (reduce_map[i, :].sum(dim=0) + reduce_map[:, i].sum(dim=0)) / (self.num_models - 1)
                reduce_map_result[i] = one_reduce_map

        return reduce_map_result.mean(dim=0).view(data_size[0], 1, data_size[-2], data_size[-1])

    @abstractmethod
    def attack(self,
               data: torch.Tensor,
               label: torch.Tensor,
               idx: int = -1) -> torch.Tensor:
        ...

    def __call__(self, data, label, idx=-1):
        return self.attack(data, label, idx)

In [10]:
import torch
import torch.nn as nn


class AdaEA_MIFGSM(AdaEA_Base):
    def __init__(self, models, eps=8/255, alpha=2/255, iters=20, max_value=1., min_value=0., threshold=0., beta=10,
                 device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'), momentum=0.9):
        super().__init__(models=models, eps=eps, alpha=alpha, max_value=max_value, min_value=min_value,
                         threshold=threshold, device=device, beta=beta)
        self.iters = iters
        self.momentum = momentum

    def attack(self, data, label, idx=-1):
        B, C, H, W = data.size()
        data = data.clone().detach().to(self.device)
        label = label.clone().detach().to(self.device)
        loss_func = nn.CrossEntropyLoss()

        # init pert
        adv_data = data.clone().detach() + 0.001 * torch.randn(data.shape, device=self.device)
        adv_data = adv_data.detach()

        grad_mom = torch.zeros_like(data, device=self.device)

        for i in range(self.iters):
            adv_data.requires_grad = True

            outputs = [self.models[idx](adv_data) for idx in range(len(self.models))]
            losses = [loss_func(outputs[idx], label) for idx in range(len(self.models))]
            grads = [torch.autograd.grad(losses[idx], adv_data, retain_graph=True, create_graph=False)[0]
                     for idx in range(len(self.models))]

            # AGM
            alpha = self.agm(ori_data=data, cur_adv=adv_data, grad=grads, label=label)

            # DRF
            cos_res = self.drf(grads, data_size=(B, C, H, W))
            cos_res[cos_res >= self.threshold] = 1.
            cos_res[cos_res < self.threshold] = 0.

            output = torch.stack(outputs, dim=0) * alpha.view(self.num_models, 1, 1)
            output = output.sum(dim=0)
            loss = loss_func(output, label)
            grad = torch.autograd.grad(loss.sum(dim=0), adv_data)[0]
            grad = grad * cos_res

            # momentum
            grad = grad / torch.mean(torch.abs(grad), dim=(1, 2, 3), keepdim=True)
            grad = grad + self.momentum * grad_mom
            grad_mom = grad

            # add perturbation
            adv_data = self.get_adv_example(ori_data=data, adv_data=adv_data, grad=grad)
            adv_data.detach_()

        return adv_data

In [11]:
import torch
import torch.nn as nn

class AdaEA_PGD(AdaEA_Base):
    def __init__(self, models, eps=8/255, alpha=2/255, iters=20, max_value=1., min_value=0., threshold=0.,
                 beta=10, device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'),
                 random_start=False):
        super().__init__(models=models, eps=eps, alpha=alpha, max_value=max_value, min_value=min_value,
                         threshold=threshold, device=device, beta=beta)
        self.iters = iters
        self.random_start = random_start

    def attack(self, data, label, idx=-1):
        """
        AdaEA PGD attack — same AGM + DRF + ensemble averaging logic as AdaEA_MIFGSM,
        but with PGD (no momentum).
        """
        B, C, H, W = data.size()
        data = data.clone().detach().to(self.device)
        label = label.clone().detach().to(self.device)
        loss_func = nn.CrossEntropyLoss()

        # init pert: either small gaussian (like MIFGSM) or a uniform random start within eps-ball
        if self.random_start:
            adv_data = data.clone().detach() + torch.empty_like(data).uniform_(-self.eps, self.eps)
            adv_data = torch.clamp(adv_data, min=self.min_value, max=self.max_value).detach()
        else:
            adv_data = data.clone().detach() + 0.001 * torch.randn(data.shape, device=self.device)
            adv_data = adv_data.detach()

        for i in range(self.iters):
            adv_data.requires_grad = True

            # forward each model on current adversarial example
            outputs = [self.models[m_idx](adv_data) for m_idx in range(len(self.models))]
            losses = [loss_func(outputs[m_idx], label) for m_idx in range(len(self.models))]

            # per-model gradient (for AGM)
            grads = [torch.autograd.grad(losses[m_idx], adv_data, retain_graph=True, create_graph=False)[0]
                     for m_idx in range(len(self.models))]

            # AGM: obtain per-model coefficients w (shape: num_models,)
            alpha_coeffs = self.agm(ori_data=data, cur_adv=adv_data, grad=grads, label=label)

            # DRF: cosine similarity based reduce map
            cos_res = self.drf(grads, data_size=(B, C, H, W))
            cos_res[cos_res >= self.threshold] = 1.
            cos_res[cos_res < self.threshold] = 0.

            # ensemble-weighted logits like in MIFGSM
            output = torch.stack(outputs, dim=0) * alpha_coeffs.view(self.num_models, 1, 1)
            output = output.sum(dim=0)
            loss = loss_func(output, label)

            # compute gradient of ensemble loss
            grad = torch.autograd.grad(loss.sum(dim=0), adv_data)[0]

            # apply DRF mask
            grad = grad * cos_res

            # normalization (same style as your MIFGSM: per-sample mean absolute normalization)
            grad = grad / (torch.mean(torch.abs(grad), dim=(1, 2, 3), keepdim=True) + 1e-12)

            # PGD step: sign-based step (keeps get_adv_example semantics)
            adv_data = self.get_adv_example(ori_data=data, adv_data=adv_data, grad=grad)
            adv_data.detach_()

        return adv_data

In [60]:
class AdaEA_TIDIM(AdaEA_Base):
    def __init__(self,
                 models,
                 eps=8/255,
                 alpha=2/255,
                 iters=10,
                 max_value=1.,
                 min_value=0.,
                 threshold=0.,
                 beta=10,
                 device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'),
                 momentum=1.0,
                 kernel_size=5,
                 resize_prob=0.5,
                 diversity_scale=0.12,
                 n_eot=1):
        """
        TI-DIM: Translation-Invariant + Diverse-Input + Momentum Iterative FGSM
        n_eot: number of EOT samples per iteration (1 is standard; >1 improves stability)
        """
        super().__init__(models=models, eps=eps, alpha=alpha,
                         max_value=max_value, min_value=min_value,
                         threshold=threshold, device=device, beta=beta)

        self.iters = iters
        self.momentum = momentum
        self.kernel_size = kernel_size
        self.resize_prob = resize_prob
        self.diversity_scale = diversity_scale
        self.n_eot = n_eot

        # create gaussian TI kernel (1,1,k,k), we'll expand it per-channel in TI_gradient
        self.kernel = self._make_gaussian_kernel(kernel_size, sigma=kernel_size / 3.0).to(self.device)

    @staticmethod
    def _make_gaussian_kernel(kernel_size=5, sigma=1.0):
        ax = torch.arange(-(kernel_size // 2), kernel_size // 2 + 1, dtype=torch.float32)
        xx, yy = torch.meshgrid(ax, ax, indexing='xy')
        kernel = torch.exp(-(xx ** 2 + yy ** 2) / (2.0 * sigma ** 2))
        kernel = kernel / kernel.sum()
        kernel = kernel.view(1, 1, kernel_size, kernel_size)
        return kernel

    def DI_transform(self, x):
        """
        Diverse Input transform:
        - with probability `resize_prob` randomly resize to rnd in [H*(1-s), H] then pad/crop back
        - uses torch RNG for reproducibility
        """
        if torch.rand(1).item() > self.resize_prob:
            return x

        B, C, H, W = x.shape
        s = float(self.diversity_scale)
        rnd = int(H * (1.0 - s) + torch.rand(1).item() * H * s)
        rnd = max(1, rnd)

        x_resized = F.interpolate(x, size=(rnd, rnd), mode='bilinear', align_corners=False)

        if rnd < H:
            pad_h = H - rnd
            pad_w = W - rnd
            pad_top = pad_h // 2
            pad_left = pad_w // 2
            pad_bottom = pad_h - pad_top
            pad_right = pad_w - pad_left
            x_padded = F.pad(x_resized, (pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0.0)
            return x_padded
        elif rnd > H:
            # center-crop back
            start = (rnd - H) // 2
            return x_resized[:, :, start:start + H, start:start + W]
        else:
            return x_resized

    def TI_gradient(self, grad):
        """
        Apply translation-invariant smoothing to gradient using grouped conv.
        grad: (B,C,H,W)
        returns smoothed grad of same shape
        """
        B, C, H, W = grad.shape
        # kernel is (1,1,k,k) -> expand to (C,1,k,k) for grouped conv
        kernel = self.kernel.to(grad.device).type(grad.dtype)
        if kernel.shape[0] == 1:
            kernel = kernel.repeat(C, 1, 1, 1)
        # grouped conv
        pad = kernel.shape[-1] // 2
        grad_smooth = F.conv2d(grad, weight=kernel, bias=None, stride=1, padding=pad, groups=C)
        return grad_smooth

    def attack(self, data, label, idx=-1):
        """
        TI-DIM attack:
        - Uses AGM to compute per-model weights and averages logits with those weights
        - Applies DI transform, TI smoothing, momentum, and MI-FGSM updates
        """
        B, C, H, W = data.size()
        data = data.clone().detach().to(self.device)
        label = label.clone().detach().to(self.device)
        loss_func = torch.nn.CrossEntropyLoss(reduction='mean')

        # initialization
        adv_data = data.clone().detach() + 0.001 * torch.randn_like(data).to(self.device)
        adv_data = adv_data.detach()
        grad_mom = torch.zeros_like(data).to(self.device)

        step_size = self.alpha if hasattr(self, 'alpha') and self.alpha is not None else (self.eps / float(self.iters))
        eps_small = 1e-12

        for it in range(self.iters):
            # EOT accumulation (average over n_eot DI samples)
            grad_eot_accum = torch.zeros_like(adv_data).to(self.device)

            # We'll also collect per-model grads for AGM/DRF if needed (useful for agm/drf)
            per_model_grads_for_agm = [torch.zeros_like(adv_data).to(self.device) for _ in range(self.num_models)]

            for e in range(self.n_eot):
                adv_data.requires_grad_(True)

                # apply DI
                adv_DI = self.DI_transform(adv_data)

                # forward each model to get logits
                logits_list = []
                for m in self.models:
                    out = m(adv_DI)
                    if isinstance(out, (tuple, list)):
                        out = out[0]
                    logits_list.append(out)

                # compute per-model losses (for AGM weighting)
                losses_per_model = [loss_func(logits_list[j], label) for j in range(self.num_models)]

                # compute per-model grads (w.r.t adv_data) if AGM or DRF needs them
                grads = [torch.autograd.grad(losses_per_model[j], adv_data, retain_graph=True, create_graph=False)[0]
                         for j in range(self.num_models)]

                # keep grads for agm/drf outside EOT loop (we average later)
                for j in range(self.num_models):
                    per_model_grads_for_agm[j] += grads[j].detach()

                # compute AGM weights (use detached grads for stability). AGM expects grads list
                try:
                    agm_w = self.agm(ori_data=data, cur_adv=adv_data.detach(), grad=[g.detach() for g in grads], label=label)
                    # ensure shape (num_models,)
                    if agm_w is None:
                        agm_w = torch.ones(self.num_models, device=self.device) / float(self.num_models)
                except Exception:
                    agm_w = torch.ones(self.num_models, device=self.device) / float(self.num_models)

                # apply AGM weights on logits: weighted sum of logits
                # logits_list is list of (B, num_classes)
                stacked_logits = torch.stack(logits_list, dim=0)  # (num_models, B, C_out)
                # agm_w: (num_models,) -> (num_models, 1, 1) to broadcast
                wview = agm_w.view(self.num_models, 1, 1)
                weighted_logits = (stacked_logits * wview).sum(dim=0)  # (B, C_out)

                # combined loss from weighted logits
                loss_combined = loss_func(weighted_logits, label)
                # gradient of combined loss w.r.t adv_data
                grad_combined = torch.autograd.grad(loss_combined, adv_data, retain_graph=False, create_graph=False)[0]

                grad_eot_accum += grad_combined.detach()

                # free grads for next eot iteration
                adv_data.grad = None
                for m in self.models:
                    if hasattr(m, 'zero_grad'):
                        m.zero_grad()

            # average over EOT samples
            grad_avg = grad_eot_accum / float(self.n_eot)

            # compute DRF mask from averaged per-model grads if possible
            try:
                per_model_grads_avg = [g / float(self.n_eot) for g in per_model_grads_for_agm]
                cos_res = self.drf(per_model_grads_avg, data_size=(B, C, H, W))
                # thresholding
                cos_res = (cos_res >= self.threshold).float().to(self.device)
                # broadcast if necessary
                if cos_res.shape != grad_avg.shape:
                    # attempt to broadcast: (B,1,H,W) -> (B,C,H,W)
                    if cos_res.dim() == 4 and cos_res.size(1) == 1:
                        cos_res = cos_res.repeat(1, C, 1, 1)
                grad_avg = grad_avg * cos_res
            except Exception:
                # if drf fails, continue without mask
                pass

            # TI smoothing
            grad_ti = self.TI_gradient(grad_avg)

            # normalize per-sample
            denom = torch.mean(torch.abs(grad_ti), dim=(1, 2, 3), keepdim=True).clamp(min=eps_small)
            grad_norm = grad_ti / denom

            # momentum update
            grad_mom = self.momentum * grad_mom + grad_norm

            # MI-FGSM update (sign of momentum)
            adv_data = adv_data.detach() + step_size * torch.sign(grad_mom)

            # clip to eps ball and valid range
            adv_data = torch.max(torch.min(adv_data, data + self.eps), data - self.eps)
            adv_data = torch.clamp(adv_data, min=self.min_value, max=self.max_value).detach()

            # zero model grads
            for m in self.models:
                if hasattr(m, 'zero_grad'):
                    m.zero_grad()

        return adv_data.detach()


### Setup 

In [61]:
ens_models = [
    get_models("imagenet", "resnet18", "resnet18"), 
    get_models("imagenet", "inception_v3", "inc_v3"), 
    get_models("imagenet", "deit_tiny_patch16_224", "deit_t"),
    get_models("imagenet", "vit_tiny_patch16_224", "vit_t"), 
    # get_models("imagenet", "efficientnet_b0", "efficientnet_b0"), 
    # get_models("imagenet", "xcit_tiny_12_p8_224", "swin_t"), 
] 


In [67]:
max_value = 1.0 
min_value = 0.0 
eps = 8/255
alpha = 2/255 
iters = 10 
threshold = -0.3
beta = 10 

# attack_method = AdaEA_MIFGSM(ens_models, eps=eps, alpha=alpha, iters=iters, max_value=max_value, 
#                             min_value=min_value, beta=beta, threshold=threshold, device=device)

# attack_method = AdaEA_PGD(ens_models, eps=eps, alpha=alpha, iters=iters, max_value=max_value, 
#                             min_value=min_value, beta=beta, threshold=threshold, device=device)



attack_method = AdaEA_TIDIM(ens_models, eps=16/255, alpha=2/255, iters=iters, kernel_size=15,  device=device)

In [68]:
from torchmetrics.functional.image import structural_similarity_index_measure as ssim
import torch

# --- before loop (clear previous lists) ---
adv_list = []
orig_list = []
labels_list = []
noise_rates = []
pixel_diffs = []

def ensure_batch(x):
    return x if x.dim() == 4 else x.unsqueeze(0)

def to_unit_range(x):
    """
    Ensure x is in [0,1]. If tensor values appear to be in [0,255] (max>1.5),
    convert by dividing by 255. Returns a float tensor on same device.
    """
    x = ensure_batch(x).float()
    if x.max().item() > 1.5:
        x = x / 255.0
    return torch.clamp(x, 0.0, 1.0)

# --- attack loop (same as yours, but using the simplified functions) ---
for img, label in tqdm(hard_loader, desc="Generating MI-FGSM adversarials (GPU)"):
    img, label = img.to(device), label.to(device)

    with torch.enable_grad():
        adv_img = attack_method(img, label) 

    # store for later (move to CPU)
    adv_list.append(adv_img.squeeze(0).cpu())
    orig_list.append(img.squeeze(0).cpu())
    labels_list.append(label.squeeze(0).cpu())

    # compute SSIM and pixel diffs on [0,1] images
    img_for_ssim = to_unit_range(img)       # (1,C,H,W) in [0,1]
    adv_for_ssim = to_unit_range(adv_img)   # (1,C,H,W) in [0,1]

    ssim_val = ssim(adv_for_ssim, img_for_ssim)  # scalar tensor
    noise_rates.append((1.0 - float(ssim_val)))
    pixel_diffs.append((adv_for_ssim - img_for_ssim).abs().mean().item())

# --- stack everything on CPU ---
adv_all = torch.stack(adv_list).cpu()
orig_all = torch.stack(orig_list).cpu()
labels_all = torch.stack(labels_list).cpu()

noise_rates = torch.tensor(noise_rates)
pixel_diffs = torch.tensor(pixel_diffs)

print(f"✅ Generated {adv_all.size(0)} adversarial images. Shape: {adv_all.shape}")
print(f"Noise (1 - SSIM): mean={noise_rates.mean():.6f}, std={noise_rates.std():.6f}, min={noise_rates.min():.6f}, max={noise_rates.max():.6f}")
print(f"Mean absolute pixel diff (after clamp to [0,1]): mean={pixel_diffs.mean():.6f}, std={pixel_diffs.std():.6f}")


Generating MI-FGSM adversarials (GPU): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [57:34<00:00,  3.45s/it]


✅ Generated 1000 adversarial images. Shape: torch.Size([1000, 3, 224, 224])
Noise (1 - SSIM): mean=0.169276, std=0.070937, min=0.000290, max=0.539290
Mean absolute pixel diff (after clamp to [0,1]): mean=0.044744, std=0.003677


### Test on Target models 

In [64]:
import torch
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm

batch_size = 32  # tune this for your GPU
dataset = TensorDataset(adv_all, labels_all)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False,
                    num_workers=4, pin_memory=True)

In [65]:
target_models = [
    get_models("imagenet", "resnet152", "resnet152"),
    get_models("imagenet", "wide_resnet101_2", "wrn101_2"),     
    get_models("imagenet", "regnety_320", "regnety_320"),
    get_models("imagenet", "vgg19", "vgg19"),
    get_models("imagenet", "vit_base_patch16_224", "vit_b"),
    get_models("imagenet", "deit_base_patch16_224", "deit_b"),
    # get_models("imagenet", "swin_base_patch4_window7_224", "swin_b"), 
    # get_models("imagenet", "mixer_b16_224", "vit_t"), 
    # get_models("imagenet", "convmixer_768_32", "vit_t")
] 

In [66]:
with torch.no_grad():
    for t_model in target_models:
        name = getattr(t_model, "name", t_model.__class__.__name__)
        t_model.eval()
        t_model.to(device)

        fooled = 0
        total = 0

        for imgs_cpu, labels_cpu in tqdm(loader, desc=f"ASR {name}"):
            # Move to device here
            imgs = imgs_cpu.to(device, non_blocking=True)
            labels = labels_cpu.to(device, non_blocking=True)

            outputs = t_model(imgs)
            if isinstance(outputs, (tuple, list)):
                outputs = outputs[0]
            preds = outputs.argmax(dim=1)

            fooled += (preds != labels).sum().item()
            total += labels.size(0)

            # free cache per batch (helps on tight GPUs)
            if device.type == "cuda":
                torch.cuda.empty_cache()

        asr = 100.0 * fooled / total if total > 0 else 0.0
        print(f"{name}: ASR = {asr:.2f}%  ({fooled}/{total} fooled)")

ASR Sequential: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:03<00:00,  8.80it/s]


Sequential: ASR = 21.60%  (216/1000 fooled)


ASR Sequential: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00,  7.16it/s]


Sequential: ASR = 54.70%  (547/1000 fooled)


ASR Sequential: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.13it/s]


Sequential: ASR = 20.20%  (202/1000 fooled)


ASR Sequential: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:03<00:00,  9.92it/s]


Sequential: ASR = 64.10%  (641/1000 fooled)


ASR Sequential: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00,  7.10it/s]


Sequential: ASR = 17.70%  (177/1000 fooled)


ASR Sequential: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00,  7.20it/s]

Sequential: ASR = 34.60%  (346/1000 fooled)



