In [2]:
# # ================= LOAD THRESHOLDS FROM CSV =================
# import pandas as pd
# import json
# from pathlib import Path

# CSV_PATH = Path("/Users/mrinalseth13331/Downloads/cae_per_category_summary.csv")
# CAE_DIR = Path("saved_spatial_caes_simple")
# CAE_DIR.mkdir(exist_ok=True)
# OUT_JSON = CAE_DIR / "thresholds.json"

# category_list = [
#     'bottle', 'cable', 'capsule', 'carpet', 'grid', 'hazelnut', 
#     'leather', 'metal_nut', 'pill', 'screw', 'tile', 
#     'toothbrush', 'transistor', 'wood', 'zipper'
# ]

# df = pd.read_csv(CSV_PATH)
# print("CSV columns:", df.columns.tolist())

# # find threshold column automatically
# threshold_col = [c for c in df.columns if "threshold" in c.lower()][0]
# cat_col = [c for c in df.columns if "category" in c.lower()][0]

# thresholds = {}
# missing = []

# for cat in category_list:
#     row = df[df[cat_col] == cat]
#     if row.empty:
#         print("Missing category in CSV:", cat)
#         missing.append(cat)
#         continue
#     thresholds[cat] = float(row[threshold_col].values[0])

# print("\nLoaded thresholds for:")
# for k,v in thresholds.items():
#     print(f"{k}: {v}")

# # Save as JSON for Gradio
# with open(OUT_JSON, "w") as f:
#     json.dump(thresholds, f, indent=2)

# print("\nSaved thresholds.json to:", OUT_JSON)


CSV columns: ['category', 'train_mean', 'train_std', 'threshold', 'image_AUROC', 'accuracy', 'n_test']

Loaded thresholds for:
bottle: 0.1071691784986634
cable: 0.1449074127336039
capsule: 0.0881808400108139
carpet: 0.0585161529997106
grid: 0.0914179283961397
hazelnut: 0.1328290603011438
leather: 0.0502316191441985
metal_nut: 0.1131659973496359
pill: 0.0923540407455983
screw: 0.1081089502004047
tile: 0.0905780693737589
toothbrush: 0.1283455279455947
transistor: 0.1239226004720422
wood: 0.0961848321741979
zipper: 0.0660800238644721

Saved thresholds.json to: saved_spatial_caes_simple/thresholds.json


In [1]:
# ===== Gradio app: classifier → extractor → CAE → threshold decision =====

import os, json
from pathlib import Path
import torch, torch.nn.functional as F
import numpy as np
from PIL import Image
import gradio as gr
from torchvision import transforms
from torchvision.models import resnet18, ResNet18_Weights, resnet50, ResNet50_Weights
import torch.nn as nn

# ---------------- CONFIG ----------------
CAE_DIR = Path("saved_spatial_caes_simple")
CLASSIFIER_CKPT = Path("best_classifier.pth")
LATENT_DIM = 100
DEVICE_PREFERENCE = "mps"
TOPK = 10
category_list = [
    'bottle','cable','capsule','carpet','grid','hazelnut',
    'leather','metal_nut','pill','screw','tile',
    'toothbrush','transistor','wood','zipper'
]

device = torch.device("mps") if (DEVICE_PREFERENCE=="mps" and torch.backends.mps.is_available()) else torch.device("cpu")
print("Device:", device)

# transforms
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

# -------- feature extractor (ResNet50 hooks) --------
class ResNetFeatureExtractorFast(nn.Module):
    def __init__(self, pretrained=True, device=None):
        super().__init__()
        w = ResNet50_Weights.DEFAULT if pretrained else None
        self.model = resnet50(weights=w)
        self.model.eval()
        for p in self.model.parameters(): p.requires_grad = False
        self.features = []
        self.hooks = []
        self.hooks.append(self.model.layer2[-1].register_forward_hook(self._hook))
        self.hooks.append(self.model.layer3[-1].register_forward_hook(self._hook))
        if device is not None: self.to(device)

    def _hook(self, module, input, output):
        self.features.append(output.detach())

    def forward(self, x):
        self.features = []
        with torch.no_grad(): _ = self.model(x)
        out = []
        tsize = self.features[0].shape[-2:]
        for f in self.features:
            f = F.avg_pool2d(f, 3, 1, 1)
            f = F.interpolate(f, size=tsize, mode='bilinear')
            out.append(f)
        return torch.cat(out, dim=1)

