## USE CLS TOKEN

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import random
import pandas as pd
import pickle
import re
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torch
from torchvision import transforms
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

vit = timm.create_model('vit_base_patch16_224', pretrained=True).eval()
config = resolve_data_config({}, model=vit)
transform = create_transform(**config)

def extract_original_filename(path):
    """Extract the original ILSVRC filename from adversarial image paths"""
    basename = os.path.basename(path)
    match = re.search(r'ILSVRC2012_val_\d+\.JPEG', basename)
    if match:
        return match.group(0)
    return basename

def save_used_basenames(train_csv, val_csv, output_path):
    used = set()

    train_df = pd.read_csv(train_csv)
    train_used = set(train_df['image_path'].apply(extract_original_filename))
    used.update(train_used)
    print(f"Found {len(train_used)} unique images in train adversarial dataset")

    val_df = pd.read_csv(val_csv)
    val_used = set(val_df['image_path'].apply(extract_original_filename))
    used.update(val_used)
    print(f"Found {len(val_used)} unique images in val adversarial dataset")

    print(f"Total unique images across both datasets: {len(used)}")
    print(f"Overlap between train and val adversarial: {len(train_used.intersection(val_used))}")

    with open(output_path, 'wb') as f:
        pickle.dump(used, f)
    print(f"Saved {len(used)} original ILSVRC basenames to {output_path}")
    print("These images will be excluded when adding extra clean samples to training")

def load_used_basenames(pkl_path):
    with open(pkl_path, 'rb') as f:
        used = pickle.load(f)
    return used

used_filenames = load_used_basenames('/content/drive/MyDrive/my231n/used_basenames.pkl')

def get_extra_clean_examples(val_dir, used_filenames, n=7000):
    """Get extra clean examples from val directory, excluding already used images"""
    all_clean = []
    for root, _, files in os.walk(val_dir):
        for file in files:
            if file.endswith('.JPEG') and file not in used_filenames:
                all_clean.append(os.path.relpath(os.path.join(root, file), val_dir))

    if len(all_clean) < n:
        print(f"Warning: Requested {n} clean samples, but only found {len(all_clean)} unused ones.")
        print(f"Using all {len(all_clean)} available unused samples.")
        n = len(all_clean)

    sampled = random.sample(all_clean, n)
    new_rows = [{
        'image_path': os.path.join('val', path),
        'attack_type': 'clean',
        'original_class': -1
    } for path in sampled]
    return pd.DataFrame(new_rows)

