# ViT Generator Evaluation - ASR Test

This notebook evaluates the 3 ViT generator variants on:

1. **Oxford Pets** (same dataset as training) - measures in-domain ASR
2. **Food101** (different dataset) - measures cross-domain transfer/generalizability

**Metrics:**
- **Clean Accuracy**: Surrogate accuracy on clean images
- **Adversarial Accuracy**: Surrogate accuracy on adversarial images  
- **Attack Success Rate (ASR)**: Clean Acc - Adv Acc (higher = better attack)
- **Targeted Success Rate**: How often the model predicts the target class


In [None]:
# Cell 1: Mount Google Drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

import os
DRIVE_OUTPUT = "/content/drive/MyDrive/grad/comp_vision/hanson_loss"
print(f"Checkpoint directory: {DRIVE_OUTPUT}")


In [None]:
# Cell 2: Setup Repository & Dependencies
!nvidia-smi
%cd /content

import os
if not os.path.exists("MFCLIP_acv"):
    !git clone -b hamza/discrim https://github.com/1hamzaiqbal/MFCLIP_acv

%cd MFCLIP_acv
!git fetch --all
!git reset --hard origin/hamza/discrim

# Install dependencies
!pip install -q torch torchvision timm einops yacs tqdm opencv-python \
    scikit-learn scipy pyyaml ruamel.yaml pytorch-ignite foolbox \
    pandas matplotlib seaborn wilds ftfy


In [None]:
# Cell 3: Setup Oxford Pets Dataset
import shutil
import os
from pathlib import Path

DATA_ROOT = "/content/data"
PETS_ROOT = f"{DATA_ROOT}/oxford_pets"
DRIVE_OUTPUT = "/content/drive/MyDrive/grad/comp_vision/hanson_loss"

Path(PETS_ROOT).mkdir(parents=True, exist_ok=True)

%cd /content
if not os.path.exists(f"{PETS_ROOT}/images"):
    print("Downloading Oxford Pets dataset...")
    !wget -q https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz
    !wget -q https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz
    !tar -xf images.tar.gz -C {PETS_ROOT}
    !tar -xf annotations.tar.gz -C {PETS_ROOT}
    !rm -f images.tar.gz annotations.tar.gz
    print("Dataset ready!")
else:
    print("Oxford Pets dataset already exists.")

# Copy surrogate checkpoint
SURROGATE_SRC = f"{DRIVE_OUTPUT}/oxford_pets/RN50_ArcFace_oxford_pets.pth"
SURROGATE_DST = f"{PETS_ROOT}/RN50_ArcFace.pth"
if os.path.exists(SURROGATE_SRC):
    shutil.copy(SURROGATE_SRC, SURROGATE_DST)
    print(f"Surrogate checkpoint ready: {SURROGATE_DST}")
else:
    print(f"WARNING: Surrogate not found at {SURROGATE_SRC}")


In [None]:
# Cell 4: Copy Generator Checkpoints from Drive
import shutil
import os

DRIVE_OUTPUT = "/content/drive/MyDrive/grad/comp_vision/hanson_loss/oxford_pets"
PETS_ROOT = "/content/data/oxford_pets"

checkpoints = [
    "vit_generator_targeted_only.pt",
    "vit_generator_contrastive.pt", 
    "vit_generator_mixed.pt",
]

print("Copying generator checkpoints from Drive...")
for ckpt in checkpoints:
    src = f"{DRIVE_OUTPUT}/{ckpt}"
    dst = f"{PETS_ROOT}/{ckpt}"
    if os.path.exists(src):
        shutil.copy(src, dst)
        print(f"  âœ“ {ckpt}")
    else:
        print(f"  âœ— {ckpt} (not found in Drive)")


In [None]:
# Cell 5: Load Models and Setup Evaluation Infrastructure
%cd /content/MFCLIP_acv

import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from ignite.metrics import Accuracy
from ruamel.yaml import YAML
import sys
sys.path.insert(0, '/content/MFCLIP_acv')

from model import ViTGenerator
from utils.util import setup_cfg, Model
from dass.engine import build_trainer
from loss.head.head_def import HeadFactory
from torchvision import transforms
import argparse

