In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
from typing import Optional, Tuple

class ClusteredGateNetwork(nn.Module):
    def __init__(self, num_clusters: int, num_heads: int, seq_len: int):
        super().__init__()
        self.num_clusters = num_clusters
        self.gates = nn.ModuleList([
            AdaptiveGateNetwork(num_heads, seq_len) for _ in range(num_clusters)
        ])

    def forward(self, attn_weights: torch.Tensor, cluster_ids: torch.Tensor):
        """
        attn_weights: (B, num_heads, N, N)
        cluster_ids: (B,) long tensor of cluster ids (0, 1, 2)
        Returns:
            gated_attn: (B, num_heads, N, N)
        """
        B, H, N, _ = attn_weights.shape
        gated = torch.zeros_like(attn_weights)
        for b in range(B):
            cid = cluster_ids[b].item()
            if cid == -1:
                gated[b] = attn_weights[b]
                continue
            gate = self.gates[cid]
            for h in range(H):
                gated_attn = gate(attn_weights[b:b+1, h:h+1])
                gated[b, h] = gated_attn.squeeze(0)
        return gated


In [None]:
class AdaptiveGateNetwork(nn.Module):
    def __init__(self, num_heads: int, seq_len: int):
        super().__init__()
        self.gate = nn.Sequential(
            nn.LayerNorm([seq_len, seq_len]),
            nn.Linear(seq_len, seq_len),
            nn.ReLU(),
            nn.Linear(seq_len, seq_len)
        )

    def forward(self, attn_weights: torch.Tensor):
        B, H, N, _ = attn_weights.shape
        x = attn_weights.squeeze(1)  # (B, N, N)
        x = self.gate(x)
        return x.unsqueeze(1)  # (B, 1, N, N)


class GatedAttentionBlock(nn.Module):
    def __init__(self, original_attn: nn.Module, gate_network: ClusteredGateNetwork):
        super().__init__()
        self.original_attn = original_attn
        self.gate_network = gate_network
        self.qkv = original_attn.qkv
        self.scale = original_attn.scale
        self.proj = original_attn.proj
        self.num_heads = original_attn.num_heads
        self.head_dim = original_attn.head_dim

        self.apply_gating = False
        self.cluster_ids = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)

        if self.apply_gating and self.cluster_ids is not None:
            attn = self.gate_network(attn, self.cluster_ids)

        out = (attn @ v).transpose(1, 2).reshape(B, N, C)
        return self.proj(out)




