# Attack Success Rate (ASR) Evaluation

This notebook evaluates the performance of a pre-trained ViT Generator on the **Test Set** of the Oxford Pets dataset.

**Goals:**
1.  Load the pre-trained (untargeted) generator.
2.  Load the Surrogate Model (RN50).
3.  Evaluate Clean Accuracy vs. Adversarial Accuracy on unseen Test Data.
4.  Calculate ASR (Attack Success Rate).

In [1]:
# 1) Mount Google Drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)


In [2]:
# 2) Setup Repo & Dependencies
!nvidia-smi
%cd /content
import os
if not os.path.exists("MFCLIP_acv"):
    !git clone -b hamza/discrim https://github.com/1hamzaiqbal/MFCLIP_acv
%cd MFCLIP_acv
!git fetch --all
!git reset --hard origin/hamza/discrim # Force sync to latest commit

!pip install torch torchvision timm einops yacs tqdm opencv-python scikit-learn scipy pyyaml ruamel.yaml pytorch-ignite foolbox pandas matplotlib seaborn wilds ftfy


In [3]:
# 3) Setup Data (Oxford Pets)
import shutil
import os
from torchvision.datasets import OxfordIIITPet
from torchvision import transforms
from pathlib import Path

# Download Dataset
root = Path("/content/data/oxford_pets")
root.mkdir(parents=True, exist_ok=True)
_ = OxfordIIITPet(root=str(root), download=True, transform=transforms.ToTensor())

# Fetch Annotations
%cd /content
!mkdir -p /content/data/oxford_pets
!wget -q https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz
!wget -q https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz
!tar -xf images.tar.gz -C /content/data/oxford_pets
!tar -xf annotations.tar.gz -C /content/data/oxford_pets

# Copy Surrogate Checkpoint (Required for Evaluation)
src_ckpt = "/content/drive/MyDrive/grad/comp_vision/hanson_loss/oxford_pets/RN50_ArcFace_oxford_pets.pth"
dst_ckpt = "/content/data/oxford_pets/RN50_ArcFace.pth"

if os.path.exists(src_ckpt):
    shutil.copy(src_ckpt, dst_ckpt)
    print(f"Successfully copied surrogate checkpoint to {dst_ckpt}")
else:
    raise FileNotFoundError(f"Surrogate checkpoint not found at {src_ckpt}. Cannot evaluate without the victim model!")


In [4]:
# 4) Load Models & Evaluate
%cd /content/MFCLIP_acv
import torch
import torch.nn as nn
from model import ViTGenerator
from utils.util import setup_cfg
from dass.engine import build_trainer
from loss.head.head_def import HeadFactory
from torchvision import transforms
from ruamel.yaml import YAML
import argparse

# Register Trainers (Critical step!)
import trainers.zsclip
import trainers.coop
import trainers.cocoop

# Register Datasets (Critical step!)
import datasets.oxford_pets
import datasets.oxford_flowers
import datasets.food101 # Added for Universal Transfer Test

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- B) Load Surrogate (Victim) ---
# We need to reconstruct the trainer/model structure to load the surrogate correctly
class Args:
    root = "/content/data"
    dataset = "oxford_pets"
    config_file = "configs/trainers/CoOp/rn50.yaml"
    dataset_config_file = "configs/datasets/oxford_pets.yaml"
    trainer = "ZeroshotCLIP"
    head = "ArcFace"
    output_dir = "output"
    opts = []
    gpu = 0
    device = "cuda:0"
    
    # Missing attributes required by reset_cfg
    resume = ""
    seed = -1
    source_domains = None
    target_domains = None
    transforms = None
    backbone = ""
    bs = 64
    ratio = 1.0
    
args = Args()
cfg = setup_cfg(args)
trainer = build_trainer(cfg)

# Manually build the surrogate model wrapper
class Model(nn.Module):
    def __init__(self, backbone, head_factory):
        super(Model, self).__init__()
        self.backbone = backbone
        self.head = head_factory.get_head() # Fixed: Extract actual head module
    def forward(self, x, labels=None):
        feat = self.backbone(x)
        return self.head(feat, labels)

yaml_parser = YAML(typ='safe')
config = yaml_parser.load(open('configs/data.yaml', 'r'))
config['num_classes'] = trainer.dm.num_classes
config['output_dim'] = 1024
head_factory = HeadFactory(args.head, config)

backbone = trainer.clip_model.visual
# Wrap backbone with normalization
normalize = transforms.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711])
backbone = nn.Sequential(normalize, backbone)

surrogate = Model(backbone, head_factory).to(device)

# Load weights
surrogate_path = "/content/data/oxford_pets/RN50_ArcFace.pth"
surrogate.load_state_dict(torch.load(surrogate_path, map_location=device))
surrogate.eval()
print("Surrogate Model Loaded.")

# --- C) Evaluation Loop ---
from ignite.metrics import Accuracy
from tqdm import tqdm
import torch.nn.functional as F

checkpoints = [
    ("/content/drive/MyDrive/grad/comp_vision/hanson_loss/oxford_pets/vit_generator.pt", "Targeted Only"),
    ("/content/drive/MyDrive/grad/comp_vision/hanson_loss/oxford_pets/vit_generator_mixed_loss.pt", "Targeted + Contrastive")
]

test_loader = trainer.test_loader
eps = 16/255.0

print(f"\nStarting Evaluation on {len(test_loader.dataset)} test images (Oxford Pets)...")

