In [1]:
from pathlib import Path
import numpy as np
import pandas as pd

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image

from sklearn.covariance import EmpiricalCovariance
from sklearn.metrics import (
    roc_auc_score,
    precision_score,
    recall_score,
    f1_score,
    fbeta_score,
    confusion_matrix,
)

import json
import matplotlib.pyplot as plt

PROJECT_ROOT = Path(r"C:\Users\othni\Projects\mvtec_ad")
DATA_ROOT = PROJECT_ROOT / "data"
EXPERIMENTS_DIR = PROJECT_ROOT / "experiments"
MODELS_DIR = PROJECT_ROOT / "models"
AD_MODELS_DIR = MODELS_DIR / "ad_resnet_mahalanobis"
AD_MODELS_DIR.mkdir(exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [2]:
df_all = pd.read_csv(EXPERIMENTS_DIR / "image_level_df.csv")
print(df_all.shape)
print(df_all["category"].value_counts())

(5354, 6)
category
hazelnut      501
screw         480
pill          434
carpet        397
zipper        391
cable         374
leather       369
capsule       351
tile          347
grid          342
metal_nut     335
wood          326
transistor    313
bottle        292
toothbrush    102
Name: count, dtype: int64


In [3]:
class MVTecADCategoryDataset(Dataset):
    """
    Dataset pour une catégorie donnée (train ou test),
    pour l'anomaly detection image-level.
    """
    def __init__(self, df_split, transform=None):
        self.df = df_split.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = Path(row["path"])
        label = int(row["label"])  # 0 = normal, 1 = défaut

        img = Image.open(img_path).convert("RGB")
        if self.transform is not None:
            img = self.transform(img)

        return img, label

In [4]:
# Transforms "eval" (comme avant)
eval_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# Extracteur de features ResNet-18 (sans la couche fc)
resnet_feat = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
modules = list(resnet_feat.children())[:-1]  # enlever la couche fc
resnet_feat = nn.Sequential(*modules).to(device)
resnet_feat.eval()

# Test rapide
dummy = torch.zeros(1, 3, 224, 224).to(device)
with torch.no_grad():
    feat_dummy = resnet_feat(dummy)
print("Feature dummy shape:", feat_dummy.shape)  # (1, 512, 1, 1)

Feature dummy shape: torch.Size([1, 512, 1, 1])


In [5]:
def extract_features(loader):
    """
    Retourne:
      feats  : (N, 512)
      labels : (N,)
    """
    all_feats = []
    all_labels = []

    with torch.no_grad():
        for imgs, labels in loader:
            imgs = imgs.to(device)
            feat = resnet_feat(imgs)              # (B, 512, 1, 1)
            feat = feat.view(feat.size(0), -1)    # (B, 512)
            all_feats.append(feat.cpu().numpy())
            all_labels.append(labels.numpy())

    feats = np.concatenate(all_feats, axis=0)
    labels = np.concatenate(all_labels, axis=0)
    return feats, labels

In [6]:
def fit_gaussian(normal_feats):
    """
    Apprend la moyenne et la matrice de précision (Sigma^{-1})
    sur les features des images normales.
    """
    cov = EmpiricalCovariance().fit(normal_feats)
    mu = cov.location_          # (D,)
    Sigma_inv = cov.precision_  # (D, D)
    return mu, Sigma_inv

def mahalanobis_scores(X, mu, Sigma_inv):
    """
    X : (N, D)
    mu : (D,)
    Sigma_inv : (D, D)
    """
    diff = X - mu
    left = np.dot(diff, Sigma_inv)
    m = np.sqrt(np.sum(left * diff, axis=1))
    return m

def sweep_thresholds(scores, labels, n_steps=500):
    """
    Balaye les seuils entre min(scores) et max(scores)
    et calcule precision, recall, F1, F2.
    """
    taus = np.linspace(scores.min(), scores.max(), n_steps)
    rows = []

    for tau in taus:
        y_pred = (scores >= tau).astype(int)  # score élevé => anomalie

        precision = precision_score(labels, y_pred, zero_division=0)
        recall    = recall_score(labels, y_pred, zero_division=0)
        f1        = f1_score(labels, y_pred, zero_division=0)
        f2        = fbeta_score(labels, y_pred, beta=2, zero_division=0)

        rows.append({
            "tau": tau,
            "precision": precision,
            "recall": recall,
            "f1": f1,
            "f2": f2,
        })

    return pd.DataFrame(rows)

In [7]:
categories = sorted(df_all["category"].unique())
categories

['bottle',
 'cable',
 'capsule',
 'carpet',
 'grid',
 'hazelnut',
 'leather',
 'metal_nut',
 'pill',
 'screw',
 'tile',
 'toothbrush',
 'transistor',
 'wood',
 'zipper']

In [8]:
ad_results = {}

for cat in categories:
    print(f"\n====================== {cat} ======================")
    df_cat = df_all[df_all["category"] == cat].copy()

    # Train: uniquement train/good, label=0
    df_train_cat = df_cat[(df_cat["split"] == "train") & (df_cat["label"] == 0)].copy()
    # Test: toutes les images de test (good + défauts)
    df_test_cat  = df_cat[df_cat["split"] == "test"].copy()

    print("Train normals :", df_train_cat.shape[0])
    print("Test total    :", df_test_cat.shape[0])
    print("Test labels:\n", df_test_cat["label"].value_counts())

    # Datasets & loaders
    train_ds = MVTecADCategoryDataset(df_train_cat, transform=eval_transform)
    test_ds  = MVTecADCategoryDataset(df_test_cat,  transform=eval_transform)

    train_loader = DataLoader(train_ds, batch_size=32, shuffle=False, num_workers=0)
    test_loader  = DataLoader(test_ds,  batch_size=32, shuffle=False, num_workers=0)

    # Features
    train_feats, train_labels = extract_features(train_loader)
    test_feats,  test_labels  = extract_features(test_loader)

    # Fit gaussienne sur train normals
    normal_feats = train_feats[train_labels == 0]
    print("Normal_feats shape:", normal_feats.shape)

    mu, Sigma_inv = fit_gaussian(normal_feats)

    # Scores d'anomalie sur le test
    test_scores = mahalanobis_scores(test_feats, mu, Sigma_inv)

    # AUROC
    auc_test = roc_auc_score(test_labels, test_scores)
    print("Test AUROC:", auc_test)

    # Sweep de seuils
    res_thr = sweep_thresholds(test_scores, test_labels, n_steps=500)

    # Seuils intéressants
    i_f1 = res_thr["f1"].idxmax()
    i_f2 = res_thr["f2"].idxmax()
    row_f1 = res_thr.loc[i_f1]
    row_f2 = res_thr.loc[i_f2]

    # Mode SAFE : rappel = 1.0 si possible
    safe_candidates = res_thr[res_thr["recall"] == 1.0]
    if len(safe_candidates) > 0:
        i_safe = safe_candidates["f1"].idxmax()
        row_safe = res_thr.loc[i_safe]
    else:
        # fallback: max F2
        row_safe = row_f2

    print("  Max F1 :", dict(row_f1))
    print("  Max F2 :", dict(row_f2))
    print("  SAFE   :", dict(row_safe))

    # Sauvegarde des paramètres gaussiens pour cette catégorie
    npz_path = AD_MODELS_DIR / f"{cat}_gaussian_stats.npz"
    np.savez(
        npz_path,
        mu=mu,
        precision=Sigma_inv,
    )

    # Sauvegarde des résultats dans un dict
    ad_results[cat] = {
        "auc_test": float(auc_test),
        "tau_F1": float(row_f1["tau"]),
        "tau_F2": float(row_f2["tau"]),
        "tau_safe": float(row_safe["tau"]),
        "metrics_F1": {
            "precision": float(row_f1["precision"]),
            "recall":    float(row_f1["recall"]),
            "f1":        float(row_f1["f1"]),
            "f2":        float(row_f1["f2"]),
        },
        "metrics_F2": {
            "precision": float(row_f2["precision"]),
            "recall":    float(row_f2["recall"]),
            "f1":        float(row_f2["f1"]),
            "f2":        float(row_f2["f2"]),
        },
        "metrics_safe": {
            "precision": float(row_safe["precision"]),
            "recall":    float(row_safe["recall"]),
            "f1":        float(row_safe["f1"]),
            "f2":        float(row_safe["f2"]),
        },
        "n_train_normals": int(normal_feats.shape[0]),
        "n_test": int(test_labels.shape[0]),
    }


Train normals : 209
Test total    : 83
Test labels:
 label
1    63
0    20
Name: count, dtype: int64
Normal_feats shape: (209, 512)
Test AUROC: 0.9952380952380953
  Max F1 : {'tau': np.float64(32.04131840277652), 'precision': np.float64(1.0), 'recall': np.float64(0.9682539682539683), 'f1': np.float64(0.9838709677419355), 'f2': np.float64(0.9744408945686901)}
  Max F2 : {'tau': np.float64(24.861816682001834), 'precision': np.float64(0.9402985074626866), 'recall': np.float64(1.0), 'f1': np.float64(0.9692307692307692), 'f2': np.float64(0.987460815047022)}
  SAFE   : {'tau': np.float64(24.861816682001834), 'precision': np.float64(0.9402985074626866), 'recall': np.float64(1.0), 'f1': np.float64(0.9692307692307692), 'f2': np.float64(0.987460815047022)}

Train normals : 224
Test total    : 150
Test labels:
 label
1    92
0    58
Name: count, dtype: int64
Normal_feats shape: (224, 512)
Test AUROC: 0.8834332833583208
  Max F1 : {'tau': np.float64(22.489762399576815), 'precision': np.float64(0.

In [9]:
ad_results_path = EXPERIMENTS_DIR / "ad_resnet_mahalanobis_mvtec_all.json"
with open(ad_results_path, "w") as f:
    json.dump(ad_results, f, indent=2)

ad_results_path

WindowsPath('C:/Users/othni/Projects/mvtec_ad/experiments/ad_resnet_mahalanobis_mvtec_all.json')

In [10]:
import json
import pandas as pd

ad_results_path = PROJECT_ROOT / "experiments" / "ad_resnet_mahalanobis_mvtec_all.json"

with open(ad_results_path, "r") as f:
    ad_results = json.load(f)

rows = []
for cat, info in ad_results.items():
    row = {
        "category": cat,
        "auc_test": info["auc_test"],
        "tau_F1": info["tau_F1"],
        "tau_F2": info["tau_F2"],
        "tau_safe": info["tau_safe"],
        "prec_safe": info["metrics_safe"]["precision"],
        "recall_safe": info["metrics_safe"]["recall"],
        "f1_safe": info["metrics_safe"]["f1"],
        "f2_safe": info["metrics_safe"]["f2"],
        "n_train_normals": info["n_train_normals"],
        "n_test": info["n_test"],
    }
    rows.append(row)

df_ad = pd.DataFrame(rows)
df_ad = df_ad.sort_values("prec_safe", ascending=False).reset_index(drop=True)
df_ad

Unnamed: 0,category,auc_test,tau_F1,tau_F2,tau_safe,prec_safe,recall_safe,f1_safe,f2_safe,n_train_normals,n_test
0,bottle,0.995238,32.041318,24.861817,24.861817,0.940299,1.0,0.969231,0.987461,209,83
1,toothbrush,0.955556,7.3834,7.3834,7.3834,0.909091,1.0,0.952381,0.980392,60,42
2,metal_nut,0.937439,20.349246,20.165997,20.165997,0.902913,1.0,0.94898,0.978947,220,115
3,capsule,0.878341,19.341192,19.341192,19.341192,0.872,1.0,0.931624,0.97148,219,132
4,pill,0.81533,23.869418,19.386349,19.386349,0.849398,1.0,0.918567,0.965753,267,167
5,leather,0.971467,34.658128,31.070109,27.48209,0.821429,1.0,0.901961,0.958333,245,124
6,wood,0.935088,45.873717,28.664655,28.664655,0.8,1.0,0.888889,0.952381,247,79
7,zipper,0.96166,26.418735,23.794852,19.59664,0.798658,1.0,0.88806,0.952,240,151
8,hazelnut,0.981786,49.739843,44.907217,40.880028,0.786517,1.0,0.880503,0.948509,391,110
9,tile,0.979076,26.739869,26.739869,19.477031,0.763636,1.0,0.865979,0.941704,230,117


In [11]:
print("=== Triées par précision en mode SAFE ===")
display(df_ad[["category", "auc_test", "prec_safe", "recall_safe", "n_test"]])

print("\nCatégories AD très fortes (prec_safe >= 0.9) :")
print(df_ad[df_ad["prec_safe"] >= 0.9]["category"].tolist())

print("\nCatégories AD moyennes (0.7 <= prec_safe < 0.9) :")
print(df_ad[(df_ad["prec_safe"] >= 0.7) & (df_ad["prec_safe"] < 0.9)]["category"].tolist())

print("\nCatégories AD faibles (prec_safe < 0.7) :")
print(df_ad[df_ad["prec_safe"] < 0.7]["category"].tolist())

=== Triées par précision en mode SAFE ===


Unnamed: 0,category,auc_test,prec_safe,recall_safe,n_test
0,bottle,0.995238,0.940299,1.0,83
1,toothbrush,0.955556,0.909091,1.0,42
2,metal_nut,0.937439,0.902913,1.0,115
3,capsule,0.878341,0.872,1.0,132
4,pill,0.81533,0.849398,1.0,167
5,leather,0.971467,0.821429,1.0,124
6,wood,0.935088,0.8,1.0,79
7,zipper,0.96166,0.798658,1.0,151
8,hazelnut,0.981786,0.786517,1.0,110
9,tile,0.979076,0.763636,1.0,117



Catégories AD très fortes (prec_safe >= 0.9) :
['bottle', 'toothbrush', 'metal_nut']

Catégories AD moyennes (0.7 <= prec_safe < 0.9) :
['capsule', 'pill', 'leather', 'wood', 'zipper', 'hazelnut', 'tile', 'carpet', 'screw', 'grid']

Catégories AD faibles (prec_safe < 0.7) :
['cable', 'transistor']