# Register trainers and datasets
import trainers.zsclip
import trainers.coop
import trainers.cocoop
import datasets.oxford_pets
import datasets.oxford_flowers
import datasets.food101

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Build trainer for Oxford Pets
class Args:
    root = "/content/data"
    dataset = "oxford_pets"
    config_file = "configs/trainers/CoOp/rn50.yaml"
    dataset_config_file = "configs/datasets/oxford_pets.yaml"
    trainer = "ZeroshotCLIP"
    head = "ArcFace"
    output_dir = "output"
    opts = []
    gpu = 0
    device = "cuda:0"
    resume = ""
    seed = -1
    source_domains = None
    target_domains = None
    transforms = None
    backbone = ""
    bs = 64
    ratio = 1.0

args = Args()
cfg = setup_cfg(args)
trainer = build_trainer(cfg)

# Build Surrogate Model
yaml_parser = YAML(typ='safe')
config = yaml_parser.load(open('configs/data.yaml', 'r'))
config['num_classes'] = trainer.dm.num_classes
config['output_dim'] = 1024
head_factory = HeadFactory(args.head, config)

backbone = trainer.clip_model.visual
normalize = transforms.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711])
backbone = nn.Sequential(normalize, backbone)

surrogate = Model(backbone, head_factory).to(device)

# Load surrogate weights
surrogate_path = "/content/data/oxford_pets/RN50_ArcFace.pth"
surrogate.load_state_dict(torch.load(surrogate_path, map_location=device))
surrogate.eval()
print("âœ“ Surrogate Model Loaded")

# Get test loader
test_loader = trainer.test_loader
print(f"âœ“ Test set: {len(test_loader.dataset)} samples")


## Evaluation Function

This function evaluates a generator on a given dataset and computes ASR metrics.


In [None]:
# Cell 6: Evaluation Function
def evaluate_generator(generator, surrogate, test_loader, num_classes, eps=16/255., device='cuda'):
    """
    Evaluate a generator on a test loader.
    
    Returns:
        dict with clean_acc, adv_acc, asr, targeted_success_rate
    """
    generator.eval()
    surrogate.eval()
    
    clean_acc_metric = Accuracy()
    adv_acc_metric = Accuracy()
    targeted_success_metric = Accuracy()
    
    for batch in tqdm(test_loader, desc="Evaluating"):
        images = batch['img'].to(device)
        labels = batch['label'].to(device)
        
        # Generate random target labels != true labels
        target_labels = torch.randint(0, num_classes, labels.shape).to(device)
        mask = (target_labels == labels)
        target_labels[mask] = (target_labels[mask] + 1) % num_classes
        
        with torch.no_grad():
            # Clean accuracy
            clean_outputs = surrogate(images, labels)
            clean_acc_metric.update((clean_outputs, labels))
            
            # Generate adversarial images
            noise = generator(images, target_labels)
            noise = torch.clamp(noise, -eps, eps)
            adv_images = torch.clamp(images + noise, 0, 1)
            
            # Adversarial accuracy
            adv_outputs = surrogate(adv_images, labels)
            adv_acc_metric.update((adv_outputs, labels))
            
            # Targeted success (does model predict target class?)
            targeted_success_metric.update((adv_outputs, target_labels))
    
    clean_acc = clean_acc_metric.compute()
    adv_acc = adv_acc_metric.compute()
    asr = clean_acc - adv_acc
    targeted_success = targeted_success_metric.compute()
    
    return {
        'clean_acc': clean_acc,
        'adv_acc': adv_acc,
        'asr': asr,
        'targeted_success': targeted_success,
    }

print("âœ“ Evaluation function defined")


In [None]:
# Cell 7: Evaluate All Variants on Oxford Pets
import pandas as pd

PETS_ROOT = "/content/data/oxford_pets"
NUM_CLASSES = 37  # Oxford Pets has 37 classes
EPS = 16/255.

variants = [
    ("targeted_only", f"{PETS_ROOT}/vit_generator_targeted_only.pt"),
    ("contrastive", f"{PETS_ROOT}/vit_generator_contrastive.pt"),
    ("mixed", f"{PETS_ROOT}/vit_generator_mixed.pt"),
]

results_pets = []

print("=" * 60)
print("OXFORD PETS EVALUATION (In-Domain)")
print("=" * 60)

