In [1]:
import clip
import torch
import os
import json
import math
import pandas as pd
from pathlib import Path
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms
from transformers import AutoImageProcessor, AutoModel, AutoProcessor, AutoModelForImageClassification
import torch.nn.functional as F
from sklearn.random_projection import SparseRandomProjection, GaussianRandomProjection
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc, roc_curve
from scipy.special import softmax

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
SRC_PATH = "../Data/GenImage/"
generator_names = ["adm", "bgan", "glide", "midj", "sd_14", "sd_15", "vqdm", "wukong"]
with open("classes.json", "r", encoding="utf-8") as f:
    data = json.load(f)
classes_idx = data["1k_idx"]
classes_names = data["21k_idx"]

In [None]:
def mad(x, dim=-1, keepdim=False):
    """
    Calculate median absolute deviation (MAD)
    MAD = median(|x - median(x)|)
    """
    med = x.median(dim=dim, keepdim=True)[0]
    mad = (x - med).abs().median(dim=dim, keepdim=keepdim)[0]

    mad = torch.clamp(mad, min=1e-6)
    return mad


class GaussianRPO:
    def __init__(self, D, dtype, M=1000, device="cuda", seed=None):
        """
        D: Data Dimension
        M: Number of projectors
        """
        self.D = D
        self.M = M
        self.device = device
        if seed is not None:
            torch.manual_seed(seed)

        # Gaussian random projection M x D
        U = torch.randn(M, D, device=device, dtype=dtype)
        U = U / U.norm(dim=1, keepdim=True)  # Normalization to sphere of norm 1
        self.U = U  # projector matrix

    def fit(self, X):
        """
        Input X: shape (B, D)
        Calculation of all MED 和 MAD to cache
        """
        X = X.to(self.device)
        # M x B，projection
        proj = torch.matmul(self.U, X.T)  # shape (M, B)
        self.med = proj.median(dim=1)[0]  # shape (M,)
        self.mad = mad(proj, dim=1)  # shape (M,)

    def score(self, x):
        """
        Calculation of RPO score for input batch
        x: (N, D)
        return shape (N,)
        """
        x = x.to(self.device)
        # (M, N) = (M, D) @ (D, N)
        proj_x = torch.matmul(self.U, x.T)  # shape (M, N)

        # broadcast : (M, 1)
        med = self.med.unsqueeze(1)
        mad = self.mad.unsqueeze(1)

        # computation
        score = torch.abs(proj_x - med) / mad  # shape (M, N)

        # choose the max score as final score for each sample
        rpo_score, _ = score.max(dim=0)  # shape (N,)

        return rpo_score


class SparseRPO:
    def __init__(self, D, dtype, M=1000, s=None, device="cuda", seed=None):
        """
        D: original dimension
        M: number of projectors
        s: parameter of scale of sparsity, by default sqrt(D)
        """
        self.D = D
        self.M = M
        self.device = device
        self.s = s or int(math.sqrt(D))
        if seed is not None:
            torch.manual_seed(seed)

        self.U = self._generate_sparse_projection_matrix(dtype)

    def _generate_sparse_projection_matrix(self, dtype):
        D, M, s = self.D, self.M, self.s
        prob_nonzero = 1.0 / s

        # initialize all zeros
        U = torch.zeros(M, D, device=self.device, dtype=dtype)

        # generate uniform random matrix
        rand_vals = torch.rand(M, D, device=self.device)

        pos_mask = rand_vals < (1 / (2 * s))
        neg_mask = (rand_vals >= (1 / (2 * s))) & (rand_vals < (1 / s))

        val = math.sqrt(s)
        U[pos_mask] = val
        U[neg_mask] = -val

        # Note: following the implementation of sklearn, there is no normalization

        return U

    def fit(self, X):
        """
        calculate MED and MAD for every projection direction
        X: (B, D)
        """
        X = X.to(self.device)
        proj = torch.matmul(self.U, X.T)  # (M, B)
        self.med = proj.median(dim=1)[0]  # (M,)
        self.mad = mad(proj, dim=1)  # (M,)

    def score(self, x):
        """
        RPO score calculation
        x: (N, D)
        """
        x = x.to(self.device)
        proj_x = torch.matmul(self.U, x.T)  # (M, N)
        med = self.med.unsqueeze(1)  # (M, 1)
        mad = self.mad.unsqueeze(1)  # (M, 1)

        score = torch.abs(proj_x - med) / mad  # (M, N)
        rpo_score, _ = score.max(dim=0)  # (N,)

        return rpo_score

