# t-SNE Visualization: Clean vs Adversarial Features

This notebook visualizes how different generator variants affect the CLIP feature space.

**Generators to Compare:**
- UNet Vanilla (existing)
- UNet Contrastive (existing)
- ViT Targeted Only (after training)
- ViT Contrastive (after training)
- ViT Mixed (after training)

**Key Question:** Does contrastive loss training produce perturbations that disperse features MORE than targeted-only training?

**Expected Result:** Contrastive-trained generators should push adversarial features further from clean feature clusters.


In [None]:
# Cell 1: Mount Drive & Setup
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

import os
DRIVE_ROOT = "/content/drive/MyDrive/grad/comp_vision/hanson_loss/oxford_pets"
print(f"Drive root: {DRIVE_ROOT}")

# List available checkpoints
print("\nAvailable checkpoints in Drive:")
if os.path.exists(DRIVE_ROOT):
    for f in os.listdir(DRIVE_ROOT):
        size_mb = os.path.getsize(f"{DRIVE_ROOT}/{f}") / 1e6
        print(f"  {f} ({size_mb:.1f} MB)")


In [None]:
# Cell 2: Clone Repo & Install Dependencies
!nvidia-smi
%cd /content

import os
if not os.path.exists("MFCLIP_acv"):
    !git clone -b hamza/discrim https://github.com/1hamzaiqbal/MFCLIP_acv

%cd MFCLIP_acv
!git fetch --all
!git reset --hard origin/hamza/discrim

!pip install -q torch torchvision timm einops yacs tqdm opencv-python \
    scikit-learn scipy pyyaml ruamel.yaml pytorch-ignite foolbox \
    pandas matplotlib seaborn wilds ftfy


In [None]:
# Cell 3: Setup Dataset
import shutil
from pathlib import Path
from torchvision.datasets import OxfordIIITPet
from torchvision import transforms

DATA_ROOT = "/content/data"
PETS_ROOT = f"{DATA_ROOT}/oxford_pets"
DRIVE_ROOT = "/content/drive/MyDrive/grad/comp_vision/hanson_loss/oxford_pets"

Path(PETS_ROOT).mkdir(parents=True, exist_ok=True)

# Download via torchvision (creates split files)
print("Setting up Oxford Pets...")
_ = OxfordIIITPet(root=PETS_ROOT, download=True, transform=transforms.ToTensor())

# Fetch images/annotations
%cd /content
if not os.path.exists(f"{PETS_ROOT}/images"):
    !wget -q https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz
    !wget -q https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz
    !tar -xf images.tar.gz -C {PETS_ROOT}
    !tar -xf annotations.tar.gz -C {PETS_ROOT}
    !rm -f images.tar.gz annotations.tar.gz
print("✓ Dataset ready!")


In [None]:
# Cell 4: Load Surrogate Model (CLIP Backbone)
%cd /content/MFCLIP_acv

import torch
import torch.nn as nn
from torchvision import transforms
from ruamel.yaml import YAML
import sys
sys.path.insert(0, '/content/MFCLIP_acv')

from model import UNetLikeGenerator as UNet, ViTGenerator
from utils.util import setup_cfg, Model
from dass.engine import build_trainer
from loss.head.head_def import HeadFactory

# Register modules
import trainers.zsclip, trainers.coop, trainers.cocoop
import datasets.oxford_pets, datasets.oxford_flowers, datasets.food101

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Build trainer to get CLIP backbone
class Args:
    root = "/content/data"
    dataset = "oxford_pets"
    config_file = "configs/trainers/CoOp/rn50.yaml"
    dataset_config_file = "configs/datasets/oxford_pets.yaml"
    trainer = "ZeroshotCLIP"
    head = "ArcFace"
    output_dir = "output"
    opts = []
    gpu = 0
    device = "cuda:0"
    resume = ""
    seed = -1
    source_domains = None
    target_domains = None
    transforms = None
    backbone = ""
    bs = 64
    ratio = 1.0

