In [88]:
import os
import json
import pandas as pd
import math
from pathlib import Path
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
from PIL import Image
from torchvision import transforms


from kymatio.torch import Scattering2D
import pywt
from collections import defaultdict


import clip
from transformers import AutoImageProcessor, AutoModel, AutoProcessor, AutoModelForImageClassification
import torch.nn.functional as F
from sklearn.random_projection import SparseRandomProjection
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc, roc_curve
from scipy.special import softmax

# import cuml
# import cudf
import torchmetrics
import torchmetrics.classification as tmc
import torchmetrics.functional as tmf

In [20]:
from einops import rearrange

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
SRC_PATH = "../Data/GenImage/"
generator_names = ["adm", "bgan", "glide", "midj", "sd_14", "sd_15", "vqdm", "wukong"]
with open("classes.json", "r", encoding="utf-8") as f:
    data = json.load(f)
classes_idx = data["1k_idx"]
classes_names = data["21k_idx"]

In [None]:
def compute_fixed_pca(tensor, n_components=48):
    # input: wst_tensor shape: (B, C=81, H, W)
    B, C, H, W = tensor.shape
    X = tensor.permute(0, 2, 3, 1).reshape(-1, C)  # shape: (B*H*W, C)
    
    # centralize
    X_mean = X.mean(dim=0, keepdim=True)
    X_centered = X - X_mean

    # SVD（or torch.pca_lowrank）
    U, S, Vh = torch.linalg.svd(X_centered, full_matrices=False)
    Vh_reduced = Vh[:n_components]  # shape: (D, C)

    return Vh_reduced, X_mean


def apply_fixed_pca(tensor, Vh_reduced, X_mean):
    # wst_tensor: (B, C=81, H, W)
    B, C, H, W = tensor.shape
    X = tensor.permute(0, 2, 3, 1).reshape(-1, C)              # (B*H*W, C)
    X_centered = X - X_mean[None, :]                              # (B*H*W, C)
    X_reduced = X_centered @ Vh_reduced.T                         # (B*H*W, D)
    X_reduced_reshaped = X_reduced.view(B, H, W, -1).permute(0, 3, 1, 2)  # (B, D, H, W)
    
    return X_reduced_reshaped

In [None]:
def resize_short_side(img, target_short_side):
    w, h = img.size
    if w < h:
        new_w = target_short_side
        new_h = int(h * (target_short_side / w))
    else:
        new_h = target_short_side
        new_w = int(w * (target_short_side / h))
    return img.resize((new_w, new_h), Image.BILINEAR)

wst_shape = (256, 256)
wst_preprocess = transforms.Compose(
    [
        transforms.Lambda(lambda img: resize_short_side(img, min(wst_shape))),  
        transforms.CenterCrop(wst_shape),  
        transforms.ToTensor(),  
    ]
)

def normalize_image(tensor):

    B, C, H, W = tensor.shape
    tensor_flat = tensor.view(B, -1)  # (B, C*H*W)
    min_val = tensor_flat.min(dim=1, keepdim=True)[0]
    max_val = tensor_flat.max(dim=1, keepdim=True)[0]
    normalized = (tensor_flat - min_val) / (max_val - min_val + 1e-8)
    return normalized.view(B, C, H, W)


def normalize_channel(tensor):

    B, C, H, W = tensor.shape
    tensor_flat = tensor.view(B, C, -1)  # (B, C, H*W)
    min_val = tensor_flat.min(dim=2, keepdim=True)[0]
    max_val = tensor_flat.max(dim=2, keepdim=True)[0]
    normalized = (tensor_flat - min_val) / (max_val - min_val + 1e-8)
    return normalized.view(B, C, H, W)

In [74]:
class WSTDataset(Dataset):
    def __init__(self, image_paths, transform):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        path = self.image_paths[idx]
        try:
            image = Image.open(path).convert("L")
            image = self.transform(image)
            return image, str(path)
        except:
            print("Failure open image.")
            return None

def wst_collate_fn(batch):
    batch = [item for item in batch if item is not None]
    if not batch:
        return None, None
    images, paths = zip(*batch)
    return torch.stack(images), paths

