In [4]:
import timm
import time 
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision import datasets 
from torch.utils.data import DataLoader
# from medmnist import INFO
import numpy as np
import faiss
import copy
from tqdm import tqdm

from torch.nn.functional import softmax, cosine_similarity
from collections import Counter
import matplotlib.pyplot as plt
import torchvision.transforms.functional as TF
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import os 

import warnings
warnings.filterwarnings("ignore")

In [5]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

print("Using device:", device)

Using device: cuda:1


In [6]:
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from PIL import Image

class CustomImageListDataset(torch.utils.data.Dataset):
    def __init__(self, file_list, class_to_idx, transform=None):
        with open(file_list, "r") as f:
            self.samples = [line.strip() for line in f]
        self.transform = transform
        self.class_to_idx = class_to_idx

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path = self.samples[idx]
        class_folder = os.path.basename(os.path.dirname(img_path))
        label = self.class_to_idx.get(class_folder, -1)
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, label


# ---------------- Create a combined class mapping ----------------
root_dir = "dataset/imagenet_tests"
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Collect class mappings from all 10 partitions
combined_class_to_idx = {}
for i in range(1, 11):
    test_dir = os.path.join(root_dir, f"test{i}")
    dataset = datasets.ImageFolder(test_dir, transform=transform)
    combined_class_to_idx.update(dataset.class_to_idx)

print(f"âœ… Combined class mapping built: {len(combined_class_to_idx)} total classes")

# ---------------- Load your 1000-image subset ----------------
subset_file = "results/hard_cases_missed_by_mobilenet.txt"
hard_dataset = CustomImageListDataset(subset_file, class_to_idx=combined_class_to_idx, transform=transform)
hard_loader = DataLoader(hard_dataset, batch_size=1, shuffle=False)

print(f"âœ… Loaded {len(hard_dataset)} hard samples")

âœ… Combined class mapping built: 1000 total classes
âœ… Loaded 1000 hard samples


In [7]:
import os
from torchvision import datasets, transforms
import torch

print(f"Step 1: Loading dataset with resize transform...")

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

val_data_dir = 'dataset/imagenet_validation' 

val_dataset = datasets.ImageFolder(os.path.join(val_data_dir), transform=transform) 


print(f"Validation samples: {len(val_dataset)}")

Step 1: Loading dataset with resize transform...
Validation samples: 30000


In [8]:
def get_models(dataset, model_name, key): 
    if dataset == 'imagenet':
        # save_root_path = r"checkpoint/tinyimagenet"
        model = timm.create_model(model_name, pretrained=True, num_classes=1000).to(device)
        model.eval()
        if 'inc' in key or 'vit' in key or 'bit' in key:
            return torch.nn.Sequential(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), model)
        else:
            return torch.nn.Sequential(transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), model)

### Ensemble Attack 

In [9]:
from abc import abstractmethod

import torch
import torch.nn.functional as F