for ckpt_path, name in checkpoints:
    print(f"\n>>> Evaluating: {name}")
    if not os.path.exists(ckpt_path):
        print(f"Skipping {name}: Checkpoint not found at {ckpt_path}")
        continue
        
    # Load Generator
    generator = ViTGenerator(num_classes=37).to(device)
    generator.load_state_dict(torch.load(ckpt_path, map_location=device), strict=False)
    generator.eval()
    
    clean_acc_metric = Accuracy()
    adv_acc_metric = Accuracy()
    
    for batch in tqdm(test_loader, desc=name):
        images = batch['img'].to(device)
        labels = batch['label'].to(device)
        
        # 1. Clean Accuracy
        with torch.no_grad():
            clean_outputs = surrogate(images, labels)
            clean_acc_metric.update((clean_outputs, labels))
            
            # 2. Generate Adversarial Images (Targeted)
            # Generate random targets != true labels to properly test targeted attack capability
            target_labels = torch.randint(0, 37, labels.shape).to(device)
            mask = (target_labels == labels)
            target_labels[mask] = (target_labels[mask] + 1) % 37
            
            noise = generator(images, target_labels)
            noise = torch.clamp(noise, -eps, eps)
            adv_images = torch.clamp(images + noise, 0, 1)
            
            # 3. Adversarial Accuracy
            adv_outputs = surrogate(adv_images, labels)
            adv_acc_metric.update((adv_outputs, labels))

    clean_acc = clean_acc_metric.compute()
    adv_acc = adv_acc_metric.compute()
    asr = clean_acc - adv_acc

    print(f"Results for {name}:")
    print(f"  Clean Accuracy:       {clean_acc:.4f}")
    print(f"  Adversarial Accuracy: {adv_acc:.4f}")
    print(f"  Attack Success Rate:  {asr:.4f}")


In [5]:
# 5) Universal Transfer Evaluation (Food101)
# Goal: Does the Pet-Generator disrupt features on a completely different dataset (Food101)?
# This tests if the "Contrastive Loss" learned a Universal Feature Disruptor.

print("\n" + "="*40)
print("UNIVERSAL TRANSFER EVALUATION: Food101")
print("="*40)

# 1. Setup Food101 Data
# Note: Food101 is large, so we might just download it or assume it's there.
# If download fails, we skip.
try:
    args_food = Args()
    args_food.dataset = "food101"
    args_food.dataset_config_file = "configs/datasets/food101.yaml"
    cfg_food = setup_cfg(args_food)
    trainer_food = build_trainer(cfg_food)
    food_loader = trainer_food.test_loader
    print(f"Loaded Food101 with {len(food_loader.dataset)} test images.")
except Exception as e:
    print(f"Failed to load Food101: {e}")
    food_loader = None

if food_loader:
    # 2. Load Food101 Surrogate (Victim)
    # We need to rebuild the surrogate model for Food101 (different num_classes)
    
    # Update config for Food101 classes
    config_food = yaml_parser.load(open('configs/data.yaml', 'r'))
    config_food['num_classes'] = trainer_food.dm.num_classes # Food101 classes
    config_food['output_dim'] = 1024
    head_factory_food = HeadFactory(args.head, config_food)
    
    surrogate_food = Model(backbone, head_factory_food).to(device)
    
    # Load Checkpoint
    food_ckpt = "/content/drive/MyDrive/grad/comp_vision/hanson_loss/oxford_pets/RN50_ArcFace_food101.pth"
    # Note: User's screenshot showed this file exists in the same folder
    
    if os.path.exists(food_ckpt):
        surrogate_food.load_state_dict(torch.load(food_ckpt, map_location=device))
        surrogate_food.eval()
        print("Food101 Surrogate Loaded.")
        
        # 3. Evaluate
        for ckpt_path, name in checkpoints:
            print(f"\n>>> Evaluating (Universal Transfer): {name}")
            if not os.path.exists(ckpt_path):
                continue
                
            # Load Generator (Pet-Trained)
            generator = ViTGenerator(num_classes=37).to(device)
            generator.load_state_dict(torch.load(ckpt_path, map_location=device), strict=False)
            generator.eval()
            
            clean_acc_metric = Accuracy()
            adv_acc_metric = Accuracy()
            
            for batch in tqdm(food_loader, desc=name):
                images = batch['img'].to(device)
                labels = batch['label'].to(device)
                
                # 1. Clean Accuracy
                with torch.no_grad():
                    clean_outputs = surrogate_food(images, labels)
                    clean_acc_metric.update((clean_outputs, labels))
                    
                    # 2. Generate Adversarial Images
                    # We use random Pet targets (0-36) because the generator expects 37 classes
                    # Ideally, for "Universal Disruption", the target shouldn't matter much
                    target_labels = torch.randint(0, 37, (images.shape[0],)).to(device)
                    
                    noise = generator(images, target_labels)
                    noise = torch.clamp(noise, -eps, eps)
                    adv_images = torch.clamp(images + noise, 0, 1)
                    
                    # 3. Adversarial Accuracy (on Food101 Surrogate)
                    adv_outputs = surrogate_food(adv_images, labels)
                    adv_acc_metric.update((adv_outputs, labels))

            clean_acc = clean_acc_metric.compute()
            adv_acc = adv_acc_metric.compute()
            asr = clean_acc - adv_acc

            print(f"Results for {name} on FOOD101:")
            print(f"  Clean Accuracy:       {clean_acc:.4f}")
            print(f"  Adversarial Accuracy: {adv_acc:.4f}")
            print(f"  Attack Success Rate:  {asr:.4f}")
            print("  (Drop in Food101 accuracy caused by Pet-Generator)")
    else:
        print(f"Food101 Checkpoint not found at {food_ckpt}")