for name, ckpt_path in variants:
    print(f"\n>>> Evaluating: {name}")
    
    if not os.path.exists(ckpt_path):
        print(f"  âœ— Checkpoint not found: {ckpt_path}")
        results_pets.append({'variant': name, 'clean_acc': None, 'adv_acc': None, 'asr': None, 'targeted_success': None})
        continue
    
    # Load generator
    generator = ViTGenerator(num_classes=NUM_CLASSES).to(device)
    generator.load_state_dict(torch.load(ckpt_path, map_location=device), strict=False)
    generator.eval()
    
    # Evaluate
    metrics = evaluate_generator(generator, surrogate, test_loader, NUM_CLASSES, EPS, device)
    
    print(f"  Clean Accuracy:       {metrics['clean_acc']:.4f}")
    print(f"  Adversarial Accuracy: {metrics['adv_acc']:.4f}")
    print(f"  Attack Success Rate:  {metrics['asr']:.4f}")
    print(f"  Targeted Success:     {metrics['targeted_success']:.4f}")
    
    results_pets.append({
        'variant': name,
        'clean_acc': metrics['clean_acc'],
        'adv_acc': metrics['adv_acc'],
        'asr': metrics['asr'],
        'targeted_success': metrics['targeted_success'],
    })
    
    # Free memory
    del generator
    torch.cuda.empty_cache()

# Display results table
df_pets = pd.DataFrame(results_pets)
print("\n" + "=" * 60)
print("OXFORD PETS RESULTS SUMMARY")
print("=" * 60)
print(df_pets.to_string(index=False))


## Setup Food101 for Cross-Domain Evaluation

Download and prepare Food101 dataset, then build a new surrogate trained on Food101.


In [None]:
# Cell 8: Setup Food101 Dataset
import os

FOOD_ROOT = "/content/data/food-101"

print("Checking Food101 dataset...")
if not os.path.exists(f"{FOOD_ROOT}/images"):
    print("Downloading Food101 (approx 5GB)... This might take a few minutes.")
    %cd /content/data
    !wget -q http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz
    print("Extracting Food101...")
    !tar -xzf food-101.tar.gz
    !rm -f food-101.tar.gz
    print("âœ“ Food101 Ready!")
else:
    print("âœ“ Food101 dataset already exists.")

# Setup Food101 trainer
args_food = Args()
args_food.dataset = "food101"
args_food.dataset_config_file = "configs/datasets/food101.yaml"

%cd /content/MFCLIP_acv
cfg_food = setup_cfg(args_food)
trainer_food = build_trainer(cfg_food)
food_loader = trainer_food.test_loader
food_num_classes = trainer_food.dm.num_classes

print(f"âœ“ Food101 test set: {len(food_loader.dataset)} samples")
print(f"âœ“ Food101 classes: {food_num_classes}")


## Evaluate on Food101 (Cross-Domain Transfer)

Test how well the Pet-trained generators transfer to a completely different domain (food images).
This measures the generalizability of the learned perturbations.


In [None]:
# Cell 10: Evaluate Cross-Domain Transfer on Food101
PETS_ROOT = "/content/data/oxford_pets"
PETS_NUM_CLASSES = 37

def evaluate_cross_domain(generator, surrogate_target, test_loader, generator_num_classes, target_num_classes, eps=16/255., device='cuda'):
    """
    Evaluate a generator (trained on one domain) on a different domain.
    Generator uses its original num_classes for target generation.
    """
    generator.eval()
    surrogate_target.eval()
    
    clean_acc_metric = Accuracy()
    adv_acc_metric = Accuracy()
    
    for batch in tqdm(test_loader, desc="Cross-Domain Eval"):
        images = batch['img'].to(device)
        labels = batch['label'].to(device)
        
        # Use generator's class space for targets (e.g., Pet classes 0-36)
        target_labels = torch.randint(0, generator_num_classes, labels.shape).to(device)
        
        with torch.no_grad():
            # Clean accuracy on target domain
            clean_outputs = surrogate_target(images, labels)
            clean_acc_metric.update((clean_outputs, labels))
            
            # Generate adversarial images using Pet-trained generator
            noise = generator(images, target_labels)
            noise = torch.clamp(noise, -eps, eps)
            adv_images = torch.clamp(images + noise, 0, 1)
            
            # Adversarial accuracy on target domain
            adv_outputs = surrogate_target(adv_images, labels)
            adv_acc_metric.update((adv_outputs, labels))
    
    clean_acc = clean_acc_metric.compute()
    adv_acc = adv_acc_metric.compute()
    asr = clean_acc - adv_acc
    
    return {
        'clean_acc': clean_acc,
        'adv_acc': adv_acc,
        'asr': asr,
    }

