In [1]:
import os
import random
import shutil
from pathlib import Path
import json
import math
import pandas as pd

from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
from PIL import Image
from torchvision import transforms

from kymatio.torch import Scattering2D
import pywt
from collections import defaultdict
import clip
from transformers import AutoImageProcessor, AutoModel, AutoProcessor, AutoModelForImageClassification

from scipy.special import softmax

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
from encoder import Encoder
from wst2vit import clip_encoder, dinov2_encoder
import time
from rpo import SparseProjector, GaussianProjector
import utils
from eval import Metrics
from padim import mahalanobis_detector, padim_detector

In [3]:
from patchcore import greedy_coreset_selection, coreset_detector, patchcore_detector

In [4]:
with open("classes.json", "r", encoding="utf-8") as f:
    data = json.load(f)
classes_idx = data["1k_idx"]
classes_names = data["21k_idx"]

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

## Encoding 

In [6]:
clip_encoder = Encoder("clip")
for cls in classes_idx:
    for generator in ["bgan", "midj", "sd_15", "nature", "nature_2"]:
        input_path = f"../Data/GenImage/{cls}/{generator}"
        save_path_clip = f"../Data/Features/clip/{cls}/{generator}.pt"
        os.makedirs(os.path.dirname(save_path_clip), exist_ok=True)
        tensor = clip_encoder(input_path)
        torch.save(tensor, save_path_clip)


  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
Clip Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.17it/s]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:06<00:00,  3.21s/it]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.56s/it]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.14s/it]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.04s/it]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.36it/s]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:06<00:00,  3.31s/it]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.58s/it]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.04it/s]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.01it/s]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.46it/s]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:06<00:00,  3.39s/it]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.45s/it]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.10s/it]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.02s/it]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.54it/s]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:06<00:00,  3.36s/it]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.47s/it]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.12it/s]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.08it/s]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.52it/s]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:06<00:00,  3.43s/it]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.28s/it]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.10it/s]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.08it/s]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.50it/s]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:06<00:00,  3.15s/it]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.47s/it]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.20it/s]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.05it/s]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.53it/s]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:06<00:00,  3.18s/it]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.47s/it]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.17it/s]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.07it/s]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.38it/s]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:06<00:00,  3.11s/it]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.47s/it]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.07s/it]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.01it/s]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.46it/s]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:07<00:00,  3.56s/it]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.44s/it]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.04s/it]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.08s/it]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.39it/s]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:06<00:00,  3.43s/it]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.45s/it]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.14it/s]


torch.Size([162, 512])


Clip Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.05it/s]

torch.Size([162, 512])





In [7]:
# clip_patch_encoder = Encoder("clip", patchify = True)
dinov2_patch_encoder = Encoder("dinov2", patchify = True)
# tc = {"clip": [], "dinov2": []}
for cls in classes_idx:
    for generator in ["bgan", "midj", "sd_15", "nature", "nature_2"]:
        input_path = f"../Data/GenImage/{cls}/{generator}"
        # save_path_clip = f"../Data/Features/clip_patch/{cls}/{generator}.pt"
        save_path_dinov2 = f"../Data/Features/dinov2_patch/{cls}/{generator}.pt"
        # os.makedirs(os.path.dirname(save_path_clip), exist_ok=True)
        os.makedirs(os.path.dirname(save_path_dinov2), exist_ok=True)
        # start = time.time()
        # tensor1 = clip_patch_encoder(input_path)
        # end = time.time()
        # tc["clip"].append(float(end - start))
        # torch.save(tensor1, save_path_clip)
        # start = time.time()
        tensor2 = dinov2_patch_encoder(input_path)
        # end = time.time()
        # tc["dinov2"].append(float(end - start))
        torch.save(tensor2, save_path_dinov2)



Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
  attn_output = torch.nn.functional.scaled_dot_product_attention(
Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.21s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:07<00:00,  3.67s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.76s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.61s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.47s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.12s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:08<00:00,  4.02s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.93s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.43s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.46s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.14s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:08<00:00,  4.17s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.76s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.48s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.41s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.10s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:08<00:00,  4.00s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.79s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.36s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.40s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.13s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:07<00:00,  3.90s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.68s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.37s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.39s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.11s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:07<00:00,  3.74s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.78s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.33s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.37s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.11s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:07<00:00,  3.87s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.81s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.37s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.37s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.14s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:07<00:00,  3.63s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.78s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.39s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.42s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.15s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:08<00:00,  4.04s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.80s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.50s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.48s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.15s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:08<00:00,  4.19s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.86s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.35s/it]


torch.Size([162, 256, 768])


Dinov2 Patches Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.43s/it]

torch.Size([162, 256, 768])





In [6]:
clip_time = np.mean(tc["clip"])
dinov2_time = np.mean(tc["dinov2"])


In [7]:
print(clip_time, dinov2_time)

3.2608294677734375 3.895852737426758