class AdaEA_Base:
    def __init__(self, models, eps=8/255, alpha=2/255, max_value=1., min_value=0., threshold=0., beta=10,
                 device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')):
        assert isinstance(models, list) and len(models) >= 2, 'Error'
        self.device = device
        self.models = models
        self.num_models = len(self.models)
        for model in models:
            model.eval()

        # attack parameter
        self.eps = eps
        self.threshold = threshold
        self.max_value = max_value
        self.min_value = min_value
        self.beta = beta
        self.alpha = alpha

    def get_adv_example(self, ori_data, adv_data, grad, attack_step=None):
        """
        :param ori_data: original image
        :param adv_data: adversarial image in the last iteration
        :param grad: gradient in this iteration
        :return: adversarial example in this iteration
        """
        if attack_step is None:
            adv_example = adv_data.detach() + grad.sign() * self.alpha
        else:
            adv_example = adv_data.detach() + grad.sign() * attack_step
        delta = torch.clamp(adv_example - ori_data.detach(), -self.eps, self.eps)
        return torch.clamp(ori_data.detach() + delta, max=self.max_value, min=self.min_value)

    def agm(self, ori_data, cur_adv, grad, label):
        """
        Adaptive gradient modulation
        :param ori_data: natural images
        :param cur_adv: adv examples in last iteration
        :param grad: gradient in this iteration
        :param label: ground truth
        :return: coefficient of each model
        """
        loss_func = torch.nn.CrossEntropyLoss()

        # generate adversarial example
        adv_exp = [self.get_adv_example(ori_data=ori_data, adv_data=cur_adv, grad=grad[idx])
                   for idx in range(self.num_models)]
        loss_self = [loss_func(self.models[idx](adv_exp[idx]), label) for idx in range(self.num_models)]
        w = torch.zeros(size=(self.num_models,), device=self.device)

        for j in range(self.num_models):
            for i in range(self.num_models):
                if i == j:
                    continue
                w[j] += loss_func(self.models[i](adv_exp[j]), label) / loss_self[i] * self.beta
        w = torch.softmax(w, dim=0)

        return w

    def drf(self, grads, data_size):
        """
        disparity-reduced filter
        :param grads: gradients of each model
        :param data_size: size of input images
        :return: reduce map
        """
        reduce_map = torch.zeros(size=(self.num_models, self.num_models, data_size[0], data_size[-2], data_size[-1]),
                                 dtype=torch.float, device=self.device)
        sim_func = torch.nn.CosineSimilarity(dim=1, eps=1e-8)
        reduce_map_result = torch.zeros(size=(self.num_models, data_size[0], data_size[-2], data_size[-1]),
                                        dtype=torch.float, device=self.device)
        for i in range(self.num_models):
            for j in range(self.num_models):
                if i >= j:
                    continue
                reduce_map[i][j] = sim_func(F.normalize(grads[i], dim=1), F.normalize(grads[j], dim=1))
            if i < j:
                one_reduce_map = (reduce_map[i, :].sum(dim=0) + reduce_map[:, i].sum(dim=0)) / (self.num_models - 1)
                reduce_map_result[i] = one_reduce_map

        return reduce_map_result.mean(dim=0).view(data_size[0], 1, data_size[-2], data_size[-1])

    @abstractmethod
    def attack(self,
               data: torch.Tensor,
               label: torch.Tensor,
               idx: int = -1) -> torch.Tensor:
        ...

    def __call__(self, data, label, idx=-1):
        return self.attack(data, label, idx)

In [10]:
import torch
import torch.nn as nn


class AdaEA_MIFGSM(AdaEA_Base):
    def __init__(self, models, eps=8/255, alpha=2/255, iters=20, max_value=1., min_value=0., threshold=0., beta=10,
                 device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'), momentum=0.9):
        super().__init__(models=models, eps=eps, alpha=alpha, max_value=max_value, min_value=min_value,
                         threshold=threshold, device=device, beta=beta)
        self.iters = iters
        self.momentum = momentum

    def attack(self, data, label, idx=-1):
        B, C, H, W = data.size()
        data = data.clone().detach().to(self.device)
        label = label.clone().detach().to(self.device)
        loss_func = nn.CrossEntropyLoss()

        # init pert
        adv_data = data.clone().detach() + 0.001 * torch.randn(data.shape, device=self.device)
        adv_data = adv_data.detach()

        grad_mom = torch.zeros_like(data, device=self.device)

        for i in range(self.iters):
            adv_data.requires_grad = True

            outputs = [self.models[idx](adv_data) for idx in range(len(self.models))]
            losses = [loss_func(outputs[idx], label) for idx in range(len(self.models))]
            grads = [torch.autograd.grad(losses[idx], adv_data, retain_graph=True, create_graph=False)[0]
                     for idx in range(len(self.models))]

            # AGM
            alpha = self.agm(ori_data=data, cur_adv=adv_data, grad=grads, label=label)

            # DRF
            cos_res = self.drf(grads, data_size=(B, C, H, W))
            cos_res[cos_res >= self.threshold] = 1.
            cos_res[cos_res < self.threshold] = 0.

            output = torch.stack(outputs, dim=0) * alpha.view(self.num_models, 1, 1)
            output = output.sum(dim=0)
            loss = loss_func(output, label)
            grad = torch.autograd.grad(loss.sum(dim=0), adv_data)[0]
            grad = grad * cos_res

            # momentum
            grad = grad / torch.mean(torch.abs(grad), dim=(1, 2, 3), keepdim=True)
            grad = grad + self.momentum * grad_mom
            grad_mom = grad

            # add perturbation
            adv_data = self.get_adv_example(ori_data=data, adv_data=adv_data, grad=grad)
            adv_data.detach_()

        return adv_data

In [11]:
import torch
import torch.nn as nn

class AdaEA_PGD(AdaEA_Base):
    def __init__(self, models, eps=8/255, alpha=2/255, iters=20, max_value=1., min_value=0., threshold=0.,
                 beta=10, device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'),
                 random_start=False):
        super().__init__(models=models, eps=eps, alpha=alpha, max_value=max_value, min_value=min_value,
                         threshold=threshold, device=device, beta=beta)
        self.iters = iters
        self.random_start = random_start

    def attack(self, data, label, idx=-1):
        """
        AdaEA PGD attack â€” same AGM + DRF + ensemble averaging logic as AdaEA_MIFGSM,
        but with PGD (no momentum).
        """
        B, C, H, W = data.size()
        data = data.clone().detach().to(self.device)
        label = label.clone().detach().to(self.device)
        loss_func = nn.CrossEntropyLoss()

        # init pert: either small gaussian (like MIFGSM) or a uniform random start within eps-ball
        if self.random_start:
            adv_data = data.clone().detach() + torch.empty_like(data).uniform_(-self.eps, self.eps)
            adv_data = torch.clamp(adv_data, min=self.min_value, max=self.max_value).detach()
        else:
            adv_data = data.clone().detach() + 0.001 * torch.randn(data.shape, device=self.device)
            adv_data = adv_data.detach()

        for i in range(self.iters):
            adv_data.requires_grad = True

            # forward each model on current adversarial example
            outputs = [self.models[m_idx](adv_data) for m_idx in range(len(self.models))]
            losses = [loss_func(outputs[m_idx], label) for m_idx in range(len(self.models))]

            # per-model gradient (for AGM)
            grads = [torch.autograd.grad(losses[m_idx], adv_data, retain_graph=True, create_graph=False)[0]
                     for m_idx in range(len(self.models))]

            # AGM: obtain per-model coefficients w (shape: num_models,)
            alpha_coeffs = self.agm(ori_data=data, cur_adv=adv_data, grad=grads, label=label)

            # DRF: cosine similarity based reduce map
            cos_res = self.drf(grads, data_size=(B, C, H, W))
            cos_res[cos_res >= self.threshold] = 1.
            cos_res[cos_res < self.threshold] = 0.

            # ensemble-weighted logits like in MIFGSM
            output = torch.stack(outputs, dim=0) * alpha_coeffs.view(self.num_models, 1, 1)
            output = output.sum(dim=0)
            loss = loss_func(output, label)

            # compute gradient of ensemble loss
            grad = torch.autograd.grad(loss.sum(dim=0), adv_data)[0]

            # apply DRF mask
            grad = grad * cos_res

            # normalization (same style as your MIFGSM: per-sample mean absolute normalization)
            grad = grad / (torch.mean(torch.abs(grad), dim=(1, 2, 3), keepdim=True) + 1e-12)

            # PGD step: sign-based step (keeps get_adv_example semantics)
            adv_data = self.get_adv_example(ori_data=data, adv_data=adv_data, grad=grad)
            adv_data.detach_()

        return adv_data

### DEAA 

#### Helper (DEAA)

In [12]:
class NormalizedModel(nn.Module):
    def __init__(self, model, mean, std):
        super().__init__()
        self.model = model
        self.register_buffer('mean', torch.tensor(mean).view(1,3,1,1))
        self.register_buffer('std', torch.tensor(std).view(1,3,1,1))

    def forward(self, x):
        x = (x - self.mean) / self.std
        return self.model(x)

def get_last_linear_layer(model):
    """
    Find the last Linear layer in timm models or wrapped models (NormalizedModel).
    """
    # unwrap NormalizedModel if needed
    if isinstance(model, NormalizedModel):
        model = model.model

    # Common attribute names for classifier heads
    candidate_attrs = ['head', 'heads', 'fc', 'classifier', 'mlp_head']

    for attr in candidate_attrs:
        if hasattr(model, attr):
            layer = getattr(model, attr)
            # If it's a Linear layer
            if isinstance(layer, nn.Linear):
                return layer
            # If it's a Sequential or Module, search inside
            if isinstance(layer, nn.Module):
                last_linear = None
                for m in reversed(list(layer.modules())):
                    if isinstance(m, nn.Linear):
                        last_linear = m
                        break
                if last_linear is not None:
                    return last_linear

    # Fallback: scan all modules
    last_linear = None
    for m in model.modules():
        if isinstance(m, nn.Linear):
            last_linear = m
    if last_linear is not None:
        return last_linear

    raise RuntimeError(f"No Linear layer found in model {model.__class__.__name__}")



def get_features_before_last_linear(model, x):
    """
    Extract features before the final classifier, works for CNNs and ViTs.
    """
    # unwrap NormalizedModel if present
    if isinstance(model, NormalizedModel):
        model = model.model

    # Common classifier attributes
    candidate_attrs = ['head', 'heads', 'fc', 'classifier', 'mlp_head']
    classifier = None
    for attr in candidate_attrs:
        if hasattr(model, attr):
            classifier = getattr(model, attr)
            break

    features = {}

    def hook(module, input, output):
        features['feat'] = input[0].detach()

    if classifier is not None:
        handle = classifier.register_forward_hook(hook)
    else:
        # fallback: attach hook to last module
        last_module = list(model.modules())[-1]
        handle = last_module.register_forward_hook(hook)

    model.eval()
    with torch.no_grad():
        _ = model(x)
    handle.remove()

    if 'feat' not in features:
        raise RuntimeError(f"Failed to capture features from model {model.__class__.__name__}")

    return features['feat']

#### DEAA Main 

In [13]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import DataLoader
import numpy as np
import faiss
from tqdm import tqdm
import torch.nn.functional as F


def softmax(x, dim=0):
    return F.softmax(x, dim=dim)


def cosine_similarity(x, y, dim=1, eps=1e-8):
    return F.cosine_similarity(x, y, dim=dim, eps=eps)


# ðŸ‘€ Dummy visualization (replace with your function)
def visualize_test_and_roc(test_img, roc_imgs, local_labels):
    print("Visualization placeholder: Test image + RoC samples")
    print(f"RoC labels: {local_labels}")


class VisionDES_2: 
    def __init__(self, dsel_dataset, pool): 
        self.dsel_dataset = dsel_dataset
        self.dsel_loader = DataLoader(dsel_dataset, batch_size=32, shuffle=False) 
        self.dino_model = timm.create_model('vit_base_patch16_224.dino', pretrained=True).to(device)
        self.dino_model.eval()  
        self.pool = pool 

        self.suspected_model_votes = [] 
        
        
    def dino_embedder(self, images):
        if images.shape[1] == 1:
            images = images.repeat(1, 3, 1, 1)
        return self.dino_model.forward_features(images)


    def fit(self): 
        dsel_embeddings = []
        dsel_labels = []
    
        with torch.no_grad():
            for imgs, labels in tqdm(self.dsel_loader):
                imgs = imgs.to(device)
                embs = self.dino_embedder(imgs).cpu()  
                dsel_embeddings.append(embs)
                dsel_labels.append(labels)
    
        # Keep as tensor
        dsel_embeddings_tensor = torch.cat(dsel_embeddings).detach().cpu()  
        cls_tensor = dsel_embeddings_tensor[:, 0, :]  
    
        # Convert to NumPy
        cls_embeddings = np.ascontiguousarray(cls_tensor.numpy(), dtype='float32')
        self.dsel_embeddings = cls_embeddings
        self.dsel_labels = torch.cat(dsel_labels).numpy()
    
        # Build FAISS index
        embedding_dim = cls_embeddings.shape[1]
        self.index = faiss.IndexFlatL2(embedding_dim)
        self.index.add(cls_embeddings)

    
    def get_top_n_competent_models(self, test_img, k=7, top_n=3, use_sim=False, sim_threshold=0, alpha=0.6):
        # Step 1: Get DINO CLS embedding for the test image
        img_for_dino = test_img.unsqueeze(0).to(device)

        with torch.no_grad():
            test_emb = self.dino_model.forward_features(img_for_dino).cpu().numpy().astype('float32')
            test_emb = test_emb[:, 0, :]  # CLS token only
    
        # Step 2: Find k nearest neighbors in FAISS
        distances, neighbors = self.index.search(test_emb, k)
        neighbor_idxs = neighbors[0]
        local_labels = np.array(self.dsel_labels[neighbor_idxs]).flatten()
    
        # Step 3: Get RoC images
        with torch.no_grad():
            roc_imgs = torch.stack([self.dsel_dataset[idx][0] for idx in neighbor_idxs]).to(device)
    
        # Step 4: Evaluate classifiers
        competences, feature_similarities, correct_counts, grad_vectors  = [], [], [], []
    
        for clf in self.pool:
            clf.eval()
            with torch.no_grad():
                outputs = clf(roc_imgs)
                preds = outputs.argmax(dim=1).cpu().numpy()
                correct = (preds == local_labels).sum()
                competences.append(correct / k)
                correct_counts.append(correct)
    
                # Feature similarity
                test_feat = get_features_before_last_linear(clf, test_img.unsqueeze(0).to(device))
                roc_feats = get_features_before_last_linear(clf, roc_imgs)
                mean_feat = roc_feats.mean(dim=0, keepdim=True)
                sim = cosine_similarity(test_feat.flatten().unsqueeze(0), mean_feat.flatten().unsqueeze(0))
                feature_similarities.append(sim.item())

            test_img_req = test_img.clone().detach().unsqueeze(0).to(device)
            test_img_req.requires_grad_(True)

            out = clf(test_img_req)
            pseudo_label = out.argmax(dim=1)  # keep grad path alive
            loss = F.cross_entropy(out, pseudo_label)
            grad = torch.autograd.grad(loss, test_img_req)[0]
            grad_vec = grad.flatten()  # flatten for cosine similarity
            grad_vectors.append(grad_vec.cpu())


        # Step 5: Compute gradient-based diversity
        grads_tensor = torch.stack(grad_vectors).to(device) 
        # L2-normalize per-model gradient vector
        grads_norm = F.normalize(grads_tensor, p=2, dim=1, eps=1e-8)  # (K, D)
        
        # compute cosine similarity matrix
        cos_sim = grads_norm @ grads_norm.t()  # (K, K)  (equivalent to pairwise cosine) 

        # zero out diagonal (self-sim = 1)
        K = cos_sim.size(0)
        cos_sim.fill_diagonal_(0.0)
        
        # average similarity over other models only
        mean_sim_other = cos_sim.sum(dim=1) / float(K - 1)  # (K,)
        
        # diversity score: lower similarity -> higher diversity
        diversity_scores = (mean_sim_other).cpu().numpy()
        diversity_scores = (diversity_scores - diversity_scores.min()) / (diversity_scores.max() - diversity_scores.min() + 1e-8)

        # print("Diversity", diversity_scores)
        # print("Competences", competences) 

        final_scores = [alpha * c + (1 - alpha) * d for c, d in zip(competences, diversity_scores)] 
    
        # Step 7: Select top_n models
        top_indices = np.argsort(final_scores)[::-1][:top_n]
        # print("final_scores", final_scores)
        # print("Top_indices", top_indices)
        top_models = [self.pool[i] for i in top_indices]
    
        return top_models

### Setup 

In [14]:
ens_models = [
    get_models("imagenet", "resnet18", "resnet18"), 
    get_models("imagenet", "inception_v3", "inc_v3"), 
    get_models("imagenet", "deit_tiny_patch16_224", "deit_t"),
    get_models("imagenet", "vit_tiny_patch16_224", "vit_t"), 
    get_models("imagenet", "efficientnet_b0", "efficientnet_b0"), 
    get_models("imagenet", "xcit_tiny_12_p8_224", "swin_t"), 
]  

In [15]:
des_model = VisionDES_2(val_dataset, ens_models)
des_model.fit()

100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 938/938 [06:24<00:00,  2.44it/s]


In [16]:
saved_index = des_model.index
saved_dsel_embeddings = des_model.dsel_embeddings 
saved_dsel_labels = des_model.dsel_labels

# des_model.index = saved_index 
# des_model.dsel_embeddings = saved_dsel_embeddings 
# des_model.dsel_labels = saved_dsel_labels

In [29]:
from torchmetrics.functional.image import structural_similarity_index_measure as ssim
import torch

# --- before loop (clear previous lists) ---
adv_list = []
orig_list = []
labels_list = []
noise_rates = []
pixel_diffs = []

max_value = 1.0 
min_value = 0.0 
eps = 8/255
alpha = 2/255 
iters = 10 
threshold = -0.3
beta = 10 

def ensure_batch(x):
    return x if x.dim() == 4 else x.unsqueeze(0)

def to_unit_range(x):
    """
    Ensure x is in [0,1]. If tensor values appear to be in [0,255] (max>1.5),
    convert by dividing by 255. Returns a float tensor on same device.
    """
    x = ensure_batch(x).float()
    if x.max().item() > 1.5:
        x = x / 255.0
    return torch.clamp(x, 0.0, 1.0)

# --- attack loop (same as yours, but using the simplified functions) ---
for img, label in tqdm(hard_loader, desc="Generating MI-FGSM adversarials (GPU)"):
    img, label = img.to(device), label.to(device)

    with torch.enable_grad():
        selected_models = des_model.get_top_n_competent_models(
                img[0],                    
                k=7,
                top_n=4,
                use_sim=False,
                sim_threshold=0,
                alpha=0.3
            )

        # attack_method = AdaEA_MIFGSM(selected_models, eps=eps, alpha=alpha, iters=iters, max_value=max_value, 
        #                     min_value=min_value, beta=beta, threshold=threshold, device=device)
        attack_method = AdaEA_PGD(selected_models, eps=eps, alpha=alpha, iters=iters, max_value=max_value, 
                            min_value=min_value, beta=beta, threshold=threshold, device=device)
        
        adv_img = attack_method(img, label)
        
    # store for later (move to CPU)
    adv_list.append(adv_img.squeeze(0).cpu())
    orig_list.append(img.squeeze(0).cpu())
    labels_list.append(label.squeeze(0).cpu())

    # compute SSIM and pixel diffs on [0,1] images
    img_for_ssim = to_unit_range(img)       # (1,C,H,W) in [0,1]
    adv_for_ssim = to_unit_range(adv_img)   # (1,C,H,W) in [0,1]

    ssim_val = ssim(adv_for_ssim, img_for_ssim)  # scalar tensor
    noise_rates.append((1.0 - float(ssim_val)))
    pixel_diffs.append((adv_for_ssim - img_for_ssim).abs().mean().item())

# --- stack everything on CPU ---
adv_all = torch.stack(adv_list).cpu()
orig_all = torch.stack(orig_list).cpu()
labels_all = torch.stack(labels_list).cpu()

noise_rates = torch.tensor(noise_rates)
pixel_diffs = torch.tensor(pixel_diffs)

print(f"âœ… Generated {adv_all.size(0)} adversarial images. Shape: {adv_all.shape}")
print(f"Noise (1 - SSIM): mean={noise_rates.mean():.6f}, std={noise_rates.std():.6f}, min={noise_rates.min():.6f}, max={noise_rates.max():.6f}")
print(f"Mean absolute pixel diff (after clamp to [0,1]): mean={pixel_diffs.mean():.6f}, std={pixel_diffs.std():.6f}")


Generating MI-FGSM adversarials (GPU): 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1000/1000 [1:22:48<00:00,  4.97s/it]


âœ… Generated 1000 adversarial images. Shape: torch.Size([1000, 3, 224, 224])
Noise (1 - SSIM): mean=0.119132, std=0.055926, min=0.000337, max=0.597748
Mean absolute pixel diff (after clamp to [0,1]): mean=0.016718, std=0.000805


### Test on Target Models 

In [24]:
import torch
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm

batch_size = 32  # tune this for your GPU
dataset = TensorDataset(adv_all, labels_all)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False,
                    num_workers=4, pin_memory=True)

In [27]:
target_models = [
    get_models("imagenet", "resnet152", "resnet152"),
    get_models("imagenet", "wide_resnet101_2", "wrn101_2"),     
    get_models("imagenet", "regnety_320", "regnety_320"),
    # get_models("imagenet", "vgg19", "vgg19"),
    # get_models("imagenet", "vit_base_patch16_224", "vit_b"),
    # get_models("imagenet", "deit_base_patch16_224", "deit_b"),
    # get_models("imagenet", "swin_base_patch4_window7_224", "swin_b"), 
    # get_models("imagenet", "mixer_b16_224", "vit_t"), 
    # get_models("imagenet", "convmixer_768_32", "vit_t")
] 

In [28]:
with torch.no_grad():
    for t_model in target_models:
        name = getattr(t_model, "name", t_model.__class__.__name__)
        t_model.eval()
        t_model.to(device)

        fooled = 0
        total = 0

        for imgs_cpu, labels_cpu in tqdm(loader, desc=f"ASR {name}"):
            # Move to device here
            imgs = imgs_cpu.to(device, non_blocking=True)
            labels = labels_cpu.to(device, non_blocking=True)

            outputs = t_model(imgs)
            if isinstance(outputs, (tuple, list)):
                outputs = outputs[0]
            preds = outputs.argmax(dim=1)

            fooled += (preds != labels).sum().item()
            total += labels.size(0)

            # free cache per batch (helps on tight GPUs)
            if device.type == "cuda":
                torch.cuda.empty_cache()

        asr = 100.0 * fooled / total if total > 0 else 0.0
        print(f"{name}: ASR = {asr:.2f}%  ({fooled}/{total} fooled)")

ASR Sequential: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 32/32 [00:03<00:00,  8.54it/s]


Sequential: ASR = 53.20%  (532/1000 fooled)


ASR Sequential: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 32/32 [00:04<00:00,  6.92it/s]


Sequential: ASR = 72.50%  (725/1000 fooled)


ASR Sequential: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 32/32 [00:10<00:00,  3.07it/s]

Sequential: ASR = 52.20%  (522/1000 fooled)