args = Args()
cfg = setup_cfg(args)
trainer = build_trainer(cfg)

# Build surrogate model
yaml_parser = YAML(typ='safe')
config = yaml_parser.load(open('configs/data.yaml', 'r'))
config['num_classes'] = trainer.dm.num_classes
config['output_dim'] = 1024
head_factory = HeadFactory(args.head, config)

clip_backbone = trainer.clip_model.visual
normalize = transforms.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711])
backbone = nn.Sequential(normalize, clip_backbone)

surrogate = Model(backbone, head_factory).to(device)

# Load surrogate weights (exact filename)
SURROGATE_PATH = f"{DRIVE_ROOT}/RN50_ArcFace_oxford_pets.pth"
if os.path.exists(SURROGATE_PATH):
    surrogate.load_state_dict(torch.load(SURROGATE_PATH, map_location=device))
    print(f"✓ Surrogate loaded from {SURROGATE_PATH}")
else:
    print(f"⚠️ Surrogate not found at {SURROGATE_PATH}")
        
surrogate.eval()
test_loader = trainer.test_loader
NUM_CLASSES = trainer.dm.num_classes
print(f"✓ Test set: {len(test_loader.dataset)} samples, {NUM_CLASSES} classes")


In [None]:
# Cell 5: Feature Extraction Functions
from tqdm import tqdm
import numpy as np

def extract_features(loader, backbone, generator=None, num_samples=1000, eps=16/255., device='cuda'):
    """
    Extract features from images, optionally with adversarial perturbations.
    
    Args:
        loader: DataLoader
        backbone: Feature extractor (CLIP visual)
        generator: Optional generator for adversarial perturbations
        num_samples: Max samples to extract (for speed)
        eps: Perturbation epsilon
        
    Returns:
        features: (N, 1024) numpy array
        labels: (N,) numpy array
    """
    backbone.eval()
    if generator is not None:
        generator.eval()
    
    all_features = []
    all_labels = []
    count = 0
    
    with torch.no_grad():
        for batch in tqdm(loader, desc="Extracting features"):
            images = batch['img'].to(device)
            labels = batch['label'].to(device)
            
            if generator is not None:
                # Generate adversarial images
                # For targeted attack, use random targets
                target_labels = torch.randint(0, NUM_CLASSES, labels.shape).to(device)
                mask = (target_labels == labels)
                target_labels[mask] = (target_labels[mask] + 1) % NUM_CLASSES
                
                # Check if generator expects target labels (ViT) or not (UNet)
                try:
                    noise = generator(images, target_labels)
                except TypeError:
                    # UNet doesn't take target labels
                    noise = generator(images)
                
                noise = torch.clamp(noise, -eps, eps)
                images = torch.clamp(images + noise, 0, 1)
            
            # Extract features
            features = backbone(images)
            
            all_features.append(features.cpu().numpy())
            all_labels.append(labels.cpu().numpy())
            
            count += len(labels)
            if count >= num_samples:
                break
    
    features = np.concatenate(all_features, axis=0)[:num_samples]
    labels = np.concatenate(all_labels, axis=0)[:num_samples]
    
    return features, labels

print("✓ Feature extraction functions defined")


In [None]:
# Cell 6: t-SNE Visualization Functions
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