In [None]:
class Detector(nn.Module):
    def __init__(self,vit, num_blocks=3):
        super().__init__()
        self.vit = vit
        self.fc = nn.Sequential(
            nn.Linear(vit.num_features, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )
        self.num_blocks = num_blocks
    def forward(self, x):
        x = self.vit.patch_embed(x)
        cls_token = self.vit.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x = x + self.vit.pos_embed
        x = self.vit.pos_drop(x)

        for i in range(self.num_blocks):
            x = self.vit.blocks[i](x)

        x = self.vit.norm(x)
        feats = x[:, 0]

        out = self.fc(feats)
        return torch.sigmoid(out).squeeze()



In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
from typing import Optional, Tuple

class ViTDefenseSystem(nn.Module):
    def __init__(self,
                 base_model,
                 classifier_path: str,
                 diagnoser_path: str,
                 diagnoser_head_path: str,
                 num_clusters: int,
                 detection_threshold: float = 0.5):
        super().__init__()

        self.base_model = base_model
        self.detection_threshold = detection_threshold
        self.num_clusters = num_clusters

        print("Freezing base ViT model...")
        for name, param in self.base_model.named_parameters():
            param.requires_grad = False
        self.base_model.eval()
        print(f"Frozen {sum(p.numel() for p in self.base_model.parameters())} base model parameters")


        vit_dummy = timm.create_model('vit_base_patch16_224', pretrained=True)
        detector = Detector(vit_dummy, num_blocks=6)

        detector.load_state_dict(torch.load(classifier_path, map_location='cpu'))

        self.classifier = detector.fc

        self.classifier.eval()
        for p in self.classifier.parameters():
            p.requires_grad = False

        with open(diagnoser_path, 'rb') as f:
            self.diagnoser_utils = pickle.load(f)

        self.diagnoser_head = AttentionDiagnosisHead(input_dim=768, hidden_dim=64)
        self.diagnoser_head.load_state_dict(torch.load(diagnoser_head_path, map_location='cpu'))
        self.diagnoser_head.eval()



        for p in self.diagnoser_head.parameters():
            p.requires_grad = False

        dummy_input = torch.randn(1, 3, 224, 224)
        with torch.no_grad():
            _ = base_model.patch_embed(dummy_input)
            seq_len = base_model.pos_embed.shape[1]
            num_heads = base_model.blocks[0].attn.num_heads

        self.gated_blocks = nn.ModuleList()
        self.gate_network = ClusteredGateNetwork(num_clusters, num_heads, seq_len)

        for i, block in enumerate(self.base_model.blocks):
            if i >= 6:
                gated_attn = GatedAttentionBlock(block.attn, self.gate_network)
                block.attn = gated_attn
            self.gated_blocks.append(block)
        self._verify_freezing()

    def _verify_freezing(self):

        total_params = 0
        trainable_params = 0
        frozen_params = 0

        for name, param in self.named_parameters():
            total_params += param.numel()
            if param.requires_grad:
                trainable_params += param.numel()
                print(f"TRAINABLE: {name} ({param.numel()} params)")
            else:
                frozen_params += param.numel()

        print(f"\nSUMMARY:")
        print(f"Total parameters: {total_params:,}")
        print(f"Trainable parameters: {trainable_params:,}")
        print(f"Frozen parameters: {frozen_params:,}")
        print(f"Trainable percentage: {100*trainable_params/total_params:.2f}%")

        expected_trainable = sum(p.numel() for p in self.gate_network.parameters())
        if trainable_params == expected_trainable:
            print("SUCCESS: Only gate network parameters are trainable!")
        else:
            print("ERROR: Unexpected trainable parameters detected!")
            print(f"Expected {expected_trainable}, got {trainable_params}")


    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, dict]:
        B = x.size(0)
        debug_info = {}

        x = self.base_model.patch_embed(x)
        cls_token = self.base_model.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x = self.base_model.pos_drop(x + self.base_model.pos_embed)

        cluster_ids = torch.full((B,), -1, dtype=torch.long, device=x.device)

        for i, block in enumerate(self.base_model.blocks):
            if i == 5:
                cls_embeddings = x[:, 0]
                with torch.no_grad():
                    adv_probs = self.classifier(cls_embeddings)
                    adv_probs = torch.sigmoid(adv_probs).squeeze()
                    is_adv = adv_probs > self.detection_threshold

                cluster_ids = self._diagnose_cluster(cls_embeddings, is_adv)
                debug_info["cluster_ids"] = cluster_ids
                debug_info["is_adversarial"] = is_adv

                for j in range(i + 1, len(self.base_model.blocks)):
                    attn = self.base_model.blocks[j].attn
                    if isinstance(attn, GatedAttentionBlock):
                        attn.apply_gating = True
                        attn.cluster_ids = cluster_ids

            x = block(x)

        x = self.base_model.norm(x)
        logits = self.base_model.head(x[:, 0])

        return logits, debug_info

    def _diagnose_cluster(self, cls_embeddings: torch.Tensor, is_adv: torch.Tensor) -> torch.Tensor:

        emb_np = cls_embeddings.detach().cpu().numpy()

        if emb_np.shape[1] > self.diagnoser_utils['pca'].n_features_in_:
            emb_np = emb_np[:, :self.diagnoser_utils['pca'].n_features_in_]

        pca_features = self.diagnoser_utils['pca'].transform(emb_np)
        cluster_preds = self.diagnoser_utils['kmeans'].predict(pca_features)
        return torch.tensor(cluster_preds, device=cls_embeddings.device)