In [11]:
dinov2_mean_encoder = Encoder("dinov2")
tc = {"dinov2": []}
for cls in classes_idx:
    for generator in ["bgan", "midj", "sd_15", "nature", "nature_2"]:
        input_path = f"../Data/GenImage/{cls}/{generator}"
        save_path_dinov2 = f"../Data/Features/dinov2_mean/{cls}/{generator}.pt"
        os.makedirs(os.path.dirname(save_path_dinov2), exist_ok=True)
        start = time.time()
        tensor2 = dinov2_mean_encoder(input_path)
        end = time.time()
        tc["dinov2"].append(float(end - start))
        torch.save(tensor2, save_path_dinov2)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.66s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:08<00:00,  4.45s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:04<00:00,  2.36s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:04<00:00,  2.05s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.91s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.55s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:09<00:00,  4.55s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:04<00:00,  2.41s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.87s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.92s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.52s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:09<00:00,  4.63s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:04<00:00,  2.33s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:04<00:00,  2.02s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.96s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.53s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:09<00:00,  4.59s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:04<00:00,  2.33s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.82s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.92s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.57s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:09<00:00,  4.84s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:04<00:00,  2.50s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.90s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.94s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.56s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:09<00:00,  4.55s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:04<00:00,  2.44s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.81s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.94s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.58s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:09<00:00,  4.67s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:04<00:00,  2.35s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.83s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.81s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.52s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:08<00:00,  4.31s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:04<00:00,  2.30s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.84s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.86s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.53s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:09<00:00,  4.71s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:04<00:00,  2.28s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.89s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.92s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.56s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:09<00:00,  4.62s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:04<00:00,  2.38s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.78s/it]


torch.Size([162, 768])


Dinov2 Direct Encoding: 100%|██████████| 2/2 [00:03<00:00,  1.86s/it]

torch.Size([162, 768])





In [12]:
print(np.mean(tc["dinov2"]))

4.94347270488739


In [8]:
wst2_encoder = Encoder("wst", J=2)
wst3_encoder = Encoder("wst", J=3)
tc = {"wst2": [], "wst3": []}
for cls in classes_idx:
    for generator in ["bgan", "midj", "sd_15", "nature", "nature_2"]:
        input_path = f"../Data/GenImage/{cls}/{generator}"
        save_path_wst2 = f"../Data/Features/wst2/{cls}/{generator}.pt"
        save_path_wst3 = f"../Data/Features/wst3/{cls}/{generator}.pt"
        os.makedirs(os.path.dirname(save_path_wst2), exist_ok=True)
        os.makedirs(os.path.dirname(save_path_wst3), exist_ok=True)
        start = time.time()
        tensor1 = wst2_encoder(input_path)
        end = time.time()
        tc["wst2"].append(float(end - start))
        torch.save(tensor1, save_path_wst2)
        start = time.time()
        tensor2 = wst3_encoder(input_path)
        end = time.time()
        tc["wst3"].append(float(end - start))
        torch.save(tensor2, save_path_wst3)



WST-2 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.90it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:00<00:00,  2.24it/s]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:04<00:00,  2.35s/it]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:04<00:00,  2.39s/it]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.02it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.05s/it]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.66it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.48it/s]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.84it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.55it/s]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:00<00:00,  3.10it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:00<00:00,  2.40it/s]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:05<00:00,  2.61s/it]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:05<00:00,  2.60s/it]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.01s/it]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.06s/it]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.92it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.59it/s]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:00<00:00,  2.01it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.67it/s]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:00<00:00,  3.07it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:00<00:00,  2.41it/s]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:05<00:00,  2.51s/it]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:05<00:00,  2.68s/it]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.04it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.05s/it]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.76it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.45it/s]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.82it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.57it/s]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:00<00:00,  2.94it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:00<00:00,  2.31it/s]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:05<00:00,  2.60s/it]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:05<00:00,  2.64s/it]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.02it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.04s/it]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:00<00:00,  2.10it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.72it/s]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.95it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.65it/s]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:00<00:00,  2.98it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:00<00:00,  2.29it/s]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:05<00:00,  2.60s/it]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:05<00:00,  2.71s/it]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.26it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.18it/s]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:00<00:00,  2.06it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.77it/s]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:00<00:00,  2.01it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.77it/s]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:00<00:00,  3.15it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:00<00:00,  2.44it/s]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:04<00:00,  2.20s/it]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:04<00:00,  2.28s/it]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.11it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.07s/it]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:00<00:00,  2.18it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.95it/s]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:00<00:00,  2.16it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.74it/s]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:00<00:00,  3.10it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:00<00:00,  2.41it/s]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:04<00:00,  2.35s/it]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:04<00:00,  2.49s/it]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.04it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.01it/s]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:00<00:00,  2.07it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.80it/s]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:00<00:00,  2.10it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.78it/s]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:00<00:00,  2.74it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:00<00:00,  2.31it/s]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:04<00:00,  2.42s/it]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:04<00:00,  2.33s/it]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.12it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.02it/s]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.98it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.66it/s]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.99it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.68it/s]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:00<00:00,  2.91it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:00<00:00,  2.30it/s]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:05<00:00,  2.78s/it]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:05<00:00,  2.77s/it]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.02it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.05s/it]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.71it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.52it/s]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.73it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.64it/s]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:00<00:00,  2.87it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:00<00:00,  2.32it/s]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:05<00:00,  2.77s/it]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:05<00:00,  2.72s/it]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.04it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:02<00:00,  1.06s/it]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:00<00:00,  2.14it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.88it/s]


torch.Size([162, 217, 32, 32])


WST-2 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.94it/s]


torch.Size([162, 81, 64, 64])


WST-3 Direct Encoding: 100%|██████████| 2/2 [00:01<00:00,  1.68it/s]


torch.Size([162, 217, 32, 32])


In [9]:
print(np.mean(tc["wst2"]), np.mean(tc["wst3"]))

2.0060737657547 2.12442666053772


## Clip, Dinov2-mean recalculation with Mahalanobis