def run_tsne(features_dict, labels, title="t-SNE Visualization", save_path=None, perplexity=30):
    """
    Run t-SNE on multiple feature sets and plot them together.
    
    Args:
        features_dict: dict of {name: features_array}
        labels: class labels (same for all)
        title: Plot title
        save_path: Optional path to save figure
    """
    # Combine all features for joint t-SNE
    all_features = []
    all_names = []
    all_indices = []
    
    start_idx = 0
    for name, feats in features_dict.items():
        all_features.append(feats)
        all_names.extend([name] * len(feats))
        all_indices.append((start_idx, start_idx + len(feats)))
        start_idx += len(feats)
    
    combined = np.concatenate(all_features, axis=0)
    
    print(f"Running t-SNE on {combined.shape[0]} samples...")
    tsne = TSNE(n_components=2, perplexity=perplexity, random_state=42, init='pca', n_iter=1000)
    embeddings = tsne.fit_transform(combined)
    
    # Plot
    n_variants = len(features_dict)
    fig, axes = plt.subplots(1, n_variants + 1, figsize=(6 * (n_variants + 1), 5))
    
    # Color by class
    cmap = plt.cm.get_cmap('tab20', NUM_CLASSES)
    
    # Plot each variant separately
    for idx, (name, (start, end)) in enumerate(zip(features_dict.keys(), all_indices)):
        ax = axes[idx]
        emb = embeddings[start:end]
        scatter = ax.scatter(emb[:, 0], emb[:, 1], c=labels, cmap=cmap, s=8, alpha=0.6)
        ax.set_title(f"{name}\n(colored by class)", fontsize=11)
        ax.set_xticks([])
        ax.set_yticks([])
    
    # Plot all overlaid (last panel)
    ax = axes[-1]
    markers = ['o', 's', '^', 'D', 'v', 'P']  # Different markers for each variant
    colors = plt.cm.Set1(np.linspace(0, 1, len(features_dict)))
    
    for idx, ((name, (start, end)), color, marker) in enumerate(zip(
            zip(features_dict.keys(), all_indices), colors, markers)):
        emb = embeddings[start:end]
        ax.scatter(emb[:, 0], emb[:, 1], c=[color], s=12, alpha=0.5, marker=marker, label=name)
    
    ax.legend(loc='upper right', fontsize=8)
    ax.set_title("All Variants Overlaid\n(colored by variant)", fontsize=11)
    ax.set_xticks([])
    ax.set_yticks([])
    
    plt.suptitle(title, fontsize=14, y=1.02)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved to {save_path}")
    
    plt.show()
    
    return embeddings

print("✓ t-SNE visualization functions defined")


## Part 1: UNet Variants (Existing Checkpoints)

Compare clean features vs adversarial features from UNet Vanilla and UNet Contrastive.


In [None]:
# Cell 7: Load UNet Generators
DRIVE_ROOT = "/content/drive/MyDrive/grad/comp_vision/hanson_loss/oxford_pets"

# Exact filenames from Drive
UNET_VANILLA_PATH = f"{DRIVE_ROOT}/unet-vanilla-pets.pt"
UNET_CONTRASTIVE_PATH = f"{DRIVE_ROOT}/unet-contrastive-pets.pt"

def load_unet(path, name):
    if os.path.exists(path):
        gen = UNet().to(device)
        gen.load_state_dict(torch.load(path, map_location=device))
        gen.eval()
        print(f"✓ Loaded {name} from {path}")
        return gen
    else:
        print(f"✗ {name} not found at {path}")
        return None

unet_vanilla = load_unet(UNET_VANILLA_PATH, "UNet Vanilla")
unet_contrastive = load_unet(UNET_CONTRASTIVE_PATH, "UNet Contrastive")


In [None]:
# Cell 8: Extract Features - UNet Variants
NUM_SAMPLES = 1000  # Adjust based on memory/speed
EPS = 16/255.

print("Extracting clean features...")
feat_clean, labels = extract_features(
    test_loader, surrogate.backbone, generator=None, 
    num_samples=NUM_SAMPLES, device=device
)
print(f"  Clean features shape: {feat_clean.shape}")

features_unet = {"Clean": feat_clean}

if unet_vanilla is not None:
    print("\nExtracting UNet Vanilla adversarial features...")
    feat_unet_vanilla, _ = extract_features(
        test_loader, surrogate.backbone, generator=unet_vanilla,
        num_samples=NUM_SAMPLES, eps=EPS, device=device
    )
    features_unet["UNet Vanilla"] = feat_unet_vanilla
    print(f"  UNet Vanilla shape: {feat_unet_vanilla.shape}")
    del unet_vanilla  # Free memory
    torch.cuda.empty_cache()

