In [1]:
import timm
import time 
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision import datasets 
from torch.utils.data import DataLoader
# from medmnist import INFO
import numpy as np
import faiss
import copy
from tqdm import tqdm

from torch.nn.functional import softmax, cosine_similarity
from collections import Counter
import matplotlib.pyplot as plt
import torchvision.transforms.functional as TF
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import os 

import warnings
warnings.filterwarnings("ignore")

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print("Using device:", device)

Using device: cuda:0


In [3]:
device

device(type='cuda', index=0)

In [4]:
import torch
from torch.utils.data.dataset import Subset
from torch.utils.data import DataLoader

# Allow Subset class for unpickling
torch.serialization.add_safe_globals([Subset])

# Load the datasets
test_subset = torch.load("data/cifar10_selected_test.pt", weights_only=False)
val_subset  = torch.load("data/cifar10_extended_val.pt", weights_only=False)

testloader = DataLoader(test_subset, batch_size=1, shuffle=False)
valloader  = DataLoader(val_subset,  batch_size=32, shuffle=True)

### Load Models 

In [5]:
import model_helper as helper

ens_models_args = [
    "resnet18", 
    "inception_v3", 
    "deit_tiny_patch16_224", 
    "vit_tiny_patch16_224", 
    "efficientnet_b0", 
    "gcvit_tiny"
    ]

ens_models = [] 

for i in ens_models_args: 
    model = helper.load_model_hub(i)
    model = model.to(device)
    ens_models.append(model.eval())


ðŸ”¹ Loading deit_tiny_patch16_224 from Models/deit_tiny.pth.tar
Load result: <All keys matched successfully>

ðŸ”¹ Loading vit_tiny_patch16_224 from Models/vit_tiny.pth.tar
Load result: <All keys matched successfully>


In [6]:
class NormalizedModel(nn.Module):
    def __init__(self, model, mean, std):
        super().__init__()
        self.model = model
        self.register_buffer('mean', torch.tensor(mean).view(1,3,1,1))
        self.register_buffer('std', torch.tensor(std).view(1,3,1,1))

    def forward(self, x):
        x = (x - self.mean) / self.std
        return self.model(x)

def get_last_linear_layer(model):
    """
    Find the last Linear layer in timm models or wrapped models (NormalizedModel).
    """
    # unwrap NormalizedModel if needed
    if isinstance(model, NormalizedModel):
        model = model.model

    # Common attribute names for classifier heads
    candidate_attrs = ['head', 'heads', 'fc', 'classifier', 'mlp_head']

    for attr in candidate_attrs:
        if hasattr(model, attr):
            layer = getattr(model, attr)
            # If it's a Linear layer
            if isinstance(layer, nn.Linear):
                return layer
            # If it's a Sequential or Module, search inside
            if isinstance(layer, nn.Module):
                last_linear = None
                for m in reversed(list(layer.modules())):
                    if isinstance(m, nn.Linear):
                        last_linear = m
                        break
                if last_linear is not None:
                    return last_linear

    # Fallback: scan all modules
    last_linear = None
    for m in model.modules():
        if isinstance(m, nn.Linear):
            last_linear = m
    if last_linear is not None:
        return last_linear

    raise RuntimeError(f"No Linear layer found in model {model.__class__.__name__}")



def get_features_before_last_linear(model, x):
    """
    Extract features before the final classifier, works for CNNs and ViTs.
    """
    # unwrap NormalizedModel if present
    if isinstance(model, NormalizedModel):
        model = model.model

    # Common classifier attributes
    candidate_attrs = ['head', 'heads', 'fc', 'classifier', 'mlp_head']
    classifier = None
    for attr in candidate_attrs:
        if hasattr(model, attr):
            classifier = getattr(model, attr)
            break

    features = {}

    def hook(module, input, output):
        features['feat'] = input[0].detach()

    if classifier is not None:
        handle = classifier.register_forward_hook(hook)
    else:
        # fallback: attach hook to last module
        last_module = list(model.modules())[-1]
        handle = last_module.register_forward_hook(hook)

    model.eval()
    with torch.no_grad():
        _ = model(x)
    handle.remove()

    if 'feat' not in features:
        raise RuntimeError(f"Failed to capture features from model {model.__class__.__name__}")

    return features['feat']

In [7]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import DataLoader
import numpy as np
import faiss
from tqdm import tqdm
import torch.nn.functional as F


def resize_for_dino(x):
    return F.interpolate(x, size=(224, 224), mode="bilinear", align_corners=False)
    
def softmax(x, dim=0):
    return F.softmax(x, dim=dim)


def cosine_similarity(x, y, dim=1, eps=1e-8):
    return F.cosine_similarity(x, y, dim=dim, eps=eps)


