This notebook is the implementation of thr paper "DIRE for Diffusion-Generated Image Detection" , "https://arxiv.org/pdf/2303.09295"

In [None]:
import os
import pandas as pd
from tqdm import tqdm
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.linear_model import LogisticRegression

import torch
from torchvision import transforms

In [None]:
def add_noise(image_tensor, timesteps=20):
    noisy_image = image_tensor.clone()
    for _ in range(timesteps):
        noisy_image = noisy_image + torch.randn_like(noisy_image) * 0.05
    return noisy_image


def denoise(noisy_image_tensor, timesteps=10):
    denoised_image = noisy_image_tensor.clone()
    for _ in range(timesteps):
        denoised_image = denoised_image - torch.randn_like(denoised_image) * 0.05
    return denoised_image

In [None]:
def ddim_inversion(image_tensor, timesteps=10, alpha_start=0.1, alpha_end=0.9):
    """Applies the DDIM inversion process by adding Gaussian noise over a series of timesteps."""
    noisy_image = image_tensor.clone()
    alphas = (
        torch.linspace(alpha_start, alpha_end, timesteps)
        .to(image_tensor.device)
        .type_as(image_tensor)
    )

    for t in range(timesteps):
        alpha_t = alphas[t]
        if t < timesteps - 1:
            alpha_next = alphas[t + 1]
        else:
            alpha_next = alpha_end

        epsilon_theta = torch.randn_like(
            noisy_image
        )  # Simulating model prediction noise

        # Convert to tensors and clamp to avoid invalid values
        alpha_t = torch.tensor(alpha_t, device=image_tensor.device).type_as(
            image_tensor
        )
        alpha_next = torch.tensor(alpha_next, device=image_tensor.device).type_as(
            image_tensor
        )

        # Ensure values inside sqrt are positive
        alpha_ratio = torch.sqrt(torch.clamp(alpha_next / alpha_t, min=1e-6))
        diff_sqrt = torch.sqrt(torch.clamp((1 - alpha_next) - (1 - alpha_t), min=1e-6))

        noisy_image = noisy_image * alpha_ratio + epsilon_theta * diff_sqrt

    return noisy_image


def ddim_reconstruction(
    noisy_image_tensor, timesteps=10, alpha_start=0.1, alpha_end=0.9
):
    """Reconstructs the image from the noisy sample using the deterministic DDIM reverse process."""
    reconstructed_image = noisy_image_tensor.clone()
    alphas = (
        torch.linspace(alpha_start, alpha_end, timesteps)
        .to(noisy_image_tensor.device)
        .type_as(noisy_image_tensor)
    )

    for t in range(timesteps - 1, -1, -1):
        alpha_t = alphas[t]
        if t > 0:
            alpha_prev = alphas[t - 1]
        else:
            alpha_prev = alpha_start

        epsilon_theta = torch.randn_like(
            reconstructed_image
        )  # Simulating model prediction noise

        # Convert to tensors and clamp to avoid invalid values
        alpha_t = torch.tensor(alpha_t, device=noisy_image_tensor.device).type_as(
            noisy_image_tensor
        )
        alpha_prev = torch.tensor(alpha_prev, device=noisy_image_tensor.device).type_as(
            noisy_image_tensor
        )

        # Ensure values inside sqrt are positive
        sqrt_alpha_t = torch.sqrt(torch.clamp(alpha_t, min=1e-6))
        sqrt_one_minus_alpha_t = torch.sqrt(torch.clamp(1 - alpha_t, min=1e-6))

        x_t = (
            reconstructed_image * sqrt_alpha_t + epsilon_theta * sqrt_one_minus_alpha_t
        )

        # Update reconstructed_image based on deterministic reverse process
        sqrt_alpha_ratio = torch.sqrt(torch.clamp(alpha_prev / alpha_t, min=1e-6))
        diff_term = torch.sqrt(torch.clamp(1 - alpha_prev, min=1e-6)) - torch.sqrt(
            torch.clamp((1 - alpha_t) / alpha_t, min=1e-6)
        )

        reconstructed_image = x_t * sqrt_alpha_ratio + diff_term * epsilon_theta

    return reconstructed_image


# DIRE calculation for a given image tensor
def calculate_dire(image_tensor, timesteps=10):
    # Perform inversion to noisy space and then reconstruct
    noisy_image = ddim_inversion(image_tensor, timesteps=timesteps)
    reconstructed_image = ddim_reconstruction(noisy_image, timesteps=timesteps)

    # Check for NaNs in the reconstructed image
    if torch.isnan(reconstructed_image).any():
        print("Warning: NaN values found in reconstructed image")
        return float("nan")

    # Compute DIRE as the mean absolute difference
    dire = torch.abs(image_tensor - reconstructed_image).mean().item()
    return dire


In [None]:
# Directory paths
real_dir = "\test\REAL"
fake_dir = "\test\FAKE"

# Dataframe to store DIRE values and labels
data = {"image_path": [], "DIRE": [], "label": []}

# Image transformation
transform = transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor()])

In [None]:
df = pd.DataFrame(columns=["image_path", "DIRE", "label"])

# Load images, calculate DIRE, and store in the dataframe
for label, folder in [("REAL", real_dir), ("FAKE", fake_dir)]:
    for filename in tqdm(
        os.listdir(folder), desc=f"Processing {label} images", unit="file"
    ):
        file_path = os.path.join(folder, filename)
        try:
            image = Image.open(file_path).convert("RGB")
            image_tensor = transform(image)
            dire_value = calculate_dire(image_tensor)
            df = df.append(
                {
                    "image_path": file_path,
                    "DIRE": dire_value,
                    "label": 1 if label == "REAL" else 0,
                },
                ignore_index=True,
            )

        except Exception as e:
            print(f"Error processing {file_path}: {e}")

In [None]:
# Split the data into training and testing sets
X = df[["DIRE"]]
y = df["label"]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Train a Logistic Regression classifier
classifier = LogisticRegression()
classifier.fit(X_train, y_train)

In [None]:
# Predictions and Evaluation
y_pred = classifier.predict(X_test)

In [None]:
# Calculate metrics
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)

print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")