if unet_contrastive is not None:
    print("\nExtracting UNet Contrastive adversarial features...")
    feat_unet_contrastive, _ = extract_features(
        test_loader, surrogate.backbone, generator=unet_contrastive,
        num_samples=NUM_SAMPLES, eps=EPS, device=device
    )
    features_unet["UNet Contrastive"] = feat_unet_contrastive
    print(f"  UNet Contrastive shape: {feat_unet_contrastive.shape}")
    del unet_contrastive
    torch.cuda.empty_cache()

print(f"\n✓ Extracted features for: {list(features_unet.keys())}")


In [None]:
# Cell 9: t-SNE Visualization - UNet Variants
if len(features_unet) > 1:
    run_tsne(
        features_unet, 
        labels,
        title="t-SNE: Clean vs UNet Adversarial Features (Oxford Pets)",
        save_path="/content/tsne_unet_comparison.png"
    )
else:
    print("Need at least 2 feature sets to compare. Check that UNet checkpoints loaded.")


## Part 2: ViT Variants (After Training)

Run these cells after training the 3 ViT variants with `vit_train_colab.ipynb`.


In [None]:
# Cell 10: Load ViT Generators (run after training completes)
DRIVE_ROOT = "/content/drive/MyDrive/grad/comp_vision/hanson_loss/oxford_pets"

vit_checkpoints = {
    "ViT Targeted": f"{DRIVE_ROOT}/vit_generator_targeted_only.pt",
    "ViT Contrastive": f"{DRIVE_ROOT}/vit_generator_contrastive.pt",
    "ViT Mixed": f"{DRIVE_ROOT}/vit_generator_mixed.pt",
}

vit_generators = {}

for name, path in vit_checkpoints.items():
    if os.path.exists(path):
        gen = ViTGenerator(num_classes=NUM_CLASSES).to(device)
        gen.load_state_dict(torch.load(path, map_location=device), strict=False)
        gen.eval()
        vit_generators[name] = gen
        print(f"✓ Loaded {name} from {path}")
    else:
        print(f"✗ {name} not found at {path}")

print(f"\nLoaded {len(vit_generators)} ViT generators")


In [None]:
# Cell 11: Extract Features - ViT Variants
NUM_SAMPLES = 1000
EPS = 16/255.

# Reuse clean features from earlier, or re-extract
if 'feat_clean' not in dir() or feat_clean is None:
    print("Extracting clean features...")
    feat_clean, labels = extract_features(
        test_loader, surrogate.backbone, generator=None,
        num_samples=NUM_SAMPLES, device=device
    )

features_vit = {"Clean": feat_clean}

for name, gen in vit_generators.items():
    print(f"\nExtracting {name} adversarial features...")
    feat_adv, _ = extract_features(
        test_loader, surrogate.backbone, generator=gen,
        num_samples=NUM_SAMPLES, eps=EPS, device=device
    )
    features_vit[name] = feat_adv
    print(f"  {name} shape: {feat_adv.shape}")
    
    # Free memory after each
    del gen
    torch.cuda.empty_cache()

# Clear the dict since we deleted generators
vit_generators.clear()

print(f"\n✓ Extracted features for: {list(features_vit.keys())}")


In [None]:
# Cell 12: t-SNE Visualization - ViT Variants
if len(features_vit) > 1:
    run_tsne(
        features_vit,
        labels,
        title="t-SNE: Clean vs ViT Adversarial Features (Oxford Pets)",
        save_path="/content/tsne_vit_comparison.png"
    )
else:
    print("Need ViT checkpoints to run this. Train them first with vit_train_colab.ipynb")


## Part 3: Combined Comparison (All Generators)

Compare all available generators together: UNet + ViT variants.


In [None]:
# Cell 13: Combined t-SNE - All Generators
# Merge UNet and ViT features (excluding duplicated "Clean")
features_all = {"Clean": feat_clean}