# Only run if Food101 surrogate is available
if FOOD_SURROGATE_AVAILABLE:
    results_food = []
    
    print("=" * 60)
    print("FOOD101 EVALUATION (Cross-Domain Transfer)")
    print("=" * 60)
    print("Using Pet-trained generators on Food images\n")
    
    for name, ckpt_path in variants:
        print(f"\n>>> Evaluating: {name}")
        
        if not os.path.exists(ckpt_path):
            print(f"  âœ— Checkpoint not found")
            results_food.append({'variant': name, 'clean_acc': None, 'adv_acc': None, 'asr': None})
            continue
        
        # Load generator (trained on Pets with 37 classes)
        generator = ViTGenerator(num_classes=PETS_NUM_CLASSES).to(device)
        generator.load_state_dict(torch.load(ckpt_path, map_location=device), strict=False)
        generator.eval()
        
        # Evaluate cross-domain
        metrics = evaluate_cross_domain(
            generator, surrogate_food, food_loader,
            generator_num_classes=PETS_NUM_CLASSES,
            target_num_classes=food_num_classes,
            eps=EPS, device=device
        )
        
        print(f"  Clean Accuracy:       {metrics['clean_acc']:.4f}")
        print(f"  Adversarial Accuracy: {metrics['adv_acc']:.4f}")
        print(f"  Attack Success Rate:  {metrics['asr']:.4f}")
        
        results_food.append({
            'variant': name,
            'clean_acc': metrics['clean_acc'],
            'adv_acc': metrics['adv_acc'],
            'asr': metrics['asr'],
        })
        
        del generator
        torch.cuda.empty_cache()
    
    # Display results
    df_food = pd.DataFrame(results_food)
    print("\n" + "=" * 60)
    print("FOOD101 RESULTS SUMMARY")
    print("=" * 60)
    print(df_food.to_string(index=False))
else:
    print("Skipping Food101 evaluation - surrogate not available.")


## Visualization: Perturbation Comparison


In [None]:
# Cell 11: Visualize Perturbations from All Variants
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import glob

PETS_ROOT = "/content/data/oxford_pets"

# Get a sample image
images = glob.glob(f"{PETS_ROOT}/images/*.jpg")
if images:
    img_path = images[0]
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])
    img = Image.open(img_path).convert('RGB')
    img_t = transform(img).unsqueeze(0).to(device)
    
    target_label = torch.tensor([5]).to(device)
    eps = 16/255
    
    fig, axes = plt.subplots(3, 4, figsize=(16, 12))
    
    def show_tensor(t, ax, title):
        im = t.squeeze().cpu().permute(1, 2, 0).numpy()
        im = np.clip(im, 0, 1)
        ax.imshow(im)
        ax.set_title(title, fontsize=10)
        ax.axis('off')
    
    for row, (name, ckpt_path) in enumerate(variants):
        if not os.path.exists(ckpt_path):
            for col in range(4):
                axes[row, col].axis('off')
                axes[row, col].set_title(f"{name}: Not Found")
            continue
        
        gen = ViTGenerator(num_classes=37).to(device)
        gen.load_state_dict(torch.load(ckpt_path, map_location=device), strict=False)
        gen.eval()
        
        with torch.no_grad():
            noise = gen(img_t, target_label)
            noise = torch.clamp(noise, -eps, eps)
            adv = torch.clamp(img_t + noise, 0, 1)
        
        # Original
        show_tensor(img_t, axes[row, 0], f"Original")
        
        # Perturbation (normalized for visibility)
        noise_np = noise.squeeze().cpu().permute(1, 2, 0).numpy()
        noise_vis = (noise_np - noise_np.min()) / (noise_np.max() - noise_np.min() + 1e-8)
        axes[row, 1].imshow(noise_vis)
        axes[row, 1].set_title(f"{name}: Perturbation")
        axes[row, 1].axis('off')
        
        # Adversarial
        show_tensor(adv, axes[row, 2], f"{name}: Adversarial")
        
        # Difference (amplified)
        diff = (adv - img_t).abs()
        diff_np = diff.squeeze().cpu().permute(1, 2, 0).numpy()
        diff_vis = diff_np * 10  # Amplify for visibility
        diff_vis = np.clip(diff_vis, 0, 1)
        axes[row, 3].imshow(diff_vis)
        axes[row, 3].set_title(f"{name}: Diff (10x)")
        axes[row, 3].axis('off')
        
        del gen
        torch.cuda.empty_cache()
    
    plt.tight_layout()
    plt.savefig(f"{PETS_ROOT}/perturbation_comparison.png", dpi=150)
    plt.show()