# ðŸ‘€ Dummy visualization (replace with your function)
def visualize_test_and_roc(test_img, roc_imgs, local_labels):
    print("Visualization placeholder: Test image + RoC samples")
    print(f"RoC labels: {local_labels}")


class VisionDES_2: 
    def __init__(self, dsel_dataset, pool): 
        self.dsel_dataset = dsel_dataset
        self.dsel_loader = DataLoader(dsel_dataset, batch_size=32, shuffle=False) 
        self.dino_model = timm.create_model('vit_base_patch16_224_dino', pretrained=True).to(device)
        self.dino_model.eval()  
        self.pool = pool 

        self.suspected_model_votes = [] 
        
        
    def dino_embedder(self, images):
        if images.shape[1] == 1:
            images = images.repeat(1, 3, 1, 1)
        return self.dino_model.forward_features(images)


    def fit(self): 
        dsel_embeddings = []
        dsel_labels = []
    
        with torch.no_grad():
            for imgs, labels in tqdm(self.dsel_loader):
                imgs = imgs.to(device)
                imgs_resized = resize_for_dino(imgs)
                embs = self.dino_embedder(imgs_resized).cpu()
                dsel_embeddings.append(embs)
                dsel_labels.append(labels)
    
        # Keep as tensor
        dsel_embeddings_tensor = torch.cat(dsel_embeddings).detach().cpu()  
        cls_tensor = dsel_embeddings_tensor[:, 0, :]  
    
        # Convert to NumPy
        cls_embeddings = np.ascontiguousarray(cls_tensor.numpy(), dtype='float32')
        self.dsel_embeddings = cls_embeddings
        self.dsel_labels = torch.cat(dsel_labels).numpy()
    
        # Build FAISS index
        embedding_dim = cls_embeddings.shape[1]
        self.index = faiss.IndexFlatL2(embedding_dim)
        self.index.add(cls_embeddings)

    
    def get_top_n_competent_models(self, test_img, k=7, top_n=3, use_sim=False, sim_threshold=0, alpha=0.6):
        # Step 1: Get DINO CLS embedding for the test image
        # img_for_dino = test_img.unsqueeze(0).to(device)
        img_for_dino = resize_for_dino(test_img.unsqueeze(0).to(device))

        with torch.no_grad():
            test_emb = self.dino_model.forward_features(img_for_dino).cpu().numpy().astype('float32')
            test_emb = test_emb[:, 0, :]  # CLS token only
    
        # Step 2: Find k nearest neighbors in FAISS
        distances, neighbors = self.index.search(test_emb, k)
        neighbor_idxs = neighbors[0]
        local_labels = np.array(self.dsel_labels[neighbor_idxs]).flatten()
    
        # Step 3: Get RoC images
        with torch.no_grad():
            roc_imgs = torch.stack([self.dsel_dataset[idx][0] for idx in neighbor_idxs]).to(device)
    
        # Step 4: Evaluate classifiers
        competences, feature_similarities, correct_counts, grad_vectors  = [], [], [], []
    
        for clf in self.pool:
            clf.eval()
            with torch.no_grad():
                outputs = clf(roc_imgs)
                preds = outputs.argmax(dim=1).cpu().numpy()
                correct = (preds == local_labels).sum()
                competences.append(correct / k)
                correct_counts.append(correct)
    
                # Feature similarity
                test_feat = get_features_before_last_linear(clf, test_img.unsqueeze(0).to(device))
                roc_feats = get_features_before_last_linear(clf, roc_imgs)
                mean_feat = roc_feats.mean(dim=0, keepdim=True)
                sim = cosine_similarity(test_feat.flatten().unsqueeze(0), mean_feat.flatten().unsqueeze(0))
                feature_similarities.append(sim.item())

            test_img_req = test_img.clone().detach().unsqueeze(0).to(device)
            test_img_req.requires_grad_(True)

            out = clf(test_img_req)
            pseudo_label = out.argmax(dim=1)  # keep grad path alive
            loss = F.cross_entropy(out, pseudo_label)
            grad = torch.autograd.grad(loss, test_img_req)[0]
            grad_vec = grad.flatten()  # flatten for cosine similarity
            grad_vectors.append(grad_vec.cpu())


        # Step 5: Compute gradient-based diversity
        grads_tensor = torch.stack(grad_vectors).to(device) 
        # L2-normalize per-model gradient vector
        grads_norm = F.normalize(grads_tensor, p=2, dim=1, eps=1e-8)  # (K, D)
        
        # compute cosine similarity matrix
        cos_sim = grads_norm @ grads_norm.t()  # (K, K)  (equivalent to pairwise cosine) 

        # zero out diagonal (self-sim = 1)
        K = cos_sim.size(0)
        cos_sim.fill_diagonal_(0.0)
        
        # average similarity over other models only
        mean_sim_other = cos_sim.sum(dim=1) / float(K - 1)  # (K,)
        
        # diversity score: lower similarity -> higher diversity
        diversity_scores = (mean_sim_other).cpu().numpy()
        diversity_scores = (diversity_scores - diversity_scores.min()) / (diversity_scores.max() - diversity_scores.min() + 1e-8)

        # print("Diversity", diversity_scores)
        # print("Competences", competences) 

        final_scores = [alpha * c + (1 - alpha) * d for c, d in zip(competences, diversity_scores)] 
    
        # Step 7: Select top_n models
        top_indices = np.argsort(final_scores)[::-1][:top_n]
        # print("final_scores", final_scores)
        # print("Top_indices", top_indices)
        top_models = [self.pool[i] for i in top_indices]
    
        return top_models