# Add UNet features if available
if 'features_unet' in dir():
    for k, v in features_unet.items():
        if k != "Clean":
            features_all[k] = v

# Add ViT features if available  
if 'features_vit' in dir():
    for k, v in features_vit.items():
        if k != "Clean":
            features_all[k] = v

print(f"Combined features: {list(features_all.keys())}")

if len(features_all) > 2:
    run_tsne(
        features_all,
        labels,
        title="t-SNE: All Generator Variants Compared (Oxford Pets)",
        save_path="/content/tsne_all_generators.png"
    )
else:
    print("Need more generator variants to make a meaningful comparison.")


## Part 4: Quantitative Analysis

Measure how much each generator disrupts the feature space.


In [None]:
# Cell 14: Quantitative Feature Disruption Analysis
from scipy.spatial.distance import cosine
import pandas as pd

def compute_disruption_metrics(feat_clean, feat_adv):
    """
    Compute metrics measuring how much adversarial features deviate from clean.
    """
    # Per-sample cosine distance
    cosine_dists = []
    for i in range(len(feat_clean)):
        cosine_dists.append(cosine(feat_clean[i], feat_adv[i]))
    cosine_dists = np.array(cosine_dists)
    
    # L2 distance
    l2_dists = np.linalg.norm(feat_clean - feat_adv, axis=1)
    
    return {
        'cosine_dist_mean': np.mean(cosine_dists),
        'cosine_dist_std': np.std(cosine_dists),
        'l2_dist_mean': np.mean(l2_dists),
        'l2_dist_std': np.std(l2_dists),
    }

# Compute for all variants
results = []
for name, feats in features_all.items():
    if name == "Clean":
        continue
    metrics = compute_disruption_metrics(feat_clean, feats)
    metrics['Generator'] = name
    results.append(metrics)

if results:
    df = pd.DataFrame(results)
    df = df[['Generator', 'cosine_dist_mean', 'cosine_dist_std', 'l2_dist_mean', 'l2_dist_std']]
    df = df.sort_values('cosine_dist_mean', ascending=False)
    
    print("Feature Disruption Metrics (higher = more disruption):")
    print("=" * 70)
    print(df.to_string(index=False))
    print("=" * 70)
    
    # Plot
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    colors = plt.cm.Set2(np.linspace(0, 1, len(df)))
    
    ax = axes[0]
    bars = ax.barh(df['Generator'], df['cosine_dist_mean'], xerr=df['cosine_dist_std'], color=colors)
    ax.set_xlabel('Mean Cosine Distance')
    ax.set_title('Feature Disruption (Cosine Distance)')
    ax.invert_yaxis()
    
    ax = axes[1]
    bars = ax.barh(df['Generator'], df['l2_dist_mean'], xerr=df['l2_dist_std'], color=colors)
    ax.set_xlabel('Mean L2 Distance')
    ax.set_title('Feature Disruption (L2 Distance)')
    ax.invert_yaxis()
    
    plt.tight_layout()
    plt.savefig('/content/feature_disruption_metrics.png', dpi=150)
    plt.show()
else:
    print("Need generator features to compute disruption metrics.")


In [None]:
# Cell 15: Save All Results to Drive
import shutil

SAVE_DIR = "/content/drive/MyDrive/grad/comp_vision/hanson_loss/oxford_pets/tsne_results"
os.makedirs(SAVE_DIR, exist_ok=True)

files_to_save = [
    "/content/tsne_unet_comparison.png",
    "/content/tsne_vit_comparison.png", 
    "/content/tsne_all_generators.png",
    "/content/feature_disruption_metrics.png",
]

print(f"Saving results to {SAVE_DIR}...")
for f in files_to_save:
    if os.path.exists(f):
        shutil.copy(f, SAVE_DIR)
        print(f"  ✓ {os.path.basename(f)}")
    else:
        print(f"  ✗ {os.path.basename(f)} (not found)")

print("\nDone!")