# -------- CAE --------
class FeatCAE(nn.Module):
    def __init__(self, in_channels=1536, latent_dim=100, is_bn=True):
        super().__init__()
        e = []
        e += [nn.Conv2d(in_channels, (in_channels+2*latent_dim)//2, 1)]
        if is_bn: e += [nn.BatchNorm2d((in_channels+2*latent_dim)//2)]
        e += [nn.ReLU()]
        e += [nn.Conv2d((in_channels+2*latent_dim)//2, 2*latent_dim, 1)]
        if is_bn: e += [nn.BatchNorm2d(2*latent_dim)]
        e += [nn.ReLU()]
        e += [nn.Conv2d(2*latent_dim, latent_dim, 1)]
        self.encoder = nn.Sequential(*e)

        d = []
        d += [nn.Conv2d(latent_dim, 2*latent_dim, 1)]
        if is_bn: d += [nn.BatchNorm2d(2*latent_dim)]
        d += [nn.ReLU()]
        d += [nn.Conv2d(2*latent_dim, (in_channels+2*latent_dim)//2, 1)]
        if is_bn: d += [nn.BatchNorm2d((in_channels+2*latent_dim)//2)]
        d += [nn.ReLU()]
        d += [nn.Conv2d((in_channels+2*latent_dim)//2, in_channels, 1)]
        self.decoder = nn.Sequential(*d)

    def forward(self, x):
        return self.decoder(self.encoder(x))

# -------- extractor --------
extractor = ResNetFeatureExtractorFast(pretrained=True, device=device)
extractor.model.eval()

# -------- classifier --------
if not CLASSIFIER_CKPT.exists():
    raise RuntimeError("best_classifier.pth not found.")

clf = resnet18(weights=ResNet18_Weights.DEFAULT)
clf.fc = nn.Linear(clf.fc.in_features, len(category_list))
state = torch.load(CLASSIFIER_CKPT, map_location=device)
if isinstance(state, dict) and "model_state_dict" in state:
    clf.load_state_dict(state["model_state_dict"])
else:
    clf.load_state_dict(state)
clf.to(device).eval()

# -------- load CAEs --------
trained_caes = {}
for cat in category_list:
    ck = CAE_DIR / f"cae_{cat}.pth"
    if ck.exists():
        cae = FeatCAE(1536, LATENT_DIM)
        cae.load_state_dict(torch.load(ck, map_location=device))
        cae.to(device).eval()
        trained_caes[cat] = cae
    else:
        print("Missing CAE:", cat)

# -------- load thresholds.json --------
thr_file = CAE_DIR / "thresholds.json"
if not thr_file.exists():
    raise RuntimeError("thresholds.json missing.")
with open(thr_file,"r") as f:
    thresholds = json.load(f)

# -------- scoring --------
def topk_score(seg, k=TOPK):
    flat = seg.view(seg.size(0), -1)
    k = min(k, flat.size(1))
    vals = torch.topk(flat, k, dim=1).values
    return vals.mean(dim=1)

# -------- heat overlay --------
def overlay_heat(img, heat, alpha=0.6):
    import matplotlib.cm as cm
    img = img.convert("RGB").resize((224,224))
    cmap = cm.get_cmap("jet")
    heat = (cmap(heat)[:,:,:3] * 255).astype(np.uint8)
    return Image.blend(img, Image.fromarray(heat), alpha)

# -------- pipeline --------
def predict(img_pil):
    x = transform(img_pil).unsqueeze(0).to(device)

    with torch.no_grad():
        pred_idx = clf(x).argmax(1).item()
    cat = category_list[pred_idx]

    if cat not in trained_caes:
        return None, None, f"No CAE for {cat}"

    thr = thresholds[cat]

    with torch.no_grad():
        f = extractor(x)
        r = trained_caes[cat](f)
        err = ((f - r)**2).mean(1, keepdim=True)
        up = F.interpolate(err, (224,224), mode='bilinear')

    heat = up.squeeze().cpu().numpy()
    heat_norm = (heat - heat.min())/(heat.max()-heat.min()+1e-8)

    score = float(topk_score(up).cpu())
    anomaly = score >= thr

    overlay = overlay_heat(img_pil, heat_norm)
    text = f"Category: {cat}\nScore: {score:.6f}\nThreshold: {thr:.6f}\nAnomaly: {'YES' if anomaly else 'NO'}"
    return overlay, heat_norm, text

# -------- Gradio UI --------
iface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", label="Upload image"),
    outputs=[
        gr.Image(type="pil", label="Overlay"),
        gr.Image(type="numpy", label="Heatmap"),
        gr.Textbox(label="Result")
    ],
    title="MVTec CAE-Based Anomaly Detection"
)

iface.launch(share=False)


Device: mps
* Running on local URL:  http://127.0.0.1:7860
* To create a public link, set `share=True` in `launch()`.