In [13]:
def detector(test_tensor, gt_tensor, labels, M, seed, save_path, device = "cuda"):
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    test_tensor = test_tensor.to(device)
    gt_tensor = gt_tensor.to(device)
    datatype = gt_tensor.dtype
    dim = gt_tensor.shape[1]
    # rpo = RPO(D=dim, dtype = datatype, M=M, device=device, seed =seed)
    rpo = SparseRPO(D=dim, dtype = datatype, M = M, device = device, seed = seed)
    rpo.fit(gt_tensor)
    scores = rpo.score(test_tensor).cpu()
    scores = scores.numpy()

    fpr, tpr, thresholds = roc_curve(labels, scores)
    idx = np.where(tpr >= 0.95)[0][0]
    fpr_95 = fpr[idx]
    distances = np.sqrt((1 - tpr) ** 2 + fpr**2)
    best_threshold = thresholds[np.argmin(distances)]
    print("Best threshold(ROC):", best_threshold)

    roc_auc = roc_auc_score(labels, scores)
    # print("AUROC:", roc_auc)

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    ax1.plot(fpr, tpr, color="darkorange", lw=2, label=f"ROC curve (area = {roc_auc:.2f})")
    ax1.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
    ax1.set_xlim([0.0, 1.0])
    ax1.set_ylim([0.0, 1.05])
    ax1.set_xlabel("False Positive Rate (FPR)")
    ax1.set_ylabel("True Positive Rate (TPR)")
    ax1.set_title("Receiver Operating Characteristic (ROC) Curve")
    ax1.legend(loc="lower right")

    precision, recall, thresholds = precision_recall_curve(labels, scores)
    pr_auc = auc(recall, precision)
    # print("AUPRC:", pr_auc)

    ax2.plot(recall, precision, color="blue", lw=2, label=f"PR curve (area = {pr_auc:.2f})")
    ax2.set_xlim([0.0, 1.0])
    ax2.set_ylim([0.0, 1.05])
    ax2.set_xlabel("Recall")
    ax2.set_ylabel("Precision")
    ax2.set_title("Precision-Recall Curve")
    ax2.legend(loc="best")

    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

    return scores, roc_auc, pr_auc, fpr_95




In [8]:
class Dinov2Dataset(Dataset):
    def __init__(self, image_paths):
        self.image_paths = image_paths

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        path = self.image_paths[idx]
        try:
            image = Image.open(path).convert("RGB")
            return image, str(path)
        except Exception as e:
            print(f"Failure open image because of {e}")
            return None


def dinov2_collate_fn(batch):
    batch = [item for item in batch if item is not None]
    if not batch:
        return None, None
    images, paths = zip(*batch)
    return list(images), paths

def dinov2_encode(image_paths, batch_size = 64, model_name='facebook/dinov2-with-registers-base', device='cuda'):
    processor = AutoImageProcessor.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name).to(device)
    model.eval()
    image_paths = Path(image_paths)
    image_paths = list(image_paths.glob("*"))
    dataset = Dinov2Dataset(image_paths)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=dinov2_collate_fn)
    embeddings = []

    for images, paths in tqdm(dataloader, desc="Extracting cls tokens"):
        if images is None:
                continue
        # processor expects a list of PIL images
        inputs = processor(images=images, return_tensors="pt").to(device)
        with torch.no_grad():
            outputs = model(**inputs)
            cls_token = outputs.last_hidden_state[:, 0, :]  # remove CLS
            embeddings.append(cls_token.detach().cpu())

    embedding_all = torch.cat(embeddings, dim=0)
    
    return embedding_all

In [12]:
for cls in classes_idx:
    for generator in ["bgan", "midj", "sd_15", "nature", "nature_2"]:
        cls_tokens = dinov2_encode(f"../Data/GenImage/{cls}/{generator}")
        save_path = f"../Data/Features/dinov2cls/{cls}/{generator}.pt"
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        torch.save(cls_tokens, save_path)
    