In [8]:
des_model = VisionDES_2(val_subset, ens_models)
des_model.fit()

100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 438/438 [00:41<00:00, 10.47it/s]


### Ensemble Attack 

In [9]:
import argparse

def get_args():
    parser = argparse.ArgumentParser(description='SMER')

    parser.add_argument('--dataset', type=str, default='imagenet_compatible')
    parser.add_argument('--batch-size', type=int, default=50)
    parser.add_argument('--image-size', type=int, default=224)
    parser.add_argument('--num_worker', type=int, default=4)
    parser.add_argument('--attack_method', type=str, default='MI_FGSM_SMER')
    parser.add_argument('--image-dir', type=str)
    parser.add_argument('--image-info', type=str, default='')
    parser.add_argument('--gpu-id', type=int, default=0)

    # attack params
    parser.add_argument('--eps', type=float, default=8.0)
    parser.add_argument('--alpha', type=float, default=2)
    parser.add_argument('--iters', type=int, default=10)
    parser.add_argument('--momentum', type=float, default=1.0)
    parser.add_argument('--beta', type=float, default=10)

    # FIX for Jupyter
    args, unknown = parser.parse_known_args()
    return args

# Correct call
args = get_args()

In [25]:
import ensemble_attacks
from torchmetrics.functional.image import structural_similarity_index_measure as ssim
import torch

# --- before loop (clear previous lists) ---
adv_list = []
orig_list = []
labels_list = []
noise_rates = []
pixel_diffs = []

max_value = 1.0 
min_value = 0.0 
eps = 8/255
alpha = 2/255 
iters = 10 
threshold = -0.3
beta = 10 



def ensure_batch(x):
    return x if x.dim() == 4 else x.unsqueeze(0)

def to_unit_range(x):
    """
    Ensure x is in [0,1]. If tensor values appear to be in [0,255] (max>1.5),
    convert by dividing by 255. Returns a float tensor on same device.
    """
    x = ensure_batch(x).float()
    if x.max().item() > 1.5:
        x = x / 255.0
    return torch.clamp(x, 0.0, 1.0)

# --- attack loop (same as yours, but using the simplified functions) ---
for img, label in tqdm(testloader, desc="Generating DEA adversarials (GPU)"):
    img, label = img.to(device), label.to(device)

    with torch.enable_grad():
        selected_models = des_model.get_top_n_competent_models(
                img[0],                    
                k=7,
                top_n=4,
                use_sim=False,
                sim_threshold=0,
                alpha=0.4
            )

        attack_method = ensemble_attacks.AdaEA_MIFGSM(selected_models, eps=eps, alpha=alpha, iters=iters, max_value=max_value, 
                             min_value=min_value, beta=beta, threshold=threshold, device=device)

        
        adv_img = ensemble_attacks.ensemble_mi_fgsm(selected_models, img, label, eps=eps, alpha=alpha, iters=iters, clip_min=0.0, clip_max=1.0, device=device)
        # adv_img = ensemble_attacks.ensemble_svre_mi_fgsm(
        #     selected_models, img, label,
        #     eps=8/255, alpha=2/255, iters=10,
        #     decay=1.0, sample_k=2, refresh=5
        # )
        # adv_img = attack_method(img, label) 
        # adv_img = ensemble_attacks.MI_FGSM_SMER(
        #     selected_models,
        #     img,
        #     label,
        #     args,
        #     num_iter=10
        # )
        
    # store for later (move to CPU)
    adv_list.append(adv_img.squeeze(0).cpu())
    orig_list.append(img.squeeze(0).cpu())
    labels_list.append(label.squeeze(0).cpu())

    # compute SSIM and pixel diffs on [0,1] images
    img_for_ssim = to_unit_range(img)       # (1,C,H,W) in [0,1]
    adv_for_ssim = to_unit_range(adv_img)   # (1,C,H,W) in [0,1]

    ssim_val = ssim(adv_for_ssim, img_for_ssim)  # scalar tensor
    noise_rates.append((1.0 - float(ssim_val)))
    pixel_diffs.append((adv_for_ssim - img_for_ssim).abs().mean().item())

