In [2]:
import timm
import time 
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision import datasets 
from torch.utils.data import DataLoader
# from medmnist import INFO
import numpy as np
import faiss
import copy
from tqdm import tqdm

from torch.nn.functional import softmax, cosine_similarity
from collections import Counter
import matplotlib.pyplot as plt
import torchvision.transforms.functional as TF
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import os 

import warnings
warnings.filterwarnings("ignore")

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print("Using device:", device)

Using device: cuda:0


In [4]:
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from PIL import Image

class CustomImageListDataset(torch.utils.data.Dataset):
    def __init__(self, file_list, class_to_idx, transform=None):
        with open(file_list, "r") as f:
            self.samples = [line.strip() for line in f]
        self.transform = transform
        self.class_to_idx = class_to_idx

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path = self.samples[idx]
        class_folder = os.path.basename(os.path.dirname(img_path))
        label = self.class_to_idx.get(class_folder, -1)
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, label


# ---------------- Create a combined class mapping ----------------
root_dir = "dataset/imagenet_tests"
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Collect class mappings from all 10 partitions
combined_class_to_idx = {}
for i in range(1, 11):
    test_dir = os.path.join(root_dir, f"test{i}")
    dataset = datasets.ImageFolder(test_dir, transform=transform)
    combined_class_to_idx.update(dataset.class_to_idx)

print(f"âœ… Combined class mapping built: {len(combined_class_to_idx)} total classes")

# ---------------- Load your 1000-image subset ----------------
subset_file = "results/hard_cases_missed_by_mobilenet.txt"
hard_dataset = CustomImageListDataset(subset_file, class_to_idx=combined_class_to_idx, transform=transform)
hard_loader = DataLoader(hard_dataset, batch_size=1, shuffle=False)

print(f"âœ… Loaded {len(hard_dataset)} hard samples")

âœ… Combined class mapping built: 1000 total classes
âœ… Loaded 1000 hard samples


In [5]:
import os
from torchvision import datasets, transforms
import torch

print(f"Step 1: Loading dataset with resize transform...")

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

val_data_dir = 'dataset/imagenet_validation' 

val_dataset = datasets.ImageFolder(os.path.join(val_data_dir), transform=transform) 


print(f"Validation samples: {len(val_dataset)}")

Step 1: Loading dataset with resize transform...
Validation samples: 30000


In [6]:
def get_models(dataset, model_name, key): 
    if dataset == 'imagenet':
        # save_root_path = r"checkpoint/tinyimagenet"
        model = timm.create_model(model_name, pretrained=True, num_classes=1000).to(device)
        model.eval()
        if 'inc' in key or 'vit' in key or 'bit' in key:
            return torch.nn.Sequential(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), model)
        else:
            return torch.nn.Sequential(transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), model)

### Ensemble Attack 

In [7]:
import torch
import torch.nn.functional as F
from typing import List

def ensemble_mi_fgsm(
    models: List[torch.nn.Module],
    x: torch.Tensor,
    y: torch.Tensor,
    eps: float = 8/255,
    alpha: float = 2/255,
    iters: int = 10,
    decay: float = 1.0,
    clip_min: float = 0.0,
    clip_max: float = 1.0,
    loss_fn=None,
    device: str = None,
):
    if device is None:
        device = x.device

    if not isinstance(models, (list, tuple)):
        models = [models]

    for m in models:
        m.to(device).eval()

    if loss_fn is None:
        loss_fn = torch.nn.CrossEntropyLoss(reduction='mean')

    x_orig = x.clone().detach().to(device).float()
    x_adv = x_orig.clone().detach()
    momentum = torch.zeros_like(x_adv).to(device)
    y = y.to(device)

    for _ in range(iters):
        x_adv.requires_grad_(True)

        # ----- Liu et al. (2017): sum/average logits before loss -----
        sum_logits = None
        for m in models:
            out = m(x_adv)
            if isinstance(out, (tuple, list)):
                out = out[0]
            sum_logits = out if sum_logits is None else sum_logits + out
        avg_logits = sum_logits / len(models)
        total_loss = loss_fn(avg_logits, y)
        # --------------------------------------------------------------

        grad = torch.autograd.grad(total_loss, x_adv, retain_graph=False, create_graph=False)[0]
        grad = grad / (torch.norm(grad, p=1) + 1e-8)

        # momentum update (MI-FGSM)
        momentum = decay * momentum + grad
        step = alpha * momentum.sign()

        x_adv = x_adv.detach() + step.detach()
        delta = torch.clamp(x_adv - x_orig, min=-eps, max=eps)
        x_adv = torch.clamp(x_orig + delta, min=clip_min, max=clip_max).detach()

    return x_adv


