## Setup

In [None]:
INPUT_PATH = "images/data"
OUTPUT_PATH = "images/outputs"
MODEL_PATH = "adversarial_models/CLIP_discriminator.pt"
GENERATORS = [
    "dalle",
    "openjourney",
    "stable_diff",
    "openjourney_v4",
    "titan"
]
BATCH_SIZE = 4
EXPERIMENT_MODE = "pgd" # "pgd" or "patch" or "attn"

In [86]:
from google.colab import drive
drive.mount("/content/drive")

ModuleNotFoundError: No module named 'google.colab'

In [87]:
!pip install torchattacks --quiet
!pip install transformers --quiet

/bin/bash: pip: command not found
/bin/bash: pip: command not found


In [88]:
import os
import ast
import csv

import numpy as np
import pandas as pd

from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import CLIPProcessor, CLIPModel

from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score,
    average_precision_score,
)

from torchattacks import PGD
from torchattacks.attack import Attack

import torchvision

from tqdm import tqdm

os.makedirs(OUTPUT_PATH, exist_ok=True) 

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE

'cpu'

## Discriminator

In [115]:
class CLIPSVMDiscriminator:
    def __init__(self, model_name="openai/clip-vit-base-patch32", device=None):
        self.device = (
            device
            if device is not None
            else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        )
        print("Running on:", self.device)
        self.model = CLIPModel.from_pretrained(model_name)
        self.processor = CLIPProcessor.from_pretrained(model_name)
        self.model.to(self.device)
        self.model.eval()
        self.svm = SVC(kernel="linear", C=1.0, probability=True)
        self.svm_trained = False

    def run_clip(self, imgs):
        # inputs = self.processor(images=imgs, return_tensors="pt").to(self.device)
        with torch.no_grad():
            outputs = self.model.vision_model(imgs)
            image_features = outputs.last_hidden_state[:, 0, :]
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            return image_features.squeeze().cpu().numpy()

    def train_svm(self, X_train, y_train):
        self.svm.fit(X_train, y_train)
        self.svm_trained = True
        train_accuracy = self.svm.score(X_train, y_train)
        print(f"Training accuracy for discriminator: {train_accuracy:.4f}")
        return self.svm

    def predict_from_embeddings(self, embeddings):
        preds = self.svm.predict(embeddings)
        probs = self.svm.predict_proba(embeddings)[:, 1]
        return preds, probs

    def evaluate(self, X_test, y_test):
        model = self.svm
        y_pred = model.predict(X_test)
        y_pred_proba = model.predict_proba(X_test)[:, 1]
        accuracy = accuracy_score(y_test, y_pred)
        precision = precision_score(y_test, y_pred, average="weighted")
        recall = recall_score(y_test, y_pred, average="weighted")
        f1 = f1_score(y_test, y_pred, average="weighted")
        auc = roc_auc_score(y_test, y_pred_proba)
        ap_per_class = []
        for class_label in np.unique(y_test):
            y_test_binary = (y_test == class_label).astype(int)
            ap = average_precision_score(y_test_binary, y_pred_proba)
            ap_per_class.append(ap)
        map_score = np.mean(ap_per_class)
        print(f"Accuracy: {accuracy:.4f}")
        print(f"Precision: {precision:.4f}")
        print(f"Recall: {recall:.4f}")
        print(f"F1 Score: {f1:.4f}")
        print(f"AUC: {auc:.4f}")
        print(f"mAP: {map_score:.4f}")
        return {
            "accuracy": accuracy,
            "precision": precision,
            "recall": recall,
            "f1": f1,
            "auc": auc,
            "map": map_score,
        }

In [116]:
model: CLIPSVMDiscriminator = torch.load(
    MODEL_PATH, weights_only=False, map_location=DEVICE
)

## Dataloader

In [None]:
class ArtEmbeddingDataset(Dataset):
    def __init__(self, ai_only=False):
        self.transform = torchvision.transforms.Resize((224, 224))
        self.image_info = {}
        for directory in GENERATORS:
            if not os.path.exists(os.path.join(INPUT_PATH, directory)):
                print(f"{directory} does not exist. Skipping.")
                continue
            print(f"{directory} has {len(os.listdir(os.path.join(INPUT_PATH, directory)))} images.")
            for filepath in tqdm(os.listdir(os.path.join(INPUT_PATH, directory)), desc="Loading "+directory):
                full_path = os.path.join(INPUT_PATH, directory, filepath)
                if full_path.endswith(".png") or full_path.endswith(".jpg"):
                    id_idx = filepath.rfind('_') + 1
                    id = filepath[id_idx:-4]
                    label = 1 if directory == "real" else 0
                    self.image_info[full_path] = {
                        "generator": directory,
                        "label": label, # 0 = fake, 1 = real
                        "id": id,
                        "data": torchvision.io.read_image(full_path),
                    } 
        self.paths = list(self.image_info.keys())

    def preprocess_image(self, image_path):
        # Load the image to torch
        image = torchvision.io.read_image(image_path)
        # image = self.transform(image)
        # image = image.float() / 255.0  # Normalize to [0, 1]
        return image
    
    def __len__(self):
        return len(self.image_info)

    def __getitem__(self, idx):
        filepath = self.paths[idx]
        image = self.image_info[filepath]["data"]
        label = self.image_info[filepath]["label"]

        sample = {"filepath": filepath, "data": image, "label": label}

        return sample

In [None]:
dataset = ArtEmbeddingDataset(ai_only=True)
len(dataset)

dalle has 435 images.


Loading dalle: 100%|██████████| 435/435 [00:03<00:00, 128.75it/s]