# --- stack everything on CPU ---
adv_all = torch.stack(adv_list).cpu()
orig_all = torch.stack(orig_list).cpu()
labels_all = torch.stack(labels_list).cpu()

noise_rates = torch.tensor(noise_rates)
pixel_diffs = torch.tensor(pixel_diffs)

print(f"âœ… Generated {adv_all.size(0)} adversarial images. Shape: {adv_all.shape}")
print(f"Noise (1 - SSIM): mean={noise_rates.mean():.6f}, std={noise_rates.std():.6f}, min={noise_rates.min():.6f}, max={noise_rates.max():.6f}")
print(f"Mean absolute pixel diff (after clamp to [0,1]): mean={pixel_diffs.mean():.6f}, std={pixel_diffs.std():.6f}")

Generating DEA adversarials (GPU):   5%|â–‰                   | 46/1000 [00:18<06:24,  2.48it/s]


KeyboardInterrupt: 

### Test 

In [23]:
import torch
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm

batch_size = 32 
dataset = TensorDataset(adv_all, labels_all)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False,
                    num_workers=4, pin_memory=True)

In [13]:
import model_helper as helper

target_models_args = [
    "resnetv2_101x1_bitm", 
    "resnet152", 
    "regnety_160", 
    "vit_base_patch16_224", 
    "deit_base_patch16_224", 
    "swin_base_patch4_window7_224", 
    "convmixer_768_32"
    ]

target_models = [] 

for i in target_models_args: 
    model = helper.load_model_hub(i)
    model = model.to(device)
    target_models.append(model.eval())


ðŸ”¹ Loading resnetv2_101x1_bitm from Models/target/resnetv2_101x1_bitm.pth.tar
Load result: <All keys matched successfully>

ðŸ”¹ Loading vit_base_patch16_224 from Models/target/vit_base.pth.tar
Load result: <All keys matched successfully>

ðŸ”¹ Loading deit_base_patch16_224 from Models/target/deit_base.pth.tar
Load result: <All keys matched successfully>

ðŸ”¹ Loading swin_base_patch4_window7_224 from Models/target/swin_base.pth.tar
Load result: <All keys matched successfully>


In [24]:
with torch.no_grad():
    for t_model in target_models:
        name = getattr(t_model, "name", t_model.__class__.__name__)
        t_model.eval()
        t_model.to(device)

        fooled = 0
        total = 0

        for imgs_cpu, labels_cpu in tqdm(loader, desc=f"ASR {name}"):
            # Move to device here
            imgs = imgs_cpu.to(device, non_blocking=True)
            labels = labels_cpu.to(device, non_blocking=True)

            outputs = t_model(imgs)
            if isinstance(outputs, (tuple, list)):
                outputs = outputs[0]
            preds = outputs.argmax(dim=1)

            fooled += (preds != labels).sum().item()
            total += labels.size(0)

            # free cache per batch (helps on tight GPUs)
            if device.type == "cuda":
                torch.cuda.empty_cache()

        asr = 100.0 * fooled / total if total > 0 else 0.0
        print(f"{name}: ASR = {asr:.2f}%  ({fooled}/{total} fooled)")

ASR Sequential: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 32/32 [00:03<00:00,  8.14it/s]


Sequential: ASR = 96.80%  (968/1000 fooled)


ASR Sequential: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 32/32 [00:02<00:00, 11.68it/s]


Sequential: ASR = 94.50%  (945/1000 fooled)


ASR Sequential: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 32/32 [00:03<00:00,  8.77it/s]


Sequential: ASR = 92.00%  (920/1000 fooled)


ASR Sequential: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 32/32 [00:03<00:00,  9.11it/s]


Sequential: ASR = 77.00%  (770/1000 fooled)


ASR Sequential: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 32/32 [00:03<00:00,  9.29it/s]


Sequential: ASR = 84.60%  (846/1000 fooled)


ASR Sequential: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 32/32 [00:04<00:00,  7.15it/s]


Sequential: ASR = 87.80%  (878/1000 fooled)


ASR Sequential: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 32/32 [00:06<00:00,  4.75it/s]

Sequential: ASR = 94.50%  (945/1000 fooled)





In [14]:
timm.__version__

'1.0.22'