In [8]:
import torch
import torch.nn.functional as F
from typing import List

def ensemble_pgd(
    models: List[torch.nn.Module],
    x: torch.Tensor,
    y: torch.Tensor,
    eps: float = 8/255,
    alpha: float = 2/255,
    iters: int = 10,
    clip_min: float = 0.0,
    clip_max: float = 1.0,
    loss_fn=None,
    device: str = None,
    random_start: bool = False,
):
    """
    Ensemble PGD (average-logits) â€” same ensemble formulation as your MI-FGSM function
    but using PGD (no momentum).
    - models: list of models (or single model)
    - random_start: if True, initialize x_adv with a random perturbation within the eps-ball
    """
    if device is None:
        device = x.device

    if not isinstance(models, (list, tuple)):
        models = [models]

    for m in models:
        m.to(device).eval()

    if loss_fn is None:
        loss_fn = torch.nn.CrossEntropyLoss(reduction='mean')

    x_orig = x.clone().detach().to(device).float()

    # optionally random start inside the L-inf ball
    if random_start:
        x_adv = x_orig + torch.empty_like(x_orig).uniform_(-eps, eps)
        x_adv = torch.clamp(x_adv, min=clip_min, max=clip_max).detach()
    else:
        x_adv = x_orig.clone().detach()

    y = y.to(device)

    for _ in range(iters):
        x_adv.requires_grad_(True)

        # ----- Liu et al. (2017): sum/average logits before loss -----
        sum_logits = None
        for m in models:
            out = m(x_adv)
            if isinstance(out, (tuple, list)):
                out = out[0]
            sum_logits = out if sum_logits is None else sum_logits + out
        avg_logits = sum_logits / len(models)
        total_loss = loss_fn(avg_logits, y)
        # --------------------------------------------------------------

        grad = torch.autograd.grad(total_loss, x_adv, retain_graph=False, create_graph=False)[0]
        # normalize (L1) like in your MI-FGSM code
        grad = grad / (torch.norm(grad, p=1) + 1e-8)

        # PGD step (no momentum) â€” use sign to match typical PGD / I-FGSM
        step = alpha * grad.sign()
        x_adv = x_adv.detach() + step.detach()

        # project back to epsilon L-inf ball around original and clip to valid range
        delta = torch.clamp(x_adv - x_orig, min=-eps, max=eps)
        x_adv = torch.clamp(x_orig + delta, min=clip_min, max=clip_max).detach()

    return x_adv

In [9]:
import torch
import torch.nn.functional as F
from torch import nn
from typing import List, Union, Optional