In [None]:
def wst_transform(input_path, transform, J=2, batch_size = 64, device = device):
    input_dir = Path(input_path)
    image_paths = list(input_dir.glob("*"))
    scattering = Scattering2D(J=J, shape=wst_shape).to(device)
    dataset = WSTDataset(image_paths, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=wst_collate_fn)
    wst_results = []
    with torch.no_grad():
        for batch in tqdm(dataloader):
            batch_images, batch_paths = batch
            batch_images = batch_images.to(device)  # shape: [batch_size, 1, H, W]
            coeffs = scattering(batch_images)  #  [batch_size, C, C', H', W']
            coeffs = coeffs.squeeze(1)
            wst_results.append(coeffs.cpu())
    wst_tensor = torch.cat(wst_results, dim=0)

    return wst_tensor

def reshape_to_3(x_reduced):
    # x_reduced: (B, D=48, 64, 64)
    B, D, H, W = x_reduced.shape
    assert D == 48
    x_split = x_reduced.view(B, 3, 16, H, W)  # (B, 3, 16, 64, 64)
    x_reshaped = rearrange(x_split, 'B c (h w) H W -> B c (h H) (w W)', h=4, w=4)
    return x_reshaped  # (B, 3, 256, 256)

def prepare_for_clip_batch(wst_tensor: torch.Tensor) -> torch.Tensor:

    # Step 1: Resize to (B, 3, 224, 224)
    resized = F.interpolate(wst_tensor, size=224, mode='bicubic', align_corners=False)

    # Step 2: Normalize（ broadcast to batch normalize）
    mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=wst_tensor.device).view(1, 3, 1, 1)
    std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=wst_tensor.device).view(1, 3, 1, 1)

    normalized = (resized - mean) / std  #  broadcast to batch
    return normalized


In [98]:
def pca_pipeline(input_path, transform, cls, device = device):
    wst_tensor = wst_transform(input_path, transform).to(device)
    wst_tensor = normalize_image(wst_tensor)
    Vh, mean = compute_fixed_pca(wst_tensor, n_components=48)
    torch.save({'Vh': Vh, 'mean': mean}, f'../Data/Features/wst_pca/{cls}.pt')
    wst_reduced = apply_fixed_pca(wst_tensor, Vh, mean)
    model_input = reshape_to_3(wst_reduced)
    return model_input.cpu(), Vh.cpu(), mean.cpu()

def compute_pipeline(input_path, transform, Vh, mean, device = device):
    wst_tensor = wst_transform(input_path, transform).to(device)
    wst_tensor = normalize_image(wst_tensor)
    wst_reduced = apply_fixed_pca(wst_tensor, Vh.to(device), mean.to(device))
    model_input = reshape_to_3(wst_reduced)
    return model_input.cpu()





In [None]:
def clip_encode(input_tensor, model, batch_size):
    dataset = TensorDataset(input_tensor)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    all_features = []
    for batch in tqdm(dataloader, desc="Image Encoding"):
        tensors = batch[0].to(device)  # (batch_size, 3, 256, 256)

        # Resize + Normalize
        tensors = prepare_for_clip_batch(tensors)

        with torch.no_grad():
            feats = model.encode_image(tensors)

        all_features.append(feats.cpu())
    all_features = torch.cat(all_features, dim=0)
    return all_features

In [65]:
def gt_compute(gt_tensors, eps=1e-3):
    embeddings = gt_tensors.to(device)
    mean = embeddings.mean(dim=0, keepdim=True)
    X = embeddings - mean
    cov = X.T @ X / (embeddings.size(0) - 1)
    cov += eps * torch.eye(cov.size(0), device=embeddings.device)
    return mean, cov


def mahalanobis_distance(x, mean, cov):
    x = x.to(torch.float32).view(-1)
    mean = mean.to(torch.float32).view(-1)
    delta = x - mean
    cov = cov.to(torch.float32)

    try:
        sol = torch.linalg.solve(cov, delta.unsqueeze(1))  # [D, 1]
        dist_squared = delta @ sol.squeeze()
        if dist_squared < 0:
            print("Warning: distance squared < 0", dist_squared.item())
            dist_squared = torch.clamp(dist_squared, min=0.0)
        dist = torch.sqrt(dist_squared)
        return dist
    except RuntimeError as e:
        print("Runtime error in Mahalanobis:", e)
        return torch.tensor(float("nan"), device=x.device)