In [7]:
data = {"CLASS": [], "GENERATOR": [], "AUROC": [], "AUPRC": [], "FPR95": []}
for cls in classes_idx:
    gt_tensor = torch.load(f"../Data/Features/clip/{cls}/nature.pt",weights_only = True).to(DEVICE)
    real_tensor = torch.load(f"../Data/Features/clip/{cls}/nature_2.pt", weights_only = True).to(DEVICE)
    for generator in ["bgan", "midj", "sd_15"]:
        fake_tensor = torch.load(f"../Data/Features/clip/{cls}/{generator}.pt", weights_only = True).to(DEVICE)
        labels = np.concatenate((np.zeros(real_tensor.shape[0]), np.ones(fake_tensor.shape[0])))
        scores = mahalanobis_detector(gt_tensor, real_tensor, fake_tensor)
        m = Metrics(labels, scores)
        r,p,f = m.computation()
        data["CLASS"].append(cls)
        data["GENERATOR"].append(generator)
        data["AUROC"].append(r)
        data["AUPRC"].append(p)
        data["FPR95"].append(f)

df = pd.DataFrame(data)
df.to_csv("results_v2/clip_mahalanobis.csv", index=False)
print(f"clip mahalanobis auroc: {np.mean(data['AUROC'])}, auprc: {np.mean(data['AUPRC'])}, fpr95: {np.mean(data['FPR95'])}")

