1. Imports & Device Setup

In [None]:
import os
import glob
from tqdm import tqdm
import numpy as np
import cv2
from PIL import Image
import torch
import matplotlib.pyplot as plt

import clip
from diffusers import StableDiffusionImg2ImgPipeline
from sklearn.metrics import roc_auc_score
from skimage.metrics import structural_similarity as ssim

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)


2. MVTec Dataset Loader

In [None]:
DATASET_PATH = "/kaggle/input/mvtec-ad"

def load_images(category="carpet", img_type="train", resize=256, limit=None):
    path = os.path.join(DATASET_PATH, category, img_type)
    images, labels = [], []
    
    for defect_type in sorted(os.listdir(path)):
        defect_path = os.path.join(path, defect_type)
        files = glob.glob(os.path.join(defect_path, "*.png"))
        if limit:
            files = files[:limit]
        for f in files:
            img = cv2.imread(f)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = cv2.resize(img, (resize, resize))
            images.append(img)
            labels.append(0 if defect_type == "good" else 1)

    return np.array(images), np.array(labels)


3. Load CLIP + Stable Diffusion img2img Pipeline

In [None]:
print("Loading CLIP...")
clip_model, preprocess = clip.load("ViT-B/32", device=device)

print("Loading Stable Diffusion img2img pipeline...")
pipe_img2img = StableDiffusionImg2ImgPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float16 if device == "cuda" else torch.float32,
    safety_checker=None,
    feature_extractor=None,
)

pipe_img2img = pipe_img2img.to(device)

try:
    pipe_img2img.enable_attention_slicing()
except Exception:
    pass

generator = torch.Generator(device=device).manual_seed(42) if device=="cuda" else torch.Generator().manual_seed(42)


4. Stable Diffusion Reconstruction Helper

In [None]:
def reconstruct_img2img(np_img,
                        prompt,
                        negative_prompt,
                        strength=0.35,
                        guidance_scale=7.5,
                        num_inference_steps=30,
                        target_size=512):

    pil = Image.fromarray(np_img.astype(np.uint8)).convert("RGB")
    pil_resized = pil.resize((target_size, target_size), resample=Image.LANCZOS)

    call_kwargs = dict(
        prompt=prompt,
        strength=strength,
        negative_prompt=negative_prompt,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        generator=generator,
        output_type="pil",
    )

    call_kwargs_with_image = {**call_kwargs, "image": pil_resized}
    call_kwargs_with_init_image = {**call_kwargs, "init_image": pil_resized}

    if device == "cuda":
        with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
            try:
                result = pipe_img2img(**call_kwargs_with_image)
            except TypeError:
                result = pipe_img2img(**call_kwargs_with_init_image)
    else:
        try:
            result = pipe_img2img(**call_kwargs_with_image)
        except TypeError:
            result = pipe_img2img(**call_kwargs_with_init_image)

    recon = np.array(result.images[0]).astype(np.float32) / 255.0
    return recon


5. Anomaly Scoring Functions

In [None]:
def pixel_mse(img, recon):
    return float(np.mean((img/255.0 - recon) ** 2))

def pixel_ssim(img, recon):
    return float(1.0 - ssim(img/255.0, recon, channel_axis=2))

def clip_similarity_error(img, recon):
    img_pil = Image.fromarray(img.astype(np.uint8))
    recon_pil = Image.fromarray((np.clip(recon, 0, 1) * 255).astype(np.uint8))
    with torch.no_grad():
        img_feat = clip_model.encode_image(preprocess(img_pil).unsqueeze(0).to(device))
        recon_feat = clip_model.encode_image(preprocess(recon_pil).unsqueeze(0).to(device))
        sim = torch.cosine_similarity(img_feat, recon_feat).item()
    return float(1.0 - sim)

def hybrid_score(img, recon, alpha=0.8):
    return alpha * pixel_mse(img, recon) + (1.0 - alpha) * clip_similarity_error(img, recon)


6. Evaluation Loop (AUROC)

In [None]:
def evaluate_with_img2img(category="carpet", resize=256, limit_train=200, limit_test=None,
                          prompt_template="a defect-free {category}",
                          negative_prompt=None,
                          strength=0.35, guidance_scale=7.5,
                          num_inference_steps=30, alpha=0.90):

    print(f"Loading data for category {category}...")
    train_imgs, train_labels = load_images(category, "train", resize=resize, limit=limit_train)
    test_imgs, test_labels = load_images(category, "test", resize=resize, limit=limit_test)

    print("Train normals:", int((train_labels == 0).sum()), "Test samples:", len(test_imgs))

    scores, labels = [], []
    prompt = prompt_template.format(category=category)

    for i in tqdm(range(len(test_imgs)), desc="Eval img2img reconstructions"):
        img = test_imgs[i]

        recon = reconstruct_img2img(
            img, prompt=prompt, negative_prompt=negative_prompt,
            strength=strength, guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps, target_size=512
        )

        h, w = img.shape[:2]
        recon_resized = cv2.resize((recon * 255).astype(np.uint8), (w, h)) / 255.0

        s = hybrid_score(img, recon_resized, alpha=alpha)
        scores.append(s)
        labels.append(int(test_labels[i]))

    auc = roc_auc_score(labels, scores)
    print(f"Hybrid AUROC for {category}: {auc:.4f}")

    return {
        "auc": auc,
        "scores": np.array(scores),
        "labels": np.array(labels),
        "test_imgs": test_imgs,
    }


7. Heatmap Visualization

In [None]:
def anomaly_map(img, recon):
    diff = np.abs(img/255.0 - recon)
    heatmap = np.mean(diff, axis=2)
    heatmap /= (heatmap.max() + 1e-8)
    return heatmap

def show_sample_result(img, recon, idx=0, save_path=None):
    if recon.shape[:2] != img.shape[:2]:
        h, w = img.shape[:2]
        recon = cv2.resize((recon * 255).astype(np.uint8), (w, h)) / 255.0

    heatmap = anomaly_map(img, recon)

    plt.figure(figsize=(12,4))
    plt.subplot(1,3,1); plt.imshow(img.astype(np.uint8)); plt.title("Original"); plt.axis('off')
    plt.subplot(1,3,2); plt.imshow((np.clip(recon,0,1)*255).astype(np.uint8)); plt.title("Reconstruction"); plt.axis('off')
    plt.subplot(1,3,3); plt.imshow(img.astype(np.uint8)); plt.imshow(heatmap, cmap='jet', alpha=0.5)
    plt.title("Anomaly Map"); plt.axis('off')

    if save_path:
        plt.savefig(save_path, bbox_inches='tight')
    plt.show()


8. Optuna Hyperparameter Search

In [None]:
def objective(trial):
    strength = trial.suggest_float("strength", 0.1, 0.7)
    guidance_scale = trial.suggest_float("guidance_scale", 2.0, 10.0)
    alpha = trial.suggest_float("alpha", 0.5, 0.95)

    auc_dict = evaluate_with_img2img(
        category="carpet",
        strength=strength,
        guidance_scale=guidance_scale,
        alpha=alpha,
        limit_train=50,
        limit_test=50,
    )

    return auc_dict["auc"]

study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=64, show_progress_bar=True)

print("Best AUROC:", study.best_value)
print("Best parameters:", study.best_params)