In [None]:
def padim_detector(input_path, cls, model, batch_size = 64, device = device):
    data = {"AUROC": {}, "AUPRC": {}, "FPR95":{}}
    gt_tensor, Vh, mean = pca_pipeline(input_path + "/nature", wst_preprocess, cls)
    gt_features = clip_encode(gt_tensor, model, batch_size)
    gt_mean, gt_cov = gt_compute(gt_features)
    gt_mean, gt_cov = gt_mean.to(device), gt_cov.to(device)
    features = {}
    for generator in ["bgan", "midj", "sd_15", "nature_2"]:
        all_tensor = compute_pipeline(input_path + f"/{generator}", wst_preprocess, Vh, mean)
        all_features = clip_encode(all_tensor, model, batch_size)
        features[generator] = all_features

    bgan_m = torch.cat([features["bgan"], features["nature_2"]], dim=0).to(device)
    midj_m = torch.cat([features["midj"], features["nature_2"]], dim=0).to(device)
    sd_15_m = torch.cat([features["sd_15"], features["nature_2"]], dim=0).to(device)
    labels = np.concatenate((np.zeros(features["bgan"].shape[0]), np.ones(features["nature_2"].shape[0])))
    mixed = [bgan_m, midj_m, sd_15_m]
    for idx, generator in enumerate(["bgan", "midj", "sd_15"]):
        save_path = f"../Data/Padim_results/wst2/{cls}/{generator}.png"
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        test_scores = []
        for sample in mixed[idx]:
            distance = mahalanobis_distance(sample, gt_mean, gt_cov)
            test_scores.append(-distance.cpu())
        scores = np.array(test_scores)

        fpr, tpr, thresholds = roc_curve(labels, scores)
        idx = np.where(tpr >= 0.95)[0][0]
        fpr_95 = fpr[idx]
        distances = np.sqrt((1 - tpr) ** 2 + fpr**2)
        best_threshold = thresholds[np.argmin(distances)]
        print("Best threshold(ROC):", best_threshold)

        roc_auc = roc_auc_score(labels, scores)
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
        ax1.plot(fpr, tpr, color="darkorange", lw=2, label=f"ROC curve (area = {roc_auc:.2f})")
        ax1.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
        ax1.set_xlim([0.0, 1.0])
        ax1.set_ylim([0.0, 1.05])
        ax1.set_xlabel("False Positive Rate (FPR)")
        ax1.set_ylabel("True Positive Rate (TPR)")
        ax1.set_title("Receiver Operating Characteristic (ROC) Curve")
        ax1.legend(loc="lower right")

        precision, recall, thresholds = precision_recall_curve(labels, scores)
        pr_auc = auc(recall, precision)
        # print("AUPRC:", pr_auc)

        ax2.plot(recall, precision, color="blue", lw=2, label=f"PR curve (area = {pr_auc:.2f})")
        ax2.set_xlim([0.0, 1.0])
        ax2.set_ylim([0.0, 1.05])
        ax2.set_xlabel("Recall")
        ax2.set_ylabel("Precision")
        ax2.set_title("Precision-Recall Curve")
        ax2.legend(loc="best")

        plt.tight_layout()
        plt.savefig(save_path)
        plt.close()

        data["AUROC"][generator] = roc_auc
        data["AUPRC"][generator] = pr_auc
        data["FPR95"][generator] = fpr_95
    
    return data
    
    

In [None]:
model, _ = clip.load("ViT-B/32", device=device)
dict_all = {}
for cls in classes_idx:
    data = padim_detector(f"../Data/GenImage/{cls}", cls, model)
    dict_all[cls] = data

all_rows = []

for class_name, data in dict_all.items(): 
    for generator, _ in data["AUROC"].items():
        row = {
            "CLASS": class_name,
            "GENERATOR": generator,
            "AUROC": data["AUROC"][generator],
            "AUPRC": data["AUPRC"][generator],
            "FPR95": data["FPR95"][generator],
        }
        all_rows.append(row)

df = pd.DataFrame(all_rows)
df.to_csv("wst_clip_results.csv")

In [None]:
avg_df = pd.read_csv("wst_clip_results.csv")

Unnamed: 0_level_0,bgan,midj,sd_15
metric,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
AUPRC,0.497237,0.561964,0.644344
AUROC,0.457205,0.550591,0.632697
FPR95,0.97037,0.897531,0.880247