Extracting cls tokens: 100%|██████████| 3/3 [00:02<00:00,  1.37it/s]
Extracting cls tokens: 100%|██████████| 3/3 [00:08<00:00,  2.79s/it]
Extracting cls tokens: 100%|██████████| 3/3 [00:04<00:00,  1.56s/it]
Extracting cls tokens: 100%|██████████| 3/3 [00:04<00:00,  1.41s/it]
Extracting cls tokens: 100%|██████████| 3/3 [00:04<00:00,  1.35s/it]
Extracting cls tokens: 100%|██████████| 3/3 [00:03<00:00,  1.04s/it]
Extracting cls tokens: 100%|██████████| 3/3 [00:09<00:00,  3.12s/it]
Extracting cls tokens: 100%|██████████| 3/3 [00:04<00:00,  1.64s/it]
Extracting cls tokens: 100%|██████████| 3/3 [00:03<00:00,  1.31s/it]
Extracting cls tokens: 100%|██████████| 3/3 [00:03<00:00,  1.31s/it]
Extracting cls tokens: 100%|██████████| 3/3 [00:03<00:00,  1.06s/it]
Extracting cls tokens: 100%|██████████| 3/3 [00:09<00:00,  3.13s/it]
Extracting cls tokens: 100%|██████████| 3/3 [00:04<00:00,  1.65s/it]
Extracting cls tokens: 100%|██████████| 3/3 [00:04<00:00,  1.41s/it]
Extracting cls tokens: 100%|██████

In [14]:
auroc = []
auprc = []
fpr95 = []
generator = []
for cls in classes_idx:
    bgan = torch.load(f"../Data/Features/dinov2cls/{cls}/bgan.pt", weights_only=True)
    midj = torch.load(f"../Data/Features/dinov2cls/{cls}/midj.pt", weights_only=True)
    sd_15 = torch.load(f"../Data/Features/dinov2cls/{cls}/sd_15.pt", weights_only=True)
    nature = torch.load(f"../Data/Features/dinov2cls/{cls}/nature.pt", weights_only=True)
    nature_2 = torch.load(f"../Data/Features/dinov2cls/{cls}/nature_2.pt", weights_only=True)
    bgan_m = torch.cat([nature_2, bgan], dim=0).to(device)
    midj_m = torch.cat([nature_2, midj], dim=0).to(device)
    sd_15_m = torch.cat([nature_2, sd_15], dim=0).to(device)
    labels = np.concatenate((np.zeros(nature_2.shape[0]), np.ones(bgan.shape[0])))
    s1, r1, p1, f1 = detector(
        bgan_m,
        nature,
        labels,
        1000,
        2025,
        f"../Data/RPO/Sparse/dinov2cls/{cls}/bgan.png",
    )
    generator.append("bgan")
    auroc.append(r1)
    auprc.append(p1)
    fpr95.append(f1)
    s2, r2, p2, f2 = detector(
        midj_m,
        nature,
        labels,
        1000,
        2025,
        f"../Data/RPO/Sparse/dinov2cls/{cls}/midj.png",
    )
    generator.append("midj")
    auroc.append(r2)
    auprc.append(p2)
    fpr95.append(f2)
    s3, r3, p3, f3 = detector(
        sd_15_m,
        nature,
        labels,
        1000,
        2025,
        f"../Data/RPO/Sparse/dinov2cls/{cls}/sd_15.png",
    )
    generator.append("sd_15")
    auroc.append(r3)
    auprc.append(p3)
    fpr95.append(f3)
data = {
    "CLASS": [x for x in classes_idx for _ in range(3)],
    "GENERATOR": generator,
    "AUROC": auroc,
    "AUPRC": auprc,
    "FPR95": fpr95,
}
df = pd.DataFrame(data)
df.to_csv("dinov2cls_rpo2_result.csv", index=False)
print(f"dinov2cls auroc: {np.mean(auroc)}, auprc: {np.mean(auprc)}, fpr95: {np.mean(fpr95)}")

Best threshold(ROC): 6.0072775
Best threshold(ROC): 5.009192
Best threshold(ROC): 5.340401
Best threshold(ROC): 5.97304
Best threshold(ROC): 5.496476
Best threshold(ROC): 3.9286146
Best threshold(ROC): 6.998828
Best threshold(ROC): 4.7329526
Best threshold(ROC): 4.4633145
Best threshold(ROC): 5.0392346
Best threshold(ROC): 5.189156
Best threshold(ROC): 4.7727094
Best threshold(ROC): 5.613031
Best threshold(ROC): 3.730882
Best threshold(ROC): 4.224724
Best threshold(ROC): 5.3669615
Best threshold(ROC): 4.2515326
Best threshold(ROC): 5.650461
Best threshold(ROC): 6.17908
Best threshold(ROC): 5.7054234
Best threshold(ROC): 6.1334214
Best threshold(ROC): 7.3345876
Best threshold(ROC): 6.2571163
Best threshold(ROC): 6.298293
Best threshold(ROC): 5.52156
Best threshold(ROC): 6.9023576
Best threshold(ROC): 6.9989862
Best threshold(ROC): 5.5880017
Best threshold(ROC): 5.0488777
Best threshold(ROC): 6.041898
dinov2cls auroc: 0.6201760402377686, auprc: 0.6058481069587169, fpr95: 0.70452674897119