def get_gaussian_kernel(kernel_size: int = 15, sigma: float = 3.0, channels: int = 3, device='cpu'):
    ax = torch.arange(-(kernel_size // 2), kernel_size // 2 + 1, device=device, dtype=torch.float32)
    xx, yy = torch.meshgrid(ax, ax, indexing='xy')
    kernel = torch.exp(-(xx**2 + yy**2) / (2.0 * sigma**2))
    kernel = kernel / kernel.sum()
    kernel = kernel.view(1, 1, kernel_size, kernel_size)
    kernel = kernel.repeat(channels, 1, 1, 1)  # grouped conv
    return kernel

def diverse_input(x: torch.Tensor, prob: float = 0.7, resize_min: int = 290, resize_max: int = 330):
    """
    Random resize + pad/crop. Works on x of shape (B,C,H,W).
    If probability check fails, returns x unchanged.
    """
    if prob <= 0 or torch.rand(1).item() > prob:
        return x
    B, C, H, W = x.shape
    # choose a random size (square)
    size = int(torch.randint(resize_min, resize_max + 1, (1,)).item())
    x_resized = F.interpolate(x, size=(size, size), mode='bilinear', align_corners=False)
    pad_h = max(0, H - size)
    pad_w = max(0, W - size)
    pad_top = int(torch.randint(0, pad_h + 1, (1,)).item()) if pad_h > 0 else 0
    pad_left = int(torch.randint(0, pad_w + 1, (1,)).item()) if pad_w > 0 else 0
    pad_bottom = pad_h - pad_top
    pad_right = pad_w - pad_left
    x_padded = F.pad(x_resized, (pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0.0)
    # if resized is bigger than original, center crop
    if size > H:
        start_h = (size - H) // 2
        start_w = (size - W) // 2
        x_padded = x_padded[:, :, start_h:start_h+H, start_w:start_w+W]
    return x_padded

def ensemble_ti_dim_mi_fgsm(
    models: Union[nn.Module, List[nn.Module]],
    images: torch.Tensor,
    labels: torch.Tensor,
    eps: float = 8/255,
    iters: int = 10,
    alpha: Optional[float] = None,
    decay: float = 1.0,                 # momentum mu
    prob: float = 0.7,                  # DIM probability
    resize_min: int = 290,
    resize_max: int = 330,
    kernel_size: int = 15,              # TI kernel size
    sigma: float = 3.0,                 # TI gaussian sigma
    ensemble_mode: str = 'logits',      # 'logits' | 'grad_avg' | 'grad_sum'
    targeted: bool = False,
    clip_min: float = 0.0,
    clip_max: float = 1.0,
    device: Optional[str] = None,
    normalize: bool = False,
    mean: Optional[torch.Tensor] = None,
    std: Optional[torch.Tensor] = None,
    loss_fn: Optional[nn.Module] = None,
):
    """
    Ensemble TI-DIM-MI-FGSM attack.
    - models: single model or list of models
    - images: float tensor in [0,1], shape (B,C,H,W)
    - labels: long tensor (B,)
    - ensemble_mode:
        'logits'  -> average logits then compute loss (Liu et al. style)
        'grad_avg'-> compute per-model grads and average them
        'grad_sum'-> sum per-model grads (sum of losses)
    Returns adversarial images (B,C,H,W)
    """
    if isinstance(models, nn.Module):
        models = [models]
    assert isinstance(models, (list, tuple)) and len(models) > 0

    if device is None:
        device = images.device

    for m in models:
        m.to(device).eval()

    images = images.clone().detach().to(device).float()
    orig = images.clone().detach()
    labels = labels.to(device)

    B, C, H, W = images.shape
    if alpha is None:
        alpha = eps / float(iters)

    if loss_fn is None:
        loss_fn = nn.CrossEntropyLoss(reduction='mean')

    kernel = get_gaussian_kernel(kernel_size, sigma, channels=C, device=device)
    momentum = torch.zeros_like(images).to(device)

    if normalize:
        assert mean is not None and std is not None
        mean = mean.to(device).view(1, C, 1, 1)
        std = std.to(device).view(1, C, 1, 1)
        def norm_fn(x): return (x - mean) / std
    else:
        norm_fn = lambda x: x

    x_adv = images.clone().detach()

    for _ in range(iters):
        x_adv.requires_grad_(True)

        # apply diverse input transform (stochastic)
        x_in = diverse_input(x_adv, prob=prob, resize_min=resize_min, resize_max=resize_max)
        x_in_norm = norm_fn(x_in)

        if ensemble_mode == 'logits':
            # average logits across models then compute loss
            sum_logits = None
            for m in models:
                out = m(x_in_norm)
                if isinstance(out, (tuple, list)):
                    out = out[0]
                sum_logits = out if sum_logits is None else sum_logits + out
            avg_logits = sum_logits / len(models)
            loss = loss_fn(avg_logits, labels)
            if targeted:
                loss = -loss
            grad = torch.autograd.grad(loss, x_adv, retain_graph=False, create_graph=False)[0]

        else:
            # compute per-model grads (w.r.t x_adv) and combine them
            grads = []
            for m in models:
                out = m(x_in_norm)
                if isinstance(out, (tuple, list)):
                    out = out[0]
                loss_m = loss_fn(out, labels)
                if targeted:
                    loss_m = -loss_m
                g = torch.autograd.grad(loss_m, x_adv, retain_graph=True, create_graph=False)[0]
                grads.append(g)
            # combine grads
            if ensemble_mode == 'grad_sum':
                grad = sum(grads)
            else:  # 'grad_avg'
                grad = sum(grads) / float(len(grads))

        # TI: smooth gradient with gaussian kernel (grouped conv)
        grad_conv = F.conv2d(grad, weight=kernel, bias=None, stride=1, padding=kernel_size//2, groups=C)

        # normalize gradient per-sample by mean absolute value (stability)
        denom = torch.mean(torch.abs(grad_conv), dim=(1,2,3), keepdim=True) + 1e-12
        grad_norm = grad_conv / denom

        # momentum update
        momentum = decay * momentum + grad_norm

        # MI-FGSM step (sign of momentum)
        if targeted:
            x_adv = x_adv.detach() - alpha * torch.sign(momentum)
        else:
            x_adv = x_adv.detach() + alpha * torch.sign(momentum)

        # clip to eps-ball and valid range
        x_adv = torch.max(torch.min(x_adv, orig + eps), orig - eps)
        x_adv = torch.clamp(x_adv, clip_min, clip_max).detach()

        # zero grads
        for m in models:
            m.zero_grad()
        if x_adv.grad is not None:
            x_adv.grad.detach_()
            x_adv.grad.zero_()

    return x_adv


### DEAA 

#### Helper (DEAA)

In [10]:
class NormalizedModel(nn.Module):
    def __init__(self, model, mean, std):
        super().__init__()
        self.model = model
        self.register_buffer('mean', torch.tensor(mean).view(1,3,1,1))
        self.register_buffer('std', torch.tensor(std).view(1,3,1,1))

    def forward(self, x):
        x = (x - self.mean) / self.std
        return self.model(x)

def get_last_linear_layer(model):
    """
    Find the last Linear layer in timm models or wrapped models (NormalizedModel).
    """
    # unwrap NormalizedModel if needed
    if isinstance(model, NormalizedModel):
        model = model.model

    # Common attribute names for classifier heads
    candidate_attrs = ['head', 'heads', 'fc', 'classifier', 'mlp_head']

    for attr in candidate_attrs:
        if hasattr(model, attr):
            layer = getattr(model, attr)
            # If it's a Linear layer
            if isinstance(layer, nn.Linear):
                return layer
            # If it's a Sequential or Module, search inside
            if isinstance(layer, nn.Module):
                last_linear = None
                for m in reversed(list(layer.modules())):
                    if isinstance(m, nn.Linear):
                        last_linear = m
                        break
                if last_linear is not None:
                    return last_linear

    # Fallback: scan all modules
    last_linear = None
    for m in model.modules():
        if isinstance(m, nn.Linear):
            last_linear = m
    if last_linear is not None:
        return last_linear

    raise RuntimeError(f"No Linear layer found in model {model.__class__.__name__}")



def get_features_before_last_linear(model, x):
    """
    Extract features before the final classifier, works for CNNs and ViTs.
    """
    # unwrap NormalizedModel if present
    if isinstance(model, NormalizedModel):
        model = model.model

    # Common classifier attributes
    candidate_attrs = ['head', 'heads', 'fc', 'classifier', 'mlp_head']
    classifier = None
    for attr in candidate_attrs:
        if hasattr(model, attr):
            classifier = getattr(model, attr)
            break

    features = {}

    def hook(module, input, output):
        features['feat'] = input[0].detach()

    if classifier is not None:
        handle = classifier.register_forward_hook(hook)
    else:
        # fallback: attach hook to last module
        last_module = list(model.modules())[-1]
        handle = last_module.register_forward_hook(hook)

    model.eval()
    with torch.no_grad():
        _ = model(x)
    handle.remove()

    if 'feat' not in features:
        raise RuntimeError(f"Failed to capture features from model {model.__class__.__name__}")

    return features['feat']

#### DEAA Main 

In [11]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import DataLoader
import numpy as np
import faiss
from tqdm import tqdm
import torch.nn.functional as F


def softmax(x, dim=0):
    return F.softmax(x, dim=dim)


def cosine_similarity(x, y, dim=1, eps=1e-8):
    return F.cosine_similarity(x, y, dim=dim, eps=eps)


# ðŸ‘€ Dummy visualization (replace with your function)
def visualize_test_and_roc(test_img, roc_imgs, local_labels):
    print("Visualization placeholder: Test image + RoC samples")
    print(f"RoC labels: {local_labels}")


class VisionDES_2: 
    def __init__(self, dsel_dataset, pool): 
        self.dsel_dataset = dsel_dataset
        self.dsel_loader = DataLoader(dsel_dataset, batch_size=32, shuffle=False) 
        self.dino_model = timm.create_model('vit_base_patch16_224.dino', pretrained=True).to(device)
        self.dino_model.eval()  
        self.pool = pool 

        self.suspected_model_votes = [] 
        
        
    def dino_embedder(self, images):
        if images.shape[1] == 1:
            images = images.repeat(1, 3, 1, 1)
        return self.dino_model.forward_features(images)


    def fit(self): 
        dsel_embeddings = []
        dsel_labels = []
    
        with torch.no_grad():
            for imgs, labels in tqdm(self.dsel_loader):
                imgs = imgs.to(device)
                embs = self.dino_embedder(imgs).cpu()  
                dsel_embeddings.append(embs)
                dsel_labels.append(labels)
    
        # Keep as tensor
        dsel_embeddings_tensor = torch.cat(dsel_embeddings).detach().cpu()  
        cls_tensor = dsel_embeddings_tensor[:, 0, :]  
    
        # Convert to NumPy
        cls_embeddings = np.ascontiguousarray(cls_tensor.numpy(), dtype='float32')
        self.dsel_embeddings = cls_embeddings
        self.dsel_labels = torch.cat(dsel_labels).numpy()
    
        # Build FAISS index
        embedding_dim = cls_embeddings.shape[1]
        self.index = faiss.IndexFlatL2(embedding_dim)
        self.index.add(cls_embeddings)

    
    def get_top_n_competent_models(self, test_img, k=7, top_n=3, use_sim=False, sim_threshold=0, alpha=0.6):
        # Step 1: Get DINO CLS embedding for the test image
        img_for_dino = test_img.unsqueeze(0).to(device)

        with torch.no_grad():
            test_emb = self.dino_model.forward_features(img_for_dino).cpu().numpy().astype('float32')
            test_emb = test_emb[:, 0, :]  # CLS token only
    
        # Step 2: Find k nearest neighbors in FAISS
        distances, neighbors = self.index.search(test_emb, k)
        neighbor_idxs = neighbors[0]
        local_labels = np.array(self.dsel_labels[neighbor_idxs]).flatten()
    
        # Step 3: Get RoC images
        with torch.no_grad():
            roc_imgs = torch.stack([self.dsel_dataset[idx][0] for idx in neighbor_idxs]).to(device)
    
        # Step 4: Evaluate classifiers
        competences, feature_similarities, correct_counts, grad_vectors  = [], [], [], []
    
        for clf in self.pool:
            clf.eval()
            with torch.no_grad():
                outputs = clf(roc_imgs)
                preds = outputs.argmax(dim=1).cpu().numpy()
                correct = (preds == local_labels).sum()
                competences.append(correct / k)
                correct_counts.append(correct)
    
                # Feature similarity
                test_feat = get_features_before_last_linear(clf, test_img.unsqueeze(0).to(device))
                roc_feats = get_features_before_last_linear(clf, roc_imgs)
                mean_feat = roc_feats.mean(dim=0, keepdim=True)
                sim = cosine_similarity(test_feat.flatten().unsqueeze(0), mean_feat.flatten().unsqueeze(0))
                feature_similarities.append(sim.item())

            test_img_req = test_img.clone().detach().unsqueeze(0).to(device)
            test_img_req.requires_grad_(True)

            out = clf(test_img_req)
            pseudo_label = out.argmax(dim=1)  # keep grad path alive
            loss = F.cross_entropy(out, pseudo_label)
            grad = torch.autograd.grad(loss, test_img_req)[0]
            grad_vec = grad.flatten()  # flatten for cosine similarity
            grad_vectors.append(grad_vec.cpu())


        # Step 5: Compute gradient-based diversity
        grads_tensor = torch.stack(grad_vectors).to(device) 
        # L2-normalize per-model gradient vector
        grads_norm = F.normalize(grads_tensor, p=2, dim=1, eps=1e-8)  # (K, D)
        
        # compute cosine similarity matrix
        cos_sim = grads_norm @ grads_norm.t()  # (K, K)  (equivalent to pairwise cosine) 

        # zero out diagonal (self-sim = 1)
        K = cos_sim.size(0)
        cos_sim.fill_diagonal_(0.0)
        
        # average similarity over other models only
        mean_sim_other = cos_sim.sum(dim=1) / float(K - 1)  # (K,)
        
        # diversity score: lower similarity -> higher diversity
        diversity_scores = (mean_sim_other).cpu().numpy()
        diversity_scores = (diversity_scores - diversity_scores.min()) / (diversity_scores.max() - diversity_scores.min() + 1e-8)

        # print("Diversity", diversity_scores)
        # print("Competences", competences) 

        final_scores = [alpha * c + (1 - alpha) * d for c, d in zip(competences, diversity_scores)] 
    
        # Step 7: Select top_n models
        top_indices = np.argsort(final_scores)[::-1][:top_n]
        # print("final_scores", final_scores)
        # print("Top_indices", top_indices)
        top_models = [self.pool[i] for i in top_indices]
    
        return top_models

### Setup 

In [12]:
ens_models = [
    get_models("imagenet", "resnet18", "resnet18"), 
    get_models("imagenet", "inception_v3", "inc_v3"), 
    get_models("imagenet", "deit_tiny_patch16_224", "deit_t"),
    get_models("imagenet", "vit_tiny_patch16_224", "vit_t"), 
    get_models("imagenet", "efficientnet_b0", "efficientnet_b0"), 
    get_models("imagenet", "xcit_tiny_12_p8_224", "swin_t"), 
]  

In [None]:
des_model = VisionDES_2(val_dataset, ens_models)
des_model.fit()

100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 938/938 [03:50<00:00,  4.07it/s]


In [13]:
saved_index = des_model.index
saved_dsel_embeddings = des_model.dsel_embeddings 
saved_dsel_labels = des_model.dsel_labels

# des_model.index = saved_index 
# des_model.dsel_embeddings = saved_dsel_embeddings 
# des_model.dsel_labels = saved_dsel_labels

In [14]:
from torchmetrics.functional.image import structural_similarity_index_measure as ssim
import torch

# --- before loop (clear previous lists) ---
adv_list = []
orig_list = []
labels_list = []
noise_rates = []
pixel_diffs = []

eps = 8/255
alpha = 2/255
iters = 10

def ensure_batch(x):
    return x if x.dim() == 4 else x.unsqueeze(0)

def to_unit_range(x):
    """
    Ensure x is in [0,1]. If tensor values appear to be in [0,255] (max>1.5),
    convert by dividing by 255. Returns a float tensor on same device.
    """
    x = ensure_batch(x).float()
    if x.max().item() > 1.5:
        x = x / 255.0
    return torch.clamp(x, 0.0, 1.0)

# --- attack loop (same as yours, but using the simplified functions) ---
for img, label in tqdm(hard_loader, desc="Generating MI-FGSM adversarials (GPU)"):
    img, label = img.to(device), label.to(device)

    with torch.enable_grad():
        selected_models = des_model.get_top_n_competent_models(
                img[0],                    
                k=7,
                top_n=2,
                use_sim=False,
                sim_threshold=0,
                alpha=0.4
            )
        
        adv_img = ensemble_mi_fgsm(selected_models, img, label, eps=eps, alpha=alpha, iters=iters, clip_min=0.0, clip_max=1.0, device=device)
        # adv_img = ensemble_pgd(selected_models, img, label, eps=eps, alpha=alpha, iters=iters, clip_min=0.0, clip_max=1.0, device=device)
        # adv_img = ensemble_ti_dim_mi_fgsm(ens_models, img, label,
        #                       eps=8/255, iters=10, alpha=2/255,
        #                       decay=1.0, prob=0.7,
        #                       kernel_size=15, sigma=1.0,
        #                       ensemble_mode='logits',
        #                       device=device)
    # store for later (move to CPU)
    adv_list.append(adv_img.squeeze(0).cpu())
    orig_list.append(img.squeeze(0).cpu())
    labels_list.append(label.squeeze(0).cpu())

    # compute SSIM and pixel diffs on [0,1] images
    img_for_ssim = to_unit_range(img)       # (1,C,H,W) in [0,1]
    adv_for_ssim = to_unit_range(adv_img)   # (1,C,H,W) in [0,1]

    ssim_val = ssim(adv_for_ssim, img_for_ssim)  # scalar tensor
    noise_rates.append((1.0 - float(ssim_val)))
    pixel_diffs.append((adv_for_ssim - img_for_ssim).abs().mean().item())

# --- stack everything on CPU ---
adv_all = torch.stack(adv_list).cpu()
orig_all = torch.stack(orig_list).cpu()
labels_all = torch.stack(labels_list).cpu()

noise_rates = torch.tensor(noise_rates)
pixel_diffs = torch.tensor(pixel_diffs)

print(f"âœ… Generated {adv_all.size(0)} adversarial images. Shape: {adv_all.shape}")
print(f"Noise (1 - SSIM): mean={noise_rates.mean():.6f}, std={noise_rates.std():.6f}, min={noise_rates.min():.6f}, max={noise_rates.max():.6f}")
print(f"Mean absolute pixel diff (after clamp to [0,1]): mean={pixel_diffs.mean():.6f}, std={pixel_diffs.std():.6f}")


Generating MI-FGSM adversarials (GPU): 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1000/1000 [27:10<00:00,  1.63s/it]


âœ… Generated 1000 adversarial images. Shape: torch.Size([1000, 3, 224, 224])
Noise (1 - SSIM): mean=0.161277, std=0.071326, min=0.014407, max=0.538927
Mean absolute pixel diff (after clamp to [0,1]): mean=0.025323, std=0.002007


### Test on Target Models 

In [15]:
import torch
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm

batch_size = 32  # tune this for your GPU
dataset = TensorDataset(adv_all, labels_all)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False,
                    num_workers=4, pin_memory=True)

In [16]:
target_models = [
    # get_models("imagenet", "resnet152", "resnet152"),
    # get_models("imagenet", "wide_resnet101_2", "wrn101_2"),     
    # get_models("imagenet", "regnety_320", "regnety_320"),
    # get_models("imagenet", "vgg19", "vgg19"),
    # get_models("imagenet", "vit_base_patch16_224", "vit_b"),
    # get_models("imagenet", "deit_base_patch16_224", "deit_b"),
    get_models("imagenet", "swin_base_patch4_window7_224", "swin_b"), 
    get_models("imagenet", "mixer_b16_224", "vit_t"), 
    get_models("imagenet", "convmixer_768_32", "vit_t")
] 

In [17]:
with torch.no_grad():
    for t_model in target_models:
        name = getattr(t_model, "name", t_model.__class__.__name__)
        t_model.eval()
        t_model.to(device)

        fooled = 0
        total = 0

        for imgs_cpu, labels_cpu in tqdm(loader, desc=f"ASR {name}"):
            # Move to device here
            imgs = imgs_cpu.to(device, non_blocking=True)
            labels = labels_cpu.to(device, non_blocking=True)

            outputs = t_model(imgs)
            if isinstance(outputs, (tuple, list)):
                outputs = outputs[0]
            preds = outputs.argmax(dim=1)

            fooled += (preds != labels).sum().item()
            total += labels.size(0)

            # free cache per batch (helps on tight GPUs)
            if device.type == "cuda":
                torch.cuda.empty_cache()

        asr = 100.0 * fooled / total if total > 0 else 0.0
        print(f"{name}: ASR = {asr:.2f}%  ({fooled}/{total} fooled)")

ASR Sequential: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 32/32 [00:05<00:00,  5.79it/s]


Sequential: ASR = 66.80%  (668/1000 fooled)


ASR Sequential: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 32/32 [00:03<00:00,  8.93it/s]


Sequential: ASR = 85.60%  (856/1000 fooled)


ASR Sequential: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 32/32 [00:08<00:00,  3.74it/s]

Sequential: ASR = 91.50%  (915/1000 fooled)