In [None]:
import os
import random
import pandas as pd
import pickle
import re
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torch
from torchvision import transforms
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

vit = timm.create_model('vit_base_patch16_224', pretrained=True).eval()
config = resolve_data_config({}, model=vit)
transform = create_transform(**config)

def extract_original_filename(path):
    basename = os.path.basename(path)
    match = re.search(r'ILSVRC2012_val_\d+\.JPEG', basename)
    if match:
        return match.group(0)
    return basename

def save_used_basenames(train_csv, val_csv, output_path):
    used = set()

    train_df = pd.read_csv(train_csv)
    train_used = set(train_df['image_path'].apply(extract_original_filename))
    used.update(train_used)
    print(f"Found {len(train_used)} unique images in train adversarial dataset")

    val_df = pd.read_csv(val_csv)
    val_used = set(val_df['image_path'].apply(extract_original_filename))
    used.update(val_used)
    print(f"Found {len(val_used)} unique images in val adversarial dataset")

    print(f"Total unique images across both datasets: {len(used)}")
    print(f"Overlap between train and val adversarial: {len(train_used.intersection(val_used))}")

    with open(output_path, 'wb') as f:
        pickle.dump(used, f)
    print(f"Saved {len(used)} original ILSVRC basenames to {output_path}")
    print("These images will be excluded when adding extra clean samples to training")

def load_used_basenames(pkl_path):
    with open(pkl_path, 'rb') as f:
        used = pickle.load(f)
    return used

used_filenames = load_used_basenames('/content/drive/MyDrive/my231n/used_basenames.pkl')

def get_extra_clean_examples(val_dir, used_filenames, n=7000):
    """Get extra clean examples from val directory, excluding already used images"""
    all_clean = []
    for root, _, files in os.walk(val_dir):
        for file in files:
            if file.endswith('.JPEG') and file not in used_filenames:
                all_clean.append(os.path.relpath(os.path.join(root, file), val_dir))

    if len(all_clean) < n:
        print(f"Warning: Requested {n} clean samples, but only found {len(all_clean)} unused ones.")
        print(f"Using all {len(all_clean)} available unused samples.")
        n = len(all_clean)

    sampled = random.sample(all_clean, n)
    new_rows = [{
        'image_path': os.path.join('val', path),
        'attack_type': 'clean',
        'original_class': -1
    } for path in sampled]
    return pd.DataFrame(new_rows)