class AdversarialDetectionDataset(Dataset):
    def __init__(self, metadata_csv, root_dir, split, transform):
        self.root_dir = root_dir
        self.transform = transform

        if split == 'train':
            self.df = pd.read_csv(metadata_csv)
            self.df = self.df[self.df['attack_type'] != 'CW']

            val_df = pd.read_csv('/content/drive/MyDrive/my231n/adversarial_val_dataset/metadata_with_clean.csv')
            val_original_names = set(val_df['image_path'].apply(extract_original_filename))

            original_len = len(self.df)
            train_original_names = self.df['image_path'].apply(extract_original_filename)
            overlap_mask = train_original_names.isin(val_original_names)
            self.df = self.df[~overlap_mask]

            print(f"Removed {overlap_mask.sum()} overlapping images from training set")
            print(f"Training set reduced from {original_len} to {len(self.df)} samples")

            clean_df = self.df[self.df['attack_type'].str.lower() == 'clean']
            adv_df = self.df[self.df['attack_type'].str.lower() != 'clean']

            extra_clean_df = get_extra_clean_examples(
                val_dir=os.path.join(root_dir, 'val'),
                used_filenames=used_filenames,
                n=8512
            )
            print(f"Dataset composition: {len(clean_df)} original clean, {len(adv_df)} adversarial, {len(extra_clean_df)} extra clean")
            self.df = pd.concat([clean_df, adv_df, extra_clean_df], ignore_index=True)

        elif split == 'val':
            self.df = pd.read_csv(metadata_csv)
            split_keyword = 'adversarial_val_dataset'
            self.df = self.df[self.df['image_path'].str.contains(split_keyword)]
        else:
            raise ValueError(f"Unsupported split: {split}")

        self.df = self.df.reset_index(drop=True)
        if len(self.df) == 0:
            raise ValueError(f"No data found for split '{split}'.")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        rel_path = row['image_path']
        full_path = rel_path if os.path.isabs(rel_path) else os.path.join(self.root_dir, rel_path)
        label = 1 if row['attack_type'].lower() == 'clean' else 0

        try:
            image = Image.open(full_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
        except Exception as e:
            print(f"Failed to load {full_path}: {e}")
            image = torch.zeros((3, 224, 224))
            label = -1

        return image, label

metadata_csv_path_train = '/content/drive/MyDrive/my231n/adversarial_train_dataset/metadata_with_clean.csv'
metadata_csv_path_val = '/content/drive/MyDrive/my231n/adversarial_val_dataset/metadata_with_clean.csv'
image_root = '/content/drive/MyDrive/my231n/'

train_dataset = AdversarialDetectionDataset(metadata_csv_path_train, image_root, split='train', transform=transform)
val_dataset = AdversarialDetectionDataset(metadata_csv_path_val, image_root, split='val', transform=transform)

from collections import Counter
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

print(f"Loaded {len(train_dataset)} training and {len(val_dataset)} validation samples.")

def verify_no_overlap():
    train_original_names = set()
    val_original_names = set()

    for _, row in train_dataset.df.iterrows():
        original_name = extract_original_filename(row['image_path'])
        train_original_names.add(original_name)

    for _, row in val_dataset.df.iterrows():
        original_name = extract_original_filename(row['image_path'])
        val_original_names.add(original_name)

    overlap = train_original_names.intersection(val_original_names)
    print(f"Overlap verification: {len(overlap)} overlapping images found")
    if len(overlap) > 0:
        print(f"Warning: Found overlapping images: {list(overlap)[:5]}...")
    else:
        print("✓ No overlap detected between train and validation sets")
verify_no_overlap()

Using device: cuda
Removed 2000 overlapping images from training set
Training set reduced from 9728 to 7728 samples
Dataset composition: 966 original clean, 6762 adversarial, 8512 extra clean
Loaded 16240 training and 13068 validation samples.
Overlap verification: 0 overlapping images found
✓ No overlap detected between train and validation sets


In [None]:
import torch
import torch.nn as nn
import pandas as pd
import tqdm
import timm

class Detector(nn.Module):
    def __init__(self,vit, num_blocks=3):
        super().__init__()
        self.vit = vit
        self.fc = nn.Sequential(
            nn.Linear(vit.num_features, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )
        self.num_blocks = num_blocks
    def forward(self, x):
        x = self.vit.patch_embed(x)
        cls_token = self.vit.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x = x + self.vit.pos_embed
        x = self.vit.pos_drop(x)

        for i in range(self.num_blocks):
            x = self.vit.blocks[i](x)

        x = self.vit.norm(x)
        feats = x[:, 0]

        out = self.fc(feats)
        return torch.sigmoid(out).squeeze()




In [None]:

from tqdm import tqdm

vit = timm.create_model('vit_base_patch16_224', pretrained=True)
vit.eval()
device='cuda'
detector = Detector(vit,6).to(device)

from torch.optim import Adam
import torch.nn.functional as F

optimizer = Adam(detector.parameters(), lr=1e-4)
print(len(train_loader))

for epoch in range(10):
    detector.train()
    total_loss = 0
    total_correct = 0
    total_samples = 0

    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):


        images, labels = images.to(device), labels.to(device).float()
        preds = detector(images)
        loss = F.binary_cross_entropy(preds, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * images.size(0)

        predicted_labels = (preds >= 0.5).float()
        total_correct += (predicted_labels == labels).sum().item()
        total_samples += labels.size(0)
    torch.save(detector.state_dict(), f'detector_epoch_{epoch+1}.pth')

    avg_loss = total_loss / total_samples
    accuracy = total_correct / total_samples

    print(f"Epoch {epoch+1}, Avg Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")


508


Epoch 1: 100%|██████████| 508/508 [24:44<00:00,  2.92s/it]


Epoch 1, Avg Loss: 0.2277, Accuracy: 0.9181


Epoch 2: 100%|██████████| 508/508 [01:01<00:00,  8.29it/s]


Epoch 2, Avg Loss: 0.1974, Accuracy: 0.9315


Epoch 3: 100%|██████████| 508/508 [01:01<00:00,  8.28it/s]


Epoch 3, Avg Loss: 0.1903, Accuracy: 0.9342


Epoch 4: 100%|██████████| 508/508 [01:01<00:00,  8.32it/s]


Epoch 4, Avg Loss: 0.1912, Accuracy: 0.9344


Epoch 5: 100%|██████████| 508/508 [01:01<00:00,  8.29it/s]


Epoch 5, Avg Loss: 0.1949, Accuracy: 0.9323


Epoch 6: 100%|██████████| 508/508 [01:01<00:00,  8.27it/s]


Epoch 6, Avg Loss: 0.1926, Accuracy: 0.9337


Epoch 7: 100%|██████████| 508/508 [01:01<00:00,  8.30it/s]


Epoch 7, Avg Loss: 0.1832, Accuracy: 0.9381


Epoch 8: 100%|██████████| 508/508 [01:01<00:00,  8.29it/s]


Epoch 8, Avg Loss: 0.1866, Accuracy: 0.9363


Epoch 9: 100%|██████████| 508/508 [01:01<00:00,  8.30it/s]


Epoch 9, Avg Loss: 0.1894, Accuracy: 0.9349


Epoch 10: 100%|██████████| 508/508 [01:01<00:00,  8.31it/s]


Epoch 10, Avg Loss: 0.1893, Accuracy: 0.9355


In [None]:
torch.save(detector.state_dict(), f'detector_epoch_final.pth')

In [None]:

detector.eval()
val_loss = 0
val_correct = 0
val_total = 0
from itertools import islice

with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device).float()
        preds = detector(images)
        loss = F.binary_cross_entropy(preds, labels)

        val_loss += loss.item() * images.size(0)
        predicted_labels = (preds >= 0.5).float()
        val_correct += (predicted_labels == labels).sum().item()
        val_total += labels.size(0)

val_avg_loss = val_loss / val_total
val_accuracy = val_correct / val_total

print(f"[Final Validation] Loss: {val_avg_loss:.4f}, Accuracy: {val_accuracy:.4f}")


[Final Validation] Loss: 0.4176, Accuracy: 0.8783


## USE POOLING

In [None]:
import torch
import torch.nn as nn
import pandas as pd
import tqdm
import timm

class Detector(nn.Module):
    def __init__(self,vit, num_blocks=3):
        super().__init__()
        self.vit = vit
        self.fc = nn.Sequential(
            nn.Linear(vit.num_features, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )
        self.num_blocks = num_blocks
    def forward(self, x):
        x = self.vit.patch_embed(x)
        cls_token = self.vit.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x = x + self.vit.pos_embed
        x = self.vit.pos_drop(x)

        for i in range(self.num_blocks):
            x = self.vit.blocks[i](x)

        x = self.vit.norm(x)
        feats = x[:, 1:, :].mean(dim=1)

        out = self.fc(feats)
        return torch.sigmoid(out).squeeze()


vit = timm.create_model('vit_base_patch16_224', pretrained=True)
vit.eval()
device='cuda'
detector = Detector(vit).to(device)


In [None]:
from torch.optim import Adam
import torch.nn.functional as F

optimizer = Adam(detector.parameters(), lr=1e-4)

for epoch in range(10):
    detector.train()
    total_loss = 0
    total_correct = 0
    total_samples = 0

    for i, (images, labels, _, _, _) in enumerate(train_loader, 1):
        print(f"Epoch {epoch}, Step {i}")
        images, labels = images.to(device), labels.to(device).float()
        preds = detector(images)
        loss = F.binary_cross_entropy(preds, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * images.size(0)
        predicted_labels = (preds >= 0.5).float()
        total_correct += (predicted_labels == labels).sum().item()
        total_samples += labels.size(0)

    train_avg_loss = total_loss / total_samples
    train_accuracy = total_correct / total_samples
    print(f"[Train] Epoch {epoch+1}, Loss: {train_avg_loss:.4f}, Accuracy: {train_accuracy:.4f}")

    # Validation
    detector.eval()
    val_loss = 0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for images, labels, _, _, _ in val_loader:
            images, labels = images.to(device), labels.to(device).float()
            preds = detector(images)
            loss = F.binary_cross_entropy(preds, labels)

            val_loss += loss.item() * images.size(0)
            predicted_labels = (preds >= 0.5).float()
            val_correct += (predicted_labels == labels).sum().item()
            val_total += labels.size(0)

    val_avg_loss = val_loss / val_total
    val_accuracy = val_correct / val_total
    print(f"[Val]   Epoch {epoch+1}, Loss: {val_avg_loss:.4f}, Accuracy: {val_accuracy:.4f}")


304
Epoch 1, Avg Loss: 0.3847, Accuracy: 0.8750
Epoch 2, Avg Loss: 0.3810, Accuracy: 0.8750
Epoch 3, Avg Loss: 0.3799, Accuracy: 0.8750
Epoch 4, Avg Loss: 0.3822, Accuracy: 0.8750
Epoch 5, Avg Loss: 0.3794, Accuracy: 0.8750
Epoch 6, Avg Loss: 0.3794, Accuracy: 0.8750
Epoch 7, Avg Loss: 0.3783, Accuracy: 0.8750
Epoch 8, Avg Loss: 0.3789, Accuracy: 0.8750
Epoch 9, Avg Loss: 0.3786, Accuracy: 0.8750
Epoch 10, Avg Loss: 0.3781, Accuracy: 0.8750


In [None]:
torch.save