openjourney does not exist. Skipping.
stable_diff does not exist. Skipping.
openjourney_v4 does not exist. Skipping.
titan has 2061 images.


Loading titan: 100%|██████████| 2061/2061 [01:04<00:00, 32.12it/s]


['dalle', 'openjourney', 'stable_diff', 'openjourney_v4', 'titan']

In [75]:
TRANSFORM = torchvision.transforms.Resize((224, 224))

def collate_fn(batch):
    images = [item["data"] for item in batch]
    labels = [item["label"] for item in batch]
    paths = [item["filepath"] for item in batch]
    labels = torch.tensor(labels)
    
    images = torch.stack(images)
    images = images.float() / 255.0  # Normalize to [0, 1]
    images = TRANSFORM(images)

    images = images.to(DEVICE)
    labels = labels.to(DEVICE)
    
    return images, labels, paths

In [91]:
dataloader = DataLoader(
    dataset=dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_fn,
)

## Attacks

In [98]:
class CLIPPGDAttack(PGD):
    def __init__(self, model, svm, eps=8 / 255, alpha=2 / 255, steps=10, random_start=True):
        super().__init__(model, eps, alpha, steps, random_start)
        self.svm_weights = torch.FloatTensor(svm.coef_[0])
        self.svm_bias = torch.tensor(svm.intercept_[0])

  
    def get_logits(self, inputs):
        if self._normalization_applied is False:
            inputs = self.normalize(inputs)

        # Get image features from the vision model
        vision_outputs = self.model.vision_model(inputs)
        image_features = vision_outputs.last_hidden_state[:, 0, :]

        return image_features
    
    def svm_boundary_loss(self, clip_embedding):
        # Distance to decision boundary (negative = wrong side)
        if self.svm_weights is None or self.svm_bias is None:
            raise ValueError("SVM weights and bias not set. Call set_svm_params() first.")
        
        distance = torch.matmul(clip_embedding, self.svm_weights) + self.svm_bias
        # Loss is higher when distance is positive (correct classification)
        return -distance  # Maximize to cross boundary
        
    def forward(self, images, labels):
        """
        Override forward method to use custom loss function
        """
        images = images.clone().detach().to(self.device)
        
        adv_images = images.clone().detach()
        
        if self.random_start:
            # Starting at a uniformly random point
            adv_images = adv_images + torch.empty_like(adv_images).uniform_(-self.eps, self.eps)
            adv_images = torch.clamp(adv_images, min=0, max=1).detach()

        for _ in range(self.steps):
            adv_images.requires_grad = True
            
            # Forward pass
            outputs = self.get_logits(adv_images)
            
            # Calculate loss using our custom SVM boundary loss
            loss = self.svm_boundary_loss(outputs).mean()
            # print(f"Loss: {loss.item():.6f}")
            
            # Backward pass
            grad = torch.autograd.grad(loss, adv_images,
                                      retain_graph=False, create_graph=False)[0]

            adv_images = adv_images.detach() - self.alpha * grad.sign()
            delta = torch.clamp(adv_images - images, min=-self.eps, max=self.eps)
            adv_images = torch.clamp(images + delta, min=0, max=1).detach()

        return adv_images
    

In [99]:
class CLIPPatchPGDAttack():
    pass

In [100]:
class CLIPAttentionPatchPGDAttack():
    pass

In [101]:
if EXPERIMENT_MODE == "pgd":
    attack = CLIPPGDAttack(model.model, model.svm, eps=8 / 255, alpha=2 / 255, steps=10, random_start=True)   
elif EXPERIMENT_MODE == "patch":
    attack = CLIPPatchPGDAttack(model.model, model.svm, eps=8 / 255, alpha=2 / 255, steps=10, random_start=True)
elif EXPERIMENT_MODE == "attn":
    attack = CLIPAttentionPatchPGDAttack(model.model, model.svm, eps=8 / 255, alpha=2 / 255, steps=10, random_start=True)
else:
    raise ValueError(f"Invalid experiment mode:", EXPERIMENT_MODE)

## Generate Adversarial Images

In [102]:
for i, (images, labels, paths) in enumerate(tqdm(dataloader)):
    images = images.to(DEVICE)
    labels = labels.to(DEVICE)

    # Generate adversarial examples
    adv_images = attack(images, labels)
    
    # Save the adversarial images
    for j in range(len(images)):
        image_path = paths[j]
        basename = os.path.basename(image_path)
        adv_image_path = os.path.join(OUTPUT_PATH, f"adv_{EXPERIMENT_MODE}_{basename}")
        torchvision.utils.save_image(adv_images[j], adv_image_path)

  0%|          | 0/624 [00:00<?, ?it/s]

  0%|          | 1/624 [01:30<15:39:25, 90.47s/it]


KeyboardInterrupt: 

## Test Attack Effectiveness

In [None]:
images = []
labels = []
for file in tqdm(os.listdir(OUTPUT_PATH), desc="Loading generated images"):
    images.append(torchvision.io.read_image(os.path.join(OUTPUT_PATH, file)))
    labels.append(0)
for file in tqdm(os.listdir(os.path.join(INPUT_PATH, "real")), desc="Loading real images"):
    images.append(TRANSFORM(torchvision.io.read_image(os.path.join(INPUT_PATH, "real", file))))
    labels.append(1)


images = torch.stack(images)
images = images.float() / 255.0  # Normalize to [0, 1]

clip_embeddings = model.run_clip(images)
print(clip_embeddings.shape)

model.evaluate(clip_embeddings, labels)

Loading generated images: 100%|██████████| 4/4 [00:00<00:00, 720.86it/s]
Loading real images: 100%|██████████| 3606/3606 [01:01<00:00, 58.29it/s]


KeyboardInterrupt: 