else:
    print("No images found for visualization.")


## Final Summary

Comparison of all variants across both datasets.


In [None]:
# Cell 12: Final Summary
import matplotlib.pyplot as plt
import numpy as np

print("=" * 70)
print("FINAL RESULTS SUMMARY")
print("=" * 70)

print("\nðŸ“Š OXFORD PETS (In-Domain):")
print(df_pets.to_string(index=False))

if FOOD_SURROGATE_AVAILABLE:
    print("\nðŸ“Š FOOD101 (Cross-Domain Transfer):")
    print(df_food.to_string(index=False))

# Plot comparison bar chart
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

variants_names = [v[0] for v in variants]
colors = ['#2196F3', '#F44336', '#4CAF50']

# Plot 1: Oxford Pets ASR
ax = axes[0]
asr_pets = [r['asr'] if r['asr'] is not None else 0 for r in results_pets]
bars = ax.bar(variants_names, asr_pets, color=colors)
ax.set_ylabel('Attack Success Rate')
ax.set_title('Oxford Pets (In-Domain)')
ax.set_ylim(0, 1)
for bar, val in zip(bars, asr_pets):
    if val > 0:
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, 
                f'{val:.3f}', ha='center', va='bottom', fontsize=10)

# Plot 2: Food101 ASR (if available)
ax = axes[1]
if FOOD_SURROGATE_AVAILABLE:
    asr_food = [r['asr'] if r['asr'] is not None else 0 for r in results_food]
    bars = ax.bar(variants_names, asr_food, color=colors)
    ax.set_ylabel('Attack Success Rate')
    ax.set_title('Food101 (Cross-Domain Transfer)')
    ax.set_ylim(0, 1)
    for bar, val in zip(bars, asr_food):
        if val > 0:
            ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, 
                    f'{val:.3f}', ha='center', va='bottom', fontsize=10)
else:
    ax.text(0.5, 0.5, 'Food101 Surrogate\nNot Available', 
            ha='center', va='center', fontsize=14, transform=ax.transAxes)
    ax.set_title('Food101 (Cross-Domain Transfer)')

plt.tight_layout()
plt.savefig(f"{PETS_ROOT}/asr_comparison.png", dpi=150)
plt.show()

print(f"\nResults saved to {PETS_ROOT}/")


In [None]:
# Cell 9: Build Food101 Surrogate (if checkpoint exists)
DRIVE_OUTPUT = "/content/drive/MyDrive/grad/comp_vision/hanson_loss"
FOOD_SURROGATE_SRC = f"{DRIVE_OUTPUT}/food101/RN50_ArcFace_food101.pth"

# Build Food101 surrogate model
config_food = yaml_parser.load(open('configs/data.yaml', 'r'))
config_food['num_classes'] = food_num_classes
config_food['output_dim'] = 1024
head_factory_food = HeadFactory(args.head, config_food)

surrogate_food = Model(backbone, head_factory_food).to(device)

if os.path.exists(FOOD_SURROGATE_SRC):
    surrogate_food.load_state_dict(torch.load(FOOD_SURROGATE_SRC, map_location=device))
    surrogate_food.eval()
    print(f"âœ“ Food101 Surrogate loaded from {FOOD_SURROGATE_SRC}")
    FOOD_SURROGATE_AVAILABLE = True
else:
    print(f"âœ— Food101 Surrogate not found at {FOOD_SURROGATE_SRC}")
    print("  Cross-domain evaluation will measure feature disruption only.")
    FOOD_SURROGATE_AVAILABLE = False