Mahalanobis scoring:: 100%|██████████| 324/324 [00:00<00:00, 354.41it/s]
Mahalanobis scoring:: 100%|██████████| 324/324 [00:00<00:00, 533.22it/s]
Mahalanobis scoring:: 100%|██████████| 324/324 [00:00<00:00, 522.17it/s]
Mahalanobis scoring:: 100%|██████████| 324/324 [00:00<00:00, 505.11it/s]
Mahalanobis scoring:: 100%|██████████| 324/324 [00:00<00:00, 533.59it/s]
Mahalanobis scoring:: 100%|██████████| 324/324 [00:00<00:00, 553.04it/s]
Mahalanobis scoring:: 100%|██████████| 324/324 [00:00<00:00, 503.18it/s]
Mahalanobis scoring:: 100%|██████████| 324/324 [00:00<00:00, 512.24it/s]
Mahalanobis scoring:: 100%|██████████| 324/324 [00:00<00:00, 509.44it/s]
Mahalanobis scoring:: 100%|██████████| 324/324 [00:00<00:00, 506.21it/s]
Mahalanobis scoring:: 100%|██████████| 324/324 [00:00<00:00, 489.71it/s]
Mahalanobis scoring:: 100%|██████████| 324/324 [00:00<00:00, 470.90it/s]
Mahalanobis scoring:: 100%|██████████| 324/324 [00:00<00:00, 498.86it/s]
Mahalanobis scoring:: 100%|██████████| 324/324 [00:

clip mahalanobis auroc: 0.5141365645480872, auprc: 0.512197574958008, fpr95: 0.8090534979423866





In [10]:
data = {"CLASS": [], "GENERATOR": [], "AUROC": [], "AUPRC": [], "FPR95": []}
for cls in classes_idx:
    gt_tensor = torch.load(f"../Data/Features/dinov2_mean/{cls}/nature.pt",weights_only = True).to(DEVICE)
    real_tensor = torch.load(f"../Data/Features/dinov2_mean/{cls}/nature_2.pt", weights_only = True).to(DEVICE)
    for generator in ["bgan", "midj", "sd_15"]:
        fake_tensor = torch.load(f"../Data/Features/dinov2_mean/{cls}/{generator}.pt", weights_only = True).to(DEVICE)
        labels = np.concatenate((np.zeros(real_tensor.shape[0]), np.ones(fake_tensor.shape[0])))
        scores = mahalanobis_detector(gt_tensor, real_tensor, fake_tensor)
        m = Metrics(labels, scores)
        r,p,f = m.computation()
        data["CLASS"].append(cls)
        data["GENERATOR"].append(generator)
        data["AUROC"].append(r)
        data["AUPRC"].append(p)
        data["FPR95"].append(f)

df = pd.DataFrame(data)
df.to_csv("results_v2/dinov2_mean_mahalanobis.csv", index=False)
print(f"dinov2 mean mahalanobis auroc: {np.mean(data['AUROC'])}, auprc: {np.mean(data['AUPRC'])}, fpr95: {np.mean(data['FPR95'])}")

Mahalanobis scoring:: 100%|██████████| 324/324 [00:01<00:00, 311.39it/s]
Mahalanobis scoring:: 100%|██████████| 324/324 [00:00<00:00, 448.36it/s]
Mahalanobis scoring:: 100%|██████████| 324/324 [00:00<00:00, 443.89it/s]
Mahalanobis scoring:: 100%|██████████| 324/324 [00:00<00:00, 451.80it/s]
Mahalanobis scoring:: 100%|██████████| 324/324 [00:00<00:00, 439.50it/s]
Mahalanobis scoring:: 100%|██████████| 324/324 [00:00<00:00, 430.63it/s]
Mahalanobis scoring:: 100%|██████████| 324/324 [00:00<00:00, 438.21it/s]
Mahalanobis scoring:: 100%|██████████| 324/324 [00:00<00:00, 439.61it/s]
Mahalanobis scoring:: 100%|██████████| 324/324 [00:00<00:00, 435.68it/s]
Mahalanobis scoring:: 100%|██████████| 324/324 [00:00<00:00, 427.94it/s]
Mahalanobis scoring:: 100%|██████████| 324/324 [00:00<00:00, 436.21it/s]
Mahalanobis scoring:: 100%|██████████| 324/324 [00:00<00:00, 436.82it/s]
Mahalanobis scoring:: 100%|██████████| 324/324 [00:00<00:00, 461.27it/s]
Mahalanobis scoring:: 100%|██████████| 324/324 [00:

dinov2 mean mahalanobis auroc: 0.6954757912919779, auprc: 0.6718992831746406, fpr95: 0.6034979423868312





## Clip, Dinov2, WST-3 recalculation with Memory Coreset and PatchCore

### Memory Coreset

In [None]:
data = {"CLASS": [], "GENERATOR": [], "AUROC": [], "AUPRC": [], "FPR95": []}
for cls in classes_idx:
    gt_tensor = torch.load(f"../Data/Features/clip/{cls}/nature.pt",weights_only = True).to(DEVICE)
    real_tensor = torch.load(f"../Data/Features/clip/{cls}/nature_2.pt", weights_only = True).to(DEVICE)
    memory_coreset, m_idx = greedy_coreset_selection(gt_tensor, l=0.5)
    for generator in ["bgan", "midj", "sd_15"]:
        fake_tensor = torch.load(f"../Data/Features/clip/{cls}/{generator}.pt", weights_only = True).to(DEVICE)
        labels = np.concatenate((np.zeros(real_tensor.shape[0]), np.ones(fake_tensor.shape[0])))
        scores = coreset_detector(memory_coreset, real_tensor, fake_tensor)
        m = Metrics(labels, scores)
        r,p,f = m.computation()
        data["CLASS"].append(cls)
        data["GENERATOR"].append(generator)
        data["AUROC"].append(r)
        data["AUPRC"].append(p)
        data["FPR95"].append(f)

df = pd.DataFrame(data)
df.to_csv("results_v2/clip_coreset.csv", index=False)
print(f"clip coreset auroc: {np.mean(data['AUROC'])}, auprc: {np.mean(data['AUPRC'])}, fpr95: {np.mean(data['FPR95'])}")

clip coreset auroc: 0.5391962607326118, auprc: 0.532540669358274, fpr95: 0.8164609053497943


In [9]:
data = {"CLASS": [], "GENERATOR": [], "AUROC": [], "AUPRC": [], "FPR95": []}
for cls in classes_idx:
    gt_tensor = torch.load(f"../Data/Features/dinov2_mean/{cls}/nature.pt",weights_only = True).to(DEVICE)
    real_tensor = torch.load(f"../Data/Features/dinov2_mean/{cls}/nature_2.pt", weights_only = True).to(DEVICE)
    memory_coreset, m_idx = greedy_coreset_selection(gt_tensor, l=0.5)
    for generator in ["bgan", "midj", "sd_15"]:
        fake_tensor = torch.load(f"../Data/Features/dinov2_mean/{cls}/{generator}.pt", weights_only = True).to(DEVICE)
        labels = np.concatenate((np.zeros(real_tensor.shape[0]), np.ones(fake_tensor.shape[0])))
        scores = coreset_detector(memory_coreset, real_tensor, fake_tensor)
        m = Metrics(labels, scores)
        r,p,f = m.computation()
        data["CLASS"].append(cls)
        data["GENERATOR"].append(generator)
        data["AUROC"].append(r)
        data["AUPRC"].append(p)
        data["FPR95"].append(f)

df = pd.DataFrame(data)
df.to_csv("results_v2/dinov2_mean_coreset.csv", index=False)
print(f"dinov2 mean coreset auroc: {np.mean(data['AUROC'])}, auprc: {np.mean(data['AUPRC'])}, fpr95: {np.mean(data['FPR95'])}")

dinov2 mean coreset auroc: 0.7117309353248997, auprc: 0.6806247266903488, fpr95: 0.6123456790123457


In [10]:
data = {"CLASS": [], "GENERATOR": [], "AUROC": [], "AUPRC": [], "FPR95": []}
for cls in classes_idx:
    proj_s = SparseProjector(217, 16)
    gt_tensor = torch.load(f"../Data/Features/wst3/{cls}/nature.pt", weights_only = True).to(DEVICE)
    real_tensor = torch.load(f"../Data/Features/wst3/{cls}/nature_2.pt", weights_only = True).to(DEVICE)
    gt_tensor = torch.nn.functional.adaptive_avg_pool2d(gt_tensor, output_size=(16, 16))
    real_tensor = torch.nn.functional.adaptive_avg_pool2d(real_tensor, output_size=(16, 16))
    gt_tensor = utils.wst_m(gt_tensor, proj_s)
    real_tensor = utils.wst_m(real_tensor, proj_s)
    memory_coreset, m_idx = greedy_coreset_selection(gt_tensor, l=0.5)
    for generator in ["bgan", "midj", "sd_15"]:
        fake_tensor = torch.load(f"../Data/Features/wst3/{cls}/{generator}.pt", weights_only = True).to(DEVICE)
        fake_tensor = torch.nn.functional.adaptive_avg_pool2d(fake_tensor, output_size=(16, 16))
        fake_tensor = utils.wst_m(fake_tensor, proj_s)
        labels = np.concatenate((np.zeros(real_tensor.shape[0]), np.ones(fake_tensor.shape[0])))
        scores = coreset_detector(memory_coreset, real_tensor, fake_tensor)
        m = Metrics(labels, scores)
        r,p,f = m.computation()
        data["CLASS"].append(cls)
        data["GENERATOR"].append(generator)
        data["AUROC"].append(r)
        data["AUPRC"].append(p)
        data["FPR95"].append(f)

df = pd.DataFrame(data)
df.to_csv("results_v2/wst3_sparse_coreset.csv", index=False)
print(f"wst3 sparse coreset auroc: {np.mean(data['AUROC'])}, auprc: {np.mean(data['AUPRC'])}, fpr95: {np.mean(data['FPR95'])}")

wst3 sparse coreset auroc: 0.44994157394706086, auprc: 0.4743008738162199, fpr95: 0.9211934156378601


### PatchCore

In [None]:
data = {"CLASS": [], "GENERATOR": [], "AUROC": [], "AUPRC": [], "FPR95": []}
projector = SparseProjector(768,100)
for cls in classes_idx:
    gt_patches = torch.load(f"../Data/Features/clip_patch/{cls}/nature.pt",weights_only = True).to(DEVICE)
    real_patches = torch.load(f"../Data/Features/clip_patch/{cls}/nature_2.pt", weights_only = True).to(DEVICE)
    gt_patches = gt_patches.reshape(-1, gt_patches.shape[-1])
    gt_patches_proj = projector.project(gt_patches)
    coreset, indices = greedy_coreset_selection(gt_patches_proj, l=0.5)
    memory_coreset = gt_patches[indices]
    for generator in ["bgan", "midj", "sd_15"]:
        fake_patches = torch.load(f"../Data/Features/clip_patch/{cls}/{generator}.pt", weights_only = True).to(DEVICE)
        labels = np.concatenate((np.zeros(real_patches.shape[0]), np.ones(fake_patches.shape[0])))
        scores = patchcore_detector(memory_coreset, real_patches, fake_patches)
        m = Metrics(labels, scores)
        r,p,f = m.computation()
        data["CLASS"].append(cls)
        data["GENERATOR"].append(generator)
        data["AUROC"].append(r)
        data["AUPRC"].append(p)
        data["FPR95"].append(f)

df = pd.DataFrame(data)
df.to_csv("results_v2/clip_patchcore.csv", index=False)
print(f"clip patchcore auroc: {np.mean(data['AUROC'])}, auprc: {np.mean(data['AUPRC'])}, fpr95: {np.mean(data['FPR95'])}")

clip patchcore auroc: 0.38810839302951794, auprc: 0.42072097432556643, fpr95: 0.9390946502057613


In [12]:
data = {"CLASS": [], "GENERATOR": [], "AUROC": [], "AUPRC": [], "FPR95": []}
projector = SparseProjector(768,100)
for cls in classes_idx:
    gt_patches = torch.load(f"../Data/Features/dinov2_patch/{cls}/nature.pt",weights_only = True).to(DEVICE)
    real_patches = torch.load(f"../Data/Features/dinov2_patch/{cls}/nature_2.pt", weights_only = True).to(DEVICE)
    gt_patches = gt_patches.reshape(-1, gt_patches.shape[-1])
    gt_patches_proj = projector.project(gt_patches)
    coreset, indices = greedy_coreset_selection(gt_patches_proj, l=0.5)
    memory_coreset = gt_patches[indices]
    for generator in ["bgan", "midj", "sd_15"]:
        fake_patches = torch.load(f"../Data/Features/dinov2_patch/{cls}/{generator}.pt", weights_only = True).to(DEVICE)
        labels = np.concatenate((np.zeros(real_patches.shape[0]), np.ones(fake_patches.shape[0])))
        scores = patchcore_detector(memory_coreset, real_patches, fake_patches)
        m = Metrics(labels, scores)
        r,p,f = m.computation()
        data["CLASS"].append(cls)
        data["GENERATOR"].append(generator)
        data["AUROC"].append(r)
        data["AUPRC"].append(p)
        data["FPR95"].append(f)

df = pd.DataFrame(data)
df.to_csv("results_v2/dinov2_patchcore.csv", index=False)
print(f"dinov2 patchcore auroc: {np.mean(data['AUROC'])}, auprc: {np.mean(data['AUPRC'])}, fpr95: {np.mean(data['FPR95'])}")

dinov2 patchcore auroc: 0.5911281308743586, auprc: 0.5762045610478104, fpr95: 0.7818930041152262


## WST 2&3 -Mahalanobis

In [None]:
data = {"CLASS": [], "GENERATOR": [], "AUROC": [], "AUPRC": [], "FPR95": []}
for cls in classes_idx:
    proj_s = SparseProjector(81, 8)
    gt_tensor = torch.load(f"../Data/Features/wst2/{cls}/nature.pt", weights_only = True).to(DEVICE)
    real_tensor = torch.load(f"../Data/Features/wst2/{cls}/nature_2.pt", weights_only = True).to(DEVICE)
    gt_tensor = torch.nn.functional.adaptive_avg_pool2d(gt_tensor, output_size=(16, 16))
    real_tensor = torch.nn.functional.adaptive_avg_pool2d(real_tensor, output_size=(16, 16))
    gt_tensor = utils.wst_m(gt_tensor, proj_s)
    real_tensor = utils.wst_m(real_tensor, proj_s)
    for generator in ["bgan", "midj", "sd_15"]:
        fake_tensor = torch.load(f"../Data/Features/wst2/{cls}/{generator}.pt", weights_only = True).to(DEVICE)
        fake_tensor = torch.nn.functional.adaptive_avg_pool2d(fake_tensor, output_size=(16, 16))
        fake_tensor = utils.wst_m(fake_tensor, proj_s)
        labels = np.concatenate((np.zeros(real_tensor.shape[0]), np.ones(fake_tensor.shape[0])))
        scores = mahalanobis_detector(gt_tensor, real_tensor, fake_tensor)
        m = Metrics(labels, scores)
        r,p,f = m.computation()
        data["CLASS"].append(cls)
        data["GENERATOR"].append(generator)
        data["AUROC"].append(r)
        data["AUPRC"].append(p)
        data["FPR95"].append(f)

df = pd.DataFrame(data)
df.to_csv("results_v2/wst2_sparse_mahalanobis.csv", index=False)
print(f"wst2 mahalanobis auroc: {np.mean(data['AUROC'])}, auprc: {np.mean(data['AUPRC'])}, fpr95: {np.mean(data['FPR95'])}")

        



wst2 mahalanobis auroc: 0.4254648681603414, auprc: 0.47431262034051974, fpr95: 0.9049382716049384


In [22]:
data = {"CLASS": [], "GENERATOR": [], "AUROC": [], "AUPRC": [], "FPR95": []}
for cls in classes_idx:
    proj_s = SparseProjector(217, 16)
    gt_tensor = torch.load(f"../Data/Features/wst3/{cls}/nature.pt", weights_only = True).to(DEVICE)
    real_tensor = torch.load(f"../Data/Features/wst3/{cls}/nature_2.pt", weights_only = True).to(DEVICE)
    gt_tensor = torch.nn.functional.adaptive_avg_pool2d(gt_tensor, output_size=(16, 16))
    real_tensor = torch.nn.functional.adaptive_avg_pool2d(real_tensor, output_size=(16, 16))
    gt_tensor = utils.wst_m(gt_tensor, proj_s)
    real_tensor = utils.wst_m(real_tensor, proj_s)
    for generator in ["bgan", "midj", "sd_15"]:
        fake_tensor = torch.load(f"../Data/Features/wst3/{cls}/{generator}.pt", weights_only = True).to(DEVICE)
        fake_tensor = torch.nn.functional.adaptive_avg_pool2d(fake_tensor, output_size=(16, 16))
        fake_tensor = utils.wst_m(fake_tensor, proj_s)
        labels = np.concatenate((np.zeros(real_tensor.shape[0]), np.ones(fake_tensor.shape[0])))
        scores = mahalanobis_detector(gt_tensor, real_tensor, fake_tensor)
        m = Metrics(labels, scores)
        r,p,f = m.computation()
        data["CLASS"].append(cls)
        data["GENERATOR"].append(generator)
        data["AUROC"].append(r)
        data["AUPRC"].append(p)
        data["FPR95"].append(f)

df = pd.DataFrame(data)
df.to_csv("results_v2/wst3_sparse_mahalanobis.csv", index=False)
print(f"wst3 sparse mahalanobis auroc: {np.mean(data['AUROC'])}, auprc: {np.mean(data['AUPRC'])}, fpr95: {np.mean(data['FPR95'])}")

wst3 sparse mahalanobis auroc: 0.4263742823756541, auprc: 0.48064977056659886, fpr95: 0.8965020576131687


In [23]:
data = {"CLASS": [], "GENERATOR": [], "AUROC": [], "AUPRC": [], "FPR95": []}
for cls in classes_idx:
    proj_g = GaussianProjector(81,8)
    gt_tensor = torch.load(f"../Data/Features/wst2/{cls}/nature.pt", weights_only = True).to(DEVICE)
    real_tensor = torch.load(f"../Data/Features/wst2/{cls}/nature_2.pt", weights_only = True).to(DEVICE)
    gt_tensor = torch.nn.functional.adaptive_avg_pool2d(gt_tensor, output_size=(16, 16))
    real_tensor = torch.nn.functional.adaptive_avg_pool2d(real_tensor, output_size=(16, 16))
    gt_tensor = utils.wst_m(gt_tensor, proj_g)
    real_tensor = utils.wst_m(real_tensor, proj_g)
    for generator in ["bgan", "midj", "sd_15"]:
        fake_tensor = torch.load(f"../Data/Features/wst2/{cls}/{generator}.pt", weights_only = True).to(DEVICE)
        fake_tensor = torch.nn.functional.adaptive_avg_pool2d(fake_tensor, output_size=(16, 16))
        fake_tensor = utils.wst_m(fake_tensor, proj_g)
        labels = np.concatenate((np.zeros(real_tensor.shape[0]), np.ones(fake_tensor.shape[0])))
        scores = mahalanobis_detector(gt_tensor, real_tensor, fake_tensor)
        m = Metrics(labels, scores)
        r,p,f = m.computation()
        data["CLASS"].append(cls)
        data["GENERATOR"].append(generator)
        data["AUROC"].append(r)
        data["AUPRC"].append(p)
        data["FPR95"].append(f)

df = pd.DataFrame(data)
df.to_csv("results_v2/wst2_gaussian_mahalanobis.csv", index=False)
print(f"wst2 gaussian mahalanobis auroc: {np.mean(data['AUROC'])}, auprc: {np.mean(data['AUPRC'])}, fpr95: {np.mean(data['FPR95'])}")

wst2 gaussian mahalanobis auroc: 0.42860844383478136, auprc: 0.4762146218749842, fpr95: 0.905349794238683


In [24]:
data = {"CLASS": [], "GENERATOR": [], "AUROC": [], "AUPRC": [], "FPR95": []}
for cls in classes_idx:
    proj_g = GaussianProjector(217,16)
    gt_tensor = torch.load(f"../Data/Features/wst3/{cls}/nature.pt", weights_only = True).to(DEVICE)
    real_tensor = torch.load(f"../Data/Features/wst3/{cls}/nature_2.pt", weights_only = True).to(DEVICE)
    gt_tensor = torch.nn.functional.adaptive_avg_pool2d(gt_tensor, output_size=(16, 16))
    real_tensor = torch.nn.functional.adaptive_avg_pool2d(real_tensor, output_size=(16, 16))
    gt_tensor = utils.wst_m(gt_tensor, proj_g)
    real_tensor = utils.wst_m(real_tensor, proj_g)
    for generator in ["bgan", "midj", "sd_15"]:
        fake_tensor = torch.load(f"../Data/Features/wst3/{cls}/{generator}.pt", weights_only = True).to(DEVICE)
        fake_tensor = torch.nn.functional.adaptive_avg_pool2d(fake_tensor, output_size=(16, 16))
        fake_tensor = utils.wst_m(fake_tensor, proj_g)
        labels = np.concatenate((np.zeros(real_tensor.shape[0]), np.ones(fake_tensor.shape[0])))
        scores = mahalanobis_detector(gt_tensor, real_tensor, fake_tensor)
        m = Metrics(labels, scores)
        r,p,f = m.computation()
        data["CLASS"].append(cls)
        data["GENERATOR"].append(generator)
        data["AUROC"].append(r)
        data["AUPRC"].append(p)
        data["FPR95"].append(f)

df = pd.DataFrame(data)
df.to_csv("results_v2/wst3_gaussian_mahalanobis.csv", index=False)
print(f"wst3 gaussian mahalanobis auroc: {np.mean(data['AUROC'])}, auprc: {np.mean(data['AUPRC'])}, fpr95: {np.mean(data['FPR95'])}")

wst3 gaussian mahalanobis auroc: 0.42946959304984006, auprc: 0.4792573234349718, fpr95: 0.8997942386831274


## WST2VIT Mahalanobis

In [8]:
dinov2_model = AutoModel.from_pretrained("facebook/dinov2-with-registers-base").to(DEVICE)
data = {"CLASS": [], "GENERATOR": [], "AUROC": [], "AUPRC": [], "FPR95": []}

for cls in classes_idx:
    proj_s = SparseProjector(81, 48)

    gt_tensor = torch.load(f"../Data/Features/wst2/{cls}/nature.pt", weights_only = True).to(DEVICE)
    real_tensor = torch.load(f"../Data/Features/wst2/{cls}/nature_2.pt", weights_only = True).to(DEVICE)
    gt_tensor = utils.wst_m(gt_tensor, proj_s, False)
    real_tensor = utils.wst_m(real_tensor, proj_s, False)
    gt_tensor = utils.reshape_to_3(gt_tensor)
    real_tensor = utils.reshape_to_3(real_tensor)
    gt_tensor = dinov2_encoder(gt_tensor, dinov2_model)
    real_tensor = dinov2_encoder(real_tensor, dinov2_model)
    os.makedirs(f"../Data/Features/wst2dinov2/{cls}", exist_ok=True)
    torch.save(gt_tensor.detach().cpu(), f"../Data/Features/wst2dinov2/{cls}/nature.pt")
    torch.save(real_tensor.detach().cpu(), f"../Data/Features/wst2dinov2/{cls}/nature_2.pt")

    for generator in ["bgan", "midj", "sd_15"]:
        fake_tensor = torch.load(f"../Data/Features/wst2/{cls}/{generator}.pt", weights_only = True).to(DEVICE)
        fake_tensor = utils.wst_m(fake_tensor, proj_s, False)
        fake_tensor = utils.reshape_to_3(fake_tensor)
        fake_tensor = dinov2_encoder(fake_tensor, dinov2_model)
        torch.save(fake_tensor.detach().cpu(), f"../Data/Features/wst2dinov2/{cls}/{generator}.pt")

        labels = np.concatenate((np.zeros(real_tensor.shape[0]), np.ones(fake_tensor.shape[0])))
        scores = mahalanobis_detector(gt_tensor, real_tensor, fake_tensor)
        m = Metrics(labels, scores)
        r,p,f = m.computation()
        data["CLASS"].append(cls)
        data["GENERATOR"].append(generator)
        data["AUROC"].append(r)
        data["AUPRC"].append(p)
        data["FPR95"].append(f)

df = pd.DataFrame(data)
df.to_csv("results_v2/wst2dinov2_mahalanobis.csv", index=False)
print(f"wst2dinov2 mahalanobis auroc: {np.mean(data['AUROC'])}, auprc: {np.mean(data['AUPRC'])}, fpr95: {np.mean(data['FPR95'])}")

  attn_output = torch.nn.functional.scaled_dot_product_attention(
WSTDinov2 Encoding: 100%|██████████| 3/3 [00:01<00:00,  1.69it/s]
WSTDinov2 Encoding: 100%|██████████| 3/3 [00:01<00:00,  1.86it/s]
WSTDinov2 Encoding: 100%|██████████| 3/3 [00:01<00:00,  1.84it/s]
Mahalanobis scoring:: 100%|██████████| 324/324 [00:00<00:00, 388.26it/s]
WSTDinov2 Encoding: 100%|██████████| 3/3 [00:01<00:00,  1.84it/s]
Mahalanobis scoring:: 100%|██████████| 324/324 [00:00<00:00, 431.58it/s]
WSTDinov2 Encoding: 100%|██████████| 3/3 [00:01<00:00,  1.84it/s]
Mahalanobis scoring:: 100%|██████████| 324/324 [00:00<00:00, 419.95it/s]
WSTDinov2 Encoding: 100%|██████████| 3/3 [00:01<00:00,  1.83it/s]
WSTDinov2 Encoding: 100%|██████████| 3/3 [00:01<00:00,  1.84it/s]
WSTDinov2 Encoding: 100%|██████████| 3/3 [00:01<00:00,  1.85it/s]
Mahalanobis scoring:: 100%|██████████| 324/324 [00:00<00:00, 433.12it/s]
WSTDinov2 Encoding: 100%|██████████| 3/3 [00:01<00:00,  1.88it/s]
Mahalanobis scoring:: 100%|██████████| 324/324 [

wst2dinov2 mahalanobis auroc: 0.5155438703449677, auprc: 0.530614766708069, fpr95: 0.9125514403292183





In [10]:
data = {"CLASS": [], "GENERATOR": [], "AUROC": [], "AUPRC": [], "FPR95": []}
clip_model, _ = clip.load("ViT-B/32", device=DEVICE)

for cls in classes_idx:
    proj_s = SparseProjector(81, 48)

    gt_tensor = torch.load(f"../Data/Features/wst2/{cls}/nature.pt", weights_only = True).to(DEVICE)
    real_tensor = torch.load(f"../Data/Features/wst2/{cls}/nature_2.pt", weights_only = True).to(DEVICE)
    gt_tensor = utils.wst_m(gt_tensor, proj_s, False)
    real_tensor = utils.wst_m(real_tensor, proj_s, False)
    gt_tensor = utils.reshape_to_3(gt_tensor)
    real_tensor = utils.reshape_to_3(real_tensor)
    gt_tensor = clip_encoder(gt_tensor, clip_model)
    real_tensor = clip_encoder(real_tensor, clip_model)
    os.makedirs(f"../Data/Features/wst2clip/{cls}", exist_ok=True)
    torch.save(gt_tensor.detach().cpu(), f"../Data/Features/wst2clip/{cls}/nature.pt")
    torch.save(real_tensor.detach().cpu(), f"../Data/Features/wst2clip/{cls}/nature_2.pt")

    for generator in ["bgan", "midj", "sd_15"]:
        fake_tensor = torch.load(f"../Data/Features/wst2/{cls}/{generator}.pt", weights_only = True).to(DEVICE)
        fake_tensor = utils.wst_m(fake_tensor, proj_s, False)
        fake_tensor = utils.reshape_to_3(fake_tensor)
        fake_tensor = clip_encoder(fake_tensor, clip_model)
        torch.save(fake_tensor.detach().cpu(), f"../Data/Features/wst2clip/{cls}/{generator}.pt")

        labels = np.concatenate((np.zeros(real_tensor.shape[0]), np.ones(fake_tensor.shape[0])))
        scores = mahalanobis_detector(gt_tensor, real_tensor, fake_tensor)
        m = Metrics(labels, scores)
        r,p,f = m.computation()
        data["CLASS"].append(cls)
        data["GENERATOR"].append(generator)
        data["AUROC"].append(r)
        data["AUPRC"].append(p)
        data["FPR95"].append(f)

df = pd.DataFrame(data)
df.to_csv("results_v2/wst2clip_mahalanobis.csv", index=False)
print(f"wst2clip mahalanobis auroc: {np.mean(data['AUROC'])}, auprc: {np.mean(data['AUPRC'])}, fpr95: {np.mean(data['FPR95'])}")

WSTClip Encoding: 100%|██████████| 3/3 [00:00<00:00, 18.73it/s]
WSTClip Encoding: 100%|██████████| 3/3 [00:00<00:00, 24.37it/s]
WSTClip Encoding: 100%|██████████| 3/3 [00:00<00:00, 23.96it/s]
Mahalanobis scoring:: 100%|██████████| 324/324 [00:00<00:00, 488.18it/s]
WSTClip Encoding: 100%|██████████| 3/3 [00:00<00:00, 23.85it/s]
Mahalanobis scoring:: 100%|██████████| 324/324 [00:00<00:00, 509.02it/s]
WSTClip Encoding: 100%|██████████| 3/3 [00:00<00:00, 24.36it/s]
Mahalanobis scoring:: 100%|██████████| 324/324 [00:00<00:00, 526.92it/s]
WSTClip Encoding: 100%|██████████| 3/3 [00:00<00:00, 22.84it/s]
WSTClip Encoding: 100%|██████████| 3/3 [00:00<00:00, 23.98it/s]
WSTClip Encoding: 100%|██████████| 3/3 [00:00<00:00, 23.70it/s]
Mahalanobis scoring:: 100%|██████████| 324/324 [00:00<00:00, 516.70it/s]
WSTClip Encoding: 100%|██████████| 3/3 [00:00<00:00, 23.75it/s]
Mahalanobis scoring:: 100%|██████████| 324/324 [00:00<00:00, 519.96it/s]
WSTClip Encoding: 100%|██████████| 3/3 [00:00<00:00, 23.37i

wst2clip mahalanobis auroc: 0.5798379312096734, auprc: 0.5807227509781118, fpr95: 0.8676954732510289



