In [6]:
# tune these hyperparams as needed
EPOCHS = 15
BATCH_SIZE = 8
LR = 1e-3
LATENT_DIM = 100

In [7]:
import torch.nn as nn

class FeatCAE(nn.Module):
    """Autoencoder."""

    def __init__(self, in_channels=1000, latent_dim=50, is_bn=True):
        super(FeatCAE, self).__init__()

        layers = []
        layers += [nn.Conv2d(in_channels, (in_channels + 2 * latent_dim) // 2, kernel_size=1, stride=1, padding=0)]
        if is_bn:
            layers += [nn.BatchNorm2d(num_features=(in_channels + 2 * latent_dim) // 2)]
        layers += [nn.ReLU()]
        layers += [nn.Conv2d((in_channels + 2 * latent_dim) // 2, 2 * latent_dim, kernel_size=1, stride=1, padding=0)]
        if is_bn:
            layers += [nn.BatchNorm2d(num_features=2 * latent_dim)]
        layers += [nn.ReLU()]
        layers += [nn.Conv2d(2 * latent_dim, latent_dim, kernel_size=1, stride=1, padding=0)]

        self.encoder = nn.Sequential(*layers)

        # if 1x1 conv to reconstruct the rgb values, we try to learn a linear combination
        # of the features for rgb
        layers = []
        layers += [nn.Conv2d(latent_dim, 2 * latent_dim, kernel_size=1, stride=1, padding=0)]
        if is_bn:
            layers += [nn.BatchNorm2d(num_features=2 * latent_dim)]
        layers += [nn.ReLU()]
        layers += [nn.Conv2d(2 * latent_dim, (in_channels + 2 * latent_dim) // 2, kernel_size=1, stride=1, padding=0)]
        if is_bn:
            layers += [nn.BatchNorm2d(num_features=(in_channels + 2 * latent_dim) // 2)]
        layers += [nn.ReLU()]
        layers += [nn.Conv2d((in_channels + 2 * latent_dim) // 2, in_channels, kernel_size=1, stride=1, padding=0)]
        # layers += [nn.ReLU()]

        self.decoder = nn.Sequential(*layers)

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [None]:
import torch, os, numpy as np, pandas as pd
import torch
import torch.nn.functional as F
from torchvision.models import resnet50, ResNet50_Weights

class ResNetFeatureExtractorFast(torch.nn.Module):
    def __init__(self, pretrained=True, device=None):
        super().__init__()
        weights = ResNet50_Weights.DEFAULT if pretrained else None
        self.model = resnet50(weights=weights)
        self.model.eval()
        for p in self.model.parameters():
            p.requires_grad = False

        # Register hooks once (optional) or you can call layers directly (see below)
        self.features = []
        self.handles = []
        self.handles.append(self.model.layer2[-1].register_forward_hook(self._hook))
        self.handles.append(self.model.layer3[-1].register_forward_hook(self._hook))

        if device is not None:
            self.to(device)

    def _hook(self, module, input, output):
        # keep on-device detached copy
        self.features.append(output.detach())

    def forward(self, x, target_spatial=None):
        # clear
        self.features = []

        # run backbone (no grad)
        with torch.no_grad():
            _ = self.model(x)


        # target spatial default = layer2 size
        if target_spatial is None:
            h = self.features[0].shape[-2]
            w = self.features[0].shape[-1]
            target_spatial = (h, w)

        resized = []
        for fmap in self.features:
            # smooth (keep same size)
            fmap_smoothed = F.avg_pool2d(fmap, kernel_size=3, stride=1, padding=1)
            # resize to target spatial using interpolate (works on MPS and CPU)
            if fmap_smoothed.shape[-2:] != target_spatial:
                fmap_resized = F.interpolate(fmap_smoothed, size=target_spatial,
                                             mode='bilinear', align_corners=False)
            else:
                fmap_resized = fmap_smoothed
            resized.append(fmap_resized)

        patch = torch.cat(resized, dim=1)
        return patch

    def remove_hooks(self):
        for h in self.handles:
            h.remove()
        self.handles = []


        

In [None]:
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from sklearn.metrics import roc_auc_score, confusion_matrix, accuracy_score
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Subset

# ----- USER EDITABLE / CHECK these paths -----
MVTEC_ROOT = Path("/Users/mrinalseth13331/Downloads/archive")                     # root of original MVTec dataset
CAE_DIR = Path("saved_spatial_caes_simple")    # where CAE checkpoints are stored, named cae_<cat>.pth
RESULTS_CSV = Path("cae_image_results.csv")    # per-image results saved here
PERF_CSV = Path("cae_per_category_summary.csv")# per-category summary
# ----------------------------------------------

# device & transform (must match training preprocessing)
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
print("Device:", device)

IMAGENET_MEAN = [0.485,0.456,0.406]
IMAGENET_STD  = [0.229,0.224,0.225]
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
])

# helper: ensure your extractor exists; if not, create it (Option A extractor)
# If you already have `extractor` in the notebook, this will skip creating a new one.
try:
    extractor  # noqa: F821
    print("Using existing 'extractor' in memory.")
except NameError:
    print("No 'extractor' found in memory â€” creating a fresh ResNetFeatureExtractorFast.")
    # Paste the Option A class code here if not defined already in your notebook.
    # For brevity, assume you have `ResNetFeatureExtractorFast` class defined above. If not,
    # please paste the class cell (Option A) before running this evaluation cell.
    extractor = ResNetFeatureExtractorFast(pretrained=True, device=device)
    extractor.model.eval()
    for p in extractor.model.parameters():
        p.requires_grad = False

# check FeatCAE is defined
if 'FeatCAE' not in globals():
    raise RuntimeError("FeatCAE class not found in notebook. Define FeatCAE (same as used when training) before running evaluation.")

# build category list from MVTec root (folders inside mvtec/)
category_list = sorted([p.name for p in MVTEC_ROOT.iterdir() if p.is_dir()])
print("Found categories:", category_list)

# helper: data loader for training 'good' images of a given category
def loader_train_good_for_category(cat_name, bs=16):
    good_dir = MVTEC_ROOT / cat_name / "train" / "good"
    if not good_dir.exists():
        # fallback to using mvtec_all/train/<cat> if you prepared it
        alt_dir = Path("mvtec_all") / "train" / cat_name
        if alt_dir.exists():
            ds = ImageFolder(str(alt_dir.parent), transform=transform)  # careful: we'll filter; but simpler below
            # Build subset as earlier; but easier: let us build a simple list loader:
            imgs = sorted(list(alt_dir.glob("*.png")) + list(alt_dir.glob("*.jpg")))
            def gen():
                for p in imgs:
                    yield transform(Image.open(p).convert("RGB"))
            # create a DataLoader from tensor dataset after stacking might be easier, but simpler to just raise
            raise RuntimeError(f"No 'train/good' at {good_dir} and alternate {alt_dir} handling not implemented. Create folders or change path.")
        else:
            raise RuntimeError(f"No training-good folder for category {cat_name}: expected {good_dir}")
    # use a simple dataset wrapper
    class SimpleFolderDataset(torch.utils.data.Dataset):
        def __init__(self, files, transform):
            self.files = files
            self.transform = transform
        def __len__(self): return len(self.files)
        def __getitem__(self, idx):
            p = self.files[idx]
            img = Image.open(p).convert("RGB")
            return self.transform(img), 0
    files = sorted(list(good_dir.glob("*.png")) + list(good_dir.glob("*.jpg")))
    ds = SimpleFolderDataset(files, transform)
    return DataLoader(ds, batch_size=bs, shuffle=False, num_workers=0), files

# helper: gather test images recursively under mvtec/<cat>/test
def gather_test_images(cat_name):
    test_root = MVTEC_ROOT / cat_name / "test"
    if not test_root.exists():
        raise RuntimeError(f"No test folder for {cat_name} at {test_root}")
    img_paths = sorted(list(test_root.rglob("*.png")) + list(test_root.rglob("*.jpg")))
    # filter out any possible ground_truth masks
    img_paths = [p for p in img_paths if "ground_truth" not in p.parts]
    return img_paths

# vectorized "decision_function" that takes a tensor seg_map of shape (B,1,H,W)
def topk_mean_score_from_segmap(seg_map, topk=10):
    """
    seg_map: tensor (B,1,H,W) or (B,H,W)
    returns: tensor of shape (B,) of top-k mean values (float)
    """
    if seg_map.dim() == 4:
        B = seg_map.shape[0]
        flat = seg_map.view(B, -1)
    elif seg_map.dim() == 3:
        B = seg_map.shape[0]
        flat = seg_map.view(B, -1)
    else:
        raise ValueError("seg_map must be 3D or 4D tensor")
    # sort descending and take topk
    if flat.shape[1] < topk:
        topk = flat.shape[1]
    # use partition for speed
    kth = torch.topk(flat, k=topk, dim=1, largest=True, sorted=False)[0]  # (B, topk)
    topk_mean = kth.mean(dim=1)
    return topk_mean  # shape (B,)

# compute heatmap upsample and scores (returns heat as numpy HxW, and score float)
def compute_heat_and_topk_score(img_tensor, extractor, cae, up_size=(224,224), topk=10):
    with torch.no_grad():
        patch = extractor(img_tensor)                      # (1,1536,h,w)
        recon = cae(patch)                                 # (1,1536,h,w)
        err = ((patch - recon) ** 2).mean(dim=1, keepdim=True)  # (1,1,h,w)
        # crop border if CAE training used cropping (you used [3:-3,3:-3] earlier). We'll keep full map then optionally crop.
        # upsample to image size
        up = F.interpolate(err, size=up_size, mode='bilinear', align_corners=False)  # (1,1,224,224)
        up_np = up.squeeze().cpu().numpy()
        score = topk_mean_score_from_segmap(up, topk=topk).item()
    heat = (up_np - up_np.min()) / (up_np.max() - up_np.min() + 1e-8)
    return heat, score, up_np

# storage for per-image rows
rows = []

# per-category summary results
summary_rows = []

# Main loop: for each category, load CAE, compute threshold from train/good, then evaluate test images
for cat in category_list:
    ckpt = CAE_DIR / f"cae_{cat}.pth"
    if not ckpt.exists():
        print(f"CAE for {cat} not found at {ckpt}, skipping category.")
        continue

    print("=== Evaluating category:", cat)
    # instantiate CAE and load weights
    cae = FeatCAE(in_channels=1536, latent_dim=LATENT_DIM, is_bn=True)  # ensure LATENT_DIM matches your training
    cae.load_state_dict(torch.load(ckpt, map_location=device))
    cae.to(device).eval()

    # compute train normal scores (to get threshold)
    train_loader_cat, train_files = loader_train_good_for_category(cat, bs=8)
    recon_scores = []
    for imgs, _ in tqdm(train_loader_cat, desc=f"{cat} train->scores"):
        imgs = imgs.to(device)
        with torch.no_grad():
            patch = extractor(imgs)   # (B,1536,h,w)
            recon = cae(patch)
            err = ((patch - recon) ** 2).mean(dim=1, keepdim=True)  # (B,1,h,w)
            up = F.interpolate(err, size=(224,224), mode='bilinear', align_corners=False)
            # compute top-10 mean per image using vectorized function
            topk_mean = topk_mean_score_from_segmap(up, topk=10)  # (B,)
            recon_scores.extend(topk_mean.cpu().numpy().tolist())

    recon_scores = np.array(recon_scores)
    if recon_scores.size == 0:
        print(f"No training scores computed for {cat} (empty train). Skipping.")
        continue

    # compute threshold: mean + 3*std
    best_threshold = float(recon_scores.mean() + 3*recon_scores.std())
    print(f"{cat}: train_scores mean={recon_scores.mean():.6f} std={recon_scores.std():.6f} threshold={best_threshold:.6f}")

    # Evaluate on test images
    img_paths = gather_test_images(cat)
    y_true = []
    y_score = []
    y_pred = []

    for p in tqdm(img_paths, desc=f"{cat} test"):
        # determine label from folder name: parent folder 'good' => normal else anomalous
        label_folder = p.parent.name.lower()
        is_normal = ("good" in label_folder)
        y_true_label = 0 if is_normal else 1

        # load image and compute score
        img_tensor = transform(Image.open(p).convert("RGB")).unsqueeze(0).to(device)
        heat, score, raw = compute_heat_and_topk_score(img_tensor, extractor, cae, up_size=(224,224), topk=10)

        y_true.append(y_true_label)
        y_score.append(score)
        y_pred.append(1 if score >= best_threshold else 0)

        rows.append({
            "category": cat,
            "image": str(p),
            "label": y_true_label,
            "score": score,
            "pred": 1 if score >= best_threshold else 0
        })

    # compute metrics for this category
    try:
        img_auroc = roc_auc_score(y_true, y_score) if len(set(y_true))>1 else None
    except Exception as e:
        img_auroc = None
    acc = accuracy_score(y_true, y_pred)
    cm = confusion_matrix(y_true, y_pred)

    print(f"{cat}  AUROC(image)={img_auroc}  Acc={acc:.4f}  ConfMat:\n{cm}")

    summary_rows.append({
        "category": cat,
        "train_mean": float(recon_scores.mean()),
        "train_std": float(recon_scores.std()),
        "threshold": best_threshold,
        "image_AUROC": float(img_auroc) if img_auroc is not None else None,
        "accuracy": float(acc),
        "n_test": len(y_true)
    })

# save results
df = pd.DataFrame(rows)
df.to_csv(RESULTS_CSV, index=False)
pd.DataFrame(summary_rows).to_csv(PERF_CSV, index=False)
print("Saved per-image results to", RESULTS_CSV)
print("Saved per-category summary to", PERF_CSV)

print("\nSummary per-category:")
print(pd.DataFrame(summary_rows))

Device: mps
Using existing 'extractor' in memory.
Found categories: ['bottle', 'cable', 'capsule', 'carpet', 'grid', 'hazelnut', 'leather', 'metal_nut', 'pill', 'screw', 'tile', 'toothbrush', 'transistor', 'wood', 'zipper']
=== Evaluating category: bottle


bottle train->scores:   0%|          | 0/27 [00:00<?, ?it/s]

bottle: train_scores mean=0.056990 std=0.016726 threshold=0.107169


bottle test:   0%|          | 0/83 [00:00<?, ?it/s]

bottle  AUROC(image)=1.0  Acc=1.0000  ConfMat:
[[20  0]
 [ 0 63]]
=== Evaluating category: cable


cable train->scores:   0%|          | 0/28 [00:00<?, ?it/s]

cable: train_scores mean=0.108250 std=0.012219 threshold=0.144907


cable test:   0%|          | 0/150 [00:00<?, ?it/s]

cable  AUROC(image)=0.9576461769115443  Acc=0.8800  ConfMat:
[[56  2]
 [16 76]]
=== Evaluating category: capsule


capsule train->scores:   0%|          | 0/28 [00:00<?, ?it/s]

capsule: train_scores mean=0.053120 std=0.011687 threshold=0.088181


capsule test:   0%|          | 0/132 [00:00<?, ?it/s]

capsule  AUROC(image)=0.934982050259274  Acc=0.7424  ConfMat:
[[22  1]
 [33 76]]
=== Evaluating category: carpet


carpet train->scores:   0%|          | 0/35 [00:00<?, ?it/s]

carpet: train_scores mean=0.036568 std=0.007316 threshold=0.058516


carpet test:   0%|          | 0/118 [00:00<?, ?it/s]

carpet  AUROC(image)=0.9742063492063492  Acc=0.9407  ConfMat:
[[26  2]
 [ 5 85]]
=== Evaluating category: grid


grid train->scores:   0%|          | 0/33 [00:00<?, ?it/s]

grid: train_scores mean=0.069295 std=0.007374 threshold=0.091418


grid test:   0%|          | 0/78 [00:00<?, ?it/s]

grid  AUROC(image)=0.9741019214703426  Acc=0.9487  ConfMat:
[[21  0]
 [ 4 53]]
=== Evaluating category: hazelnut


hazelnut train->scores:   0%|          | 0/49 [00:00<?, ?it/s]

hazelnut: train_scores mean=0.092479 std=0.013450 threshold=0.132829


hazelnut test:   0%|          | 0/110 [00:00<?, ?it/s]

hazelnut  AUROC(image)=1.0  Acc=1.0000  ConfMat:
[[40  0]
 [ 0 70]]
=== Evaluating category: leather


leather train->scores:   0%|          | 0/31 [00:00<?, ?it/s]

leather: train_scores mean=0.034241 std=0.005330 threshold=0.050232


leather test:   0%|          | 0/124 [00:00<?, ?it/s]

leather  AUROC(image)=1.0  Acc=0.9758  ConfMat:
[[29  3]
 [ 0 92]]
=== Evaluating category: metal_nut


metal_nut train->scores:   0%|          | 0/28 [00:00<?, ?it/s]

metal_nut: train_scores mean=0.081729 std=0.010479 threshold=0.113166


metal_nut test:   0%|          | 0/115 [00:00<?, ?it/s]

metal_nut  AUROC(image)=1.0  Acc=1.0000  ConfMat:
[[22  0]
 [ 0 93]]
=== Evaluating category: pill


pill train->scores:   0%|          | 0/34 [00:00<?, ?it/s]

pill: train_scores mean=0.060661 std=0.010564 threshold=0.092354


pill test:   0%|          | 0/167 [00:00<?, ?it/s]

pill  AUROC(image)=0.9623567921440263  Acc=0.7964  ConfMat:
[[ 26   0]
 [ 34 107]]
=== Evaluating category: screw


screw train->scores:   0%|          | 0/40 [00:00<?, ?it/s]

screw: train_scores mean=0.076264 std=0.010615 threshold=0.108109


screw test:   0%|          | 0/160 [00:00<?, ?it/s]

screw  AUROC(image)=0.8466898954703833  Acc=0.6250  ConfMat:
[[40  1]
 [59 60]]
=== Evaluating category: tile


tile train->scores:   0%|          | 0/29 [00:00<?, ?it/s]

tile: train_scores mean=0.051160 std=0.013139 threshold=0.090578


tile test:   0%|          | 0/117 [00:00<?, ?it/s]

tile  AUROC(image)=1.0  Acc=1.0000  ConfMat:
[[33  0]
 [ 0 84]]
=== Evaluating category: toothbrush


toothbrush train->scores:   0%|          | 0/8 [00:00<?, ?it/s]

toothbrush: train_scores mean=0.097803 std=0.010181 threshold=0.128346


toothbrush test:   0%|          | 0/42 [00:00<?, ?it/s]

toothbrush  AUROC(image)=0.9361111111111111  Acc=0.8333  ConfMat:
[[ 9  3]
 [ 4 26]]
=== Evaluating category: transistor


transistor train->scores:   0%|          | 0/27 [00:00<?, ?it/s]

transistor: train_scores mean=0.089458 std=0.011488 threshold=0.123923


transistor test:   0%|          | 0/100 [00:00<?, ?it/s]

transistor  AUROC(image)=0.96875  Acc=0.9400  ConfMat:
[[59  1]
 [ 5 35]]
=== Evaluating category: wood


wood train->scores:   0%|          | 0/31 [00:00<?, ?it/s]

wood: train_scores mean=0.050232 std=0.015318 threshold=0.096185


wood test:   0%|          | 0/79 [00:00<?, ?it/s]

wood  AUROC(image)=0.9859649122807018  Acc=0.9494  ConfMat:
[[16  3]
 [ 1 59]]
=== Evaluating category: zipper


zipper train->scores:   0%|          | 0/30 [00:00<?, ?it/s]

zipper: train_scores mean=0.038016 std=0.009355 threshold=0.066080


zipper test:   0%|          | 0/151 [00:00<?, ?it/s]

zipper  AUROC(image)=0.9692752100840336  Acc=0.9404  ConfMat:
[[ 28   4]
 [  5 114]]
Saved per-image results to cae_image_results.csv
Saved per-category summary to cae_per_category_summary.csv

Summary per-category:
      category  train_mean  train_std  threshold  image_AUROC  accuracy  \
0       bottle    0.056990   0.016726   0.107169     1.000000  1.000000   
1        cable    0.108250   0.012219   0.144907     0.957646  0.880000   
2      capsule    0.053120   0.011687   0.088181     0.934982  0.742424   
3       carpet    0.036568   0.007316   0.058516     0.974206  0.940678   
4         grid    0.069295   0.007374   0.091418     0.974102  0.948718   
5     hazelnut    0.092479   0.013450   0.132829     1.000000  1.000000   
6      leather    0.034241   0.005330   0.050232     1.000000  0.975806   
7    metal_nut    0.081729   0.010479   0.113166     1.000000  1.000000   
8         pill    0.060661   0.010564   0.092354     0.962357  0.796407   
9        screw    0.076264   0.010