In [16]:
for embedder in ["clip", "dinov2", "wst"]:
        auroc = []
        auprc = []
        fpr95 = []
        generator = []
        for cls in classes_idx:
            bgan = torch.load(f"../Data/Features/{embedder}/{cls}/bgan.pt", weights_only=True)["features"]
            midj = torch.load(f"../Data/Features/{embedder}/{cls}/midj.pt", weights_only=True)["features"]
            sd_15 = torch.load(f"../Data/Features/{embedder}/{cls}/sd_15.pt", weights_only=True)["features"]
            nature = torch.load(f"../Data/Features/{embedder}/{cls}/nature.pt", weights_only=True)["features"]
            nature_2 = torch.load(f"../Data/Features/{embedder}/{cls}/nature_2.pt", weights_only=True)["features"]
            bgan_m = torch.cat([bgan, nature_2], dim=0).to(device)
            midj_m = torch.cat([midj, nature_2], dim=0).to(device)
            sd_15_m = torch.cat([sd_15, nature_2], dim=0).to(device)
            labels = np.concatenate((np.zeros(bgan.shape[0]), np.ones(nature_2.shape[0])))
            s1, r1, p1, f1 = detector(
                bgan_m,
                nature,
                labels,
                1000,
                2025,
                f"../Data/RPO/Sparse/{embedder}/{cls}/bgan.png",
            )
            generator.append("bgan")
            auroc.append(r1)
            auprc.append(p1)
            fpr95.append(f1)
            s2, r2, p2, f2 = detector(
                midj_m,
                nature,
                labels,
                1000,
                2025,
                f"../Data/RPO/Sparse/{embedder}/{cls}/midj.png",
            )
            generator.append("midj")
            auroc.append(r2)
            auprc.append(p2)
            fpr95.append(f2)
            s3, r3, p3, f3 = detector(
                sd_15_m,
                nature,
                labels,
                1000,
                2025,
                f"../Data/RPO/Sparse/{embedder}/{cls}/sd_15.png",
            )
            generator.append("sd_15")
            auroc.append(r3)
            auprc.append(p3)
            fpr95.append(f3)
        data = {
            "CLASS": [x for x in classes_idx for _ in range(3)],
            "GENERATOR": generator,
            "AUROC": auroc,
            "AUPRC": auprc,
            "FPR95": fpr95,
        }
        df = pd.DataFrame(data)
        df.to_csv(embedder + "_rpo2_result.csv", index=False)
        print(embedder + f" auroc: {np.mean(auroc)}, auprc: {np.mean(auprc)}, fpr95: {np.mean(fpr95)}")

Best threshold(ROC): -5.47
Best threshold(ROC): -5.684
Best threshold(ROC): -4.73
Best threshold(ROC): -5.004
Best threshold(ROC): -5.152
Best threshold(ROC): -4.887
Best threshold(ROC): -5.055
Best threshold(ROC): -5.258
Best threshold(ROC): -4.465
Best threshold(ROC): -4.8
Best threshold(ROC): -5.84
Best threshold(ROC): -4.8
Best threshold(ROC): -5.137
Best threshold(ROC): -5.06
Best threshold(ROC): -4.707
Best threshold(ROC): -4.945
Best threshold(ROC): -4.82
Best threshold(ROC): -4.945
Best threshold(ROC): -5.812
Best threshold(ROC): -5.445
Best threshold(ROC): -4.223
Best threshold(ROC): -5.543
Best threshold(ROC): -5.68
Best threshold(ROC): -5.344
Best threshold(ROC): -5.344
Best threshold(ROC): -5.65
Best threshold(ROC): -5.65
Best threshold(ROC): -5.14
Best threshold(ROC): -4.76
Best threshold(ROC): -4.38
clip auroc: 0.5177729512777524, auprc: 0.5676631833058587, fpr95: 0.9370370370370369
Best threshold(ROC): -5.9389734
Best threshold(ROC): -5.422695
Best threshold(ROC): -5.149