class AdversarialDetectionDataset(Dataset):
    def __init__(self, metadata_csv, root_dir, split, transform):
        self.root_dir = root_dir
        self.transform = transform

        if split == 'train':
            self.df = pd.read_csv(metadata_csv)
            self.df = self.df[self.df['attack_type'] != 'CW']

            val_df = pd.read_csv('/content/drive/MyDrive/my231n/adversarial_val_dataset/metadata_with_clean.csv')
            val_original_names = set(val_df['image_path'].apply(extract_original_filename))

            original_len = len(self.df)
            train_original_names = self.df['image_path'].apply(extract_original_filename)
            overlap_mask = train_original_names.isin(val_original_names)
            self.df = self.df[~overlap_mask]

            print(f"Removed {overlap_mask.sum()} overlapping images from training set")
            print(f"Training set reduced from {original_len} to {len(self.df)} samples")

            clean_df = self.df[self.df['attack_type'].str.lower() == 'clean']
            adv_df = self.df[self.df['attack_type'].str.lower() != 'clean']

            extra_clean_df = get_extra_clean_examples(
                val_dir=os.path.join(root_dir, 'val'),
                used_filenames=used_filenames,
                n=8512
            )
            print(f"Dataset composition: {len(clean_df)} original clean, {len(adv_df)} adversarial, {len(extra_clean_df)} extra clean")
            self.df = pd.concat([clean_df, adv_df, extra_clean_df], ignore_index=True)

        elif split == 'val':
            self.df = pd.read_csv(metadata_csv)
            split_keyword = 'adversarial_val_dataset'
            self.df = self.df[self.df['image_path'].str.contains(split_keyword)]
        else:
            raise ValueError(f"Unsupported split: {split}")

        self.df = self.df.reset_index(drop=True)
        if len(self.df) == 0:
            raise ValueError(f"No data found for split '{split}'.")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        rel_path = row['image_path']
        full_path = rel_path if os.path.isabs(rel_path) else os.path.join(self.root_dir, rel_path)
        label = 1 if row['attack_type'].lower() == 'clean' else 0

        try:
            image = Image.open(full_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
        except Exception as e:
            print(f"Failed to load {full_path}: {e}")
            image = torch.zeros((3, 224, 224))
            label = -1

        return image, label

metadata_csv_path_train = '/content/drive/MyDrive/my231n/adversarial_train_dataset/metadata_with_clean.csv'
metadata_csv_path_val = '/content/drive/MyDrive/my231n/adversarial_val_dataset/metadata_with_clean.csv'
image_root = '/content/drive/MyDrive/my231n/'
train_dataset = AdversarialDetectionDataset(metadata_csv_path_train, image_root, split='train', transform=transform)
val_dataset = AdversarialDetectionDataset(metadata_csv_path_val, image_root, split='val', transform=transform)

from collections import Counter
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

print(f"Loaded {len(train_dataset)} training and {len(val_dataset)} validation samples.")

def verify_no_overlap():
    """Verify there's no overlap between training and validation sets"""
    train_original_names = set()
    val_original_names = set()

    for _, row in train_dataset.df.iterrows():
        original_name = extract_original_filename(row['image_path'])
        train_original_names.add(original_name)

    for _, row in val_dataset.df.iterrows():
        original_name = extract_original_filename(row['image_path'])
        val_original_names.add(original_name)

    overlap = train_original_names.intersection(val_original_names)
    print(f"Overlap verification: {len(overlap)} overlapping images found")
    if len(overlap) > 0:
        print(f"Warning: Found overlapping images: {list(overlap)[:5]}...")
    else:
        print("✓ No overlap detected between train and validation sets")

verify_no_overlap()

Using device: cuda


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Removed 2000 overlapping images from training set
Training set reduced from 9728 to 7728 samples
Using all 0 available unused samples.
Dataset composition: 966 original clean, 6762 adversarial, 0 extra clean
Loaded 7728 training and 13068 validation samples.
Overlap verification: 0 overlapping images found
✓ No overlap detected between train and validation sets


In [None]:
class AttentionDiagnosisHead(nn.Module):
    def __init__(self, input_dim, hidden_dim=64):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2)
        )


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import timm

vit = timm.create_model('vit_base_patch16_224', pretrained=True)
vit.head = nn.Linear(vit.head.in_features, 1000)
classifier_path = '/content/drive/MyDrive/my231n/cs231n_project/detector_epoch_final.pth'
diagnoser_path = '/content/drive/MyDrive/my231n/cs231n_project/diagnoser_utils.pkl'
diagnoser_head_path = '/content/drive/MyDrive/my231n/cs231n_project/diagnoser_epoch_5.pt'

defense_model = ViTDefenseSystem(
    base_model=vit,
    classifier_path=classifier_path,
    diagnoser_path=diagnoser_path,
    diagnoser_head_path=diagnoser_head_path,
    num_clusters=3,
    detection_threshold=0.5
)


In [None]:
import os
from tqdm import tqdm
import torch.nn.functional as F

optimizer = torch.optim.Adam(defense_model.gate_network.parameters(), lr=1e-4)
save_every = 10
checkpoint_dir = "checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

step = 0
for images, labels in tqdm(train_loader):
    logits, debug = defense_model(images)
    loss = F.cross_entropy(logits, labels)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if step % save_every == 0:
        ckpt_path = os.path.join(checkpoint_dir, f"defense_step_{step}.pt")
        torch.save({
            'step': step,
            'model_state_dict': defense_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss.item(),
        }, ckpt_path)
        print(f"[Checkpoint saved] → {ckpt_path}")

    step += 1
