In [1]:
from pathlib import Path
import numpy as np
import pandas as pd

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image

from sklearn.metrics import (
    precision_score,
    recall_score,
    f1_score,
    fbeta_score,
    confusion_matrix,
)

import json

PROJECT_ROOT = Path(r"C:\Users\othni\Projects\mvtec_ad")
DATA_ROOT = PROJECT_ROOT / "data"
EXPERIMENTS_DIR = PROJECT_ROOT / "experiments"
MODELS_DIR = PROJECT_ROOT / "models"
AD_MODELS_DIR = MODELS_DIR / "ad_resnet_mahalanobis"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [2]:
df_all = pd.read_csv(EXPERIMENTS_DIR / "image_level_df.csv")
df_test_all = df_all[df_all["split"] == "test"].copy().reset_index(drop=True)

print(df_all.shape)
print(df_test_all.shape)
print(df_test_all["label"].value_counts())

with open(EXPERIMENTS_DIR / "ad_resnet_mahalanobis_mvtec_all.json", "r") as f:
    ad_results = json.load(f)

(5354, 6)
(1725, 6)
label
1    1258
0     467
Name: count, dtype: int64


In [3]:
from torchvision import transforms

eval_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

class MVTecTestDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = Path(row["path"])
        label = int(row["label"])
        category = row["category"]

        img = Image.open(img_path).convert("RGB")
        if self.transform is not None:
            img = self.transform(img)

        return img, label, category

test_ds_global = MVTecTestDataset(df_test_all, transform=eval_transform)
test_loader_global = DataLoader(test_ds_global, batch_size=32, shuffle=False, num_workers=0)

len(test_ds_global)

1725

In [4]:
# Chargement du classifieur image-level global
from torchvision import models

cls_model = models.resnet18(weights=None)
in_features = cls_model.fc.in_features
cls_model.fc = nn.Linear(in_features, 1)

cls_ckpt = MODELS_DIR / "resnet18_image_level_best.pt"
cls_model.load_state_dict(torch.load(cls_ckpt, map_location=device))
cls_model = cls_model.to(device)
cls_model.eval()

# Inference sur tout le test global
all_probs = []
all_labels = []
all_categories = []

with torch.no_grad():
    for imgs, labels, cats in test_loader_global:
        imgs = imgs.to(device)
        logits = cls_model(imgs).squeeze(1)       # (B,)
        probs = torch.sigmoid(logits).cpu().numpy()

        all_probs.append(probs)
        all_labels.append(labels.numpy())
        all_categories.extend(list(cats))

cls_probs_all = np.concatenate(all_probs)      # (N_test,)
y_test_all = np.concatenate(all_labels)        # (N_test,)

print("cls_probs_all shape:", cls_probs_all.shape)
print("same size as df_test_all ?", cls_probs_all.shape[0] == len(df_test_all))

df_test_all["cls_prob"] = cls_probs_all
df_test_all["category_from_loader"] = all_categories

cls_probs_all shape: (1725,)
same size as df_test_all ? True


In [5]:
# Extracteur de features ResNet-18 (pré-entraîné, comme dans notebook 06)
resnet_feat = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
modules = list(resnet_feat.children())[:-1]
resnet_feat = nn.Sequential(*modules).to(device)
resnet_feat.eval()

def extract_features_single_loader(loader):
    all_feats = []
    all_labels = []
    all_cats = []
    with torch.no_grad():
        for imgs, labels, cats in loader:
            imgs = imgs.to(device)
            feat = resnet_feat(imgs)
            feat = feat.view(feat.size(0), -1)
            all_feats.append(feat.cpu().numpy())
            all_labels.append(labels.numpy())
            all_cats.extend(list(cats))
    feats = np.concatenate(all_feats, axis=0)
    labels = np.concatenate(all_labels, axis=0)
    return feats, labels, all_cats

feats_test_all, labels_test_all, cats_test_all = extract_features_single_loader(test_loader_global)

print(feats_test_all.shape, labels_test_all.shape, len(cats_test_all))

(1725, 512) (1725,) 1725


In [6]:
def mahalanobis_scores(X, mu, Sigma_inv):
    diff = X - mu
    left = np.dot(diff, Sigma_inv)
    m = np.sqrt(np.sum(left * diff, axis=1))
    return m

# On prépare une colonne ad_score initialisée à NaN
df_test_all["ad_score"] = np.nan

for cat in sorted(df_test_all["category"].unique()):
    print(f"Catégorie {cat} ...")
    # indices des images de test de cette catégorie
    idx_cat = df_test_all.index[df_test_all["category"] == cat].tolist()

    X_cat = feats_test_all[idx_cat]

    # Charger mu et precision pour cette catégorie
    npz_path = AD_MODELS_DIR / f"{cat}_gaussian_stats.npz"
    data = np.load(npz_path)
    mu = data["mu"]
    precision = data["precision"]

    scores_cat = mahalanobis_scores(X_cat, mu, precision)

    df_test_all.loc[idx_cat, "ad_score"] = scores_cat

df_test_all[["category", "label", "cls_prob", "ad_score"]].head()

Catégorie bottle ...
Catégorie cable ...
Catégorie capsule ...
Catégorie carpet ...
Catégorie grid ...
Catégorie hazelnut ...
Catégorie leather ...
Catégorie metal_nut ...
Catégorie pill ...
Catégorie screw ...
Catégorie tile ...
Catégorie toothbrush ...
Catégorie transistor ...
Catégorie wood ...
Catégorie zipper ...


Unnamed: 0,category,label,cls_prob,ad_score
0,bottle,1,1.0,89.614496
1,bottle,1,1.0,92.867646
2,bottle,1,1.0,128.321736
3,bottle,1,0.999992,145.374881
4,bottle,1,0.999787,89.16088


In [7]:
# Seuils globaux pour le classifieur image-level
with open(EXPERIMENTS_DIR / "image_level_thresholds.json", "r") as f:
    cls_thr = json.load(f)

tau_img_F1   = cls_thr["image_level_global"]["tau_F1"]
tau_img_safe = cls_thr["image_level_global"]["tau_safe"]

print("tau_img_F1  =", tau_img_F1)
print("tau_img_safe=", tau_img_safe)

# On a déjà ad_results (JSON des AD) dans ce notebook, sinon recharger:
with open(EXPERIMENTS_DIR / "ad_resnet_mahalanobis_mvtec_all.json", "r") as f:
    ad_results = json.load(f)

tau_img_F1  = 0.8200000000000001
tau_img_safe= 0.023


In [8]:
from sklearn.metrics import precision_score, recall_score, f1_score, fbeta_score, confusion_matrix

def eval_policy(df, y_pred, name=""):
    y_true = df["label"].values.astype(int)

    precision = precision_score(y_true, y_pred, zero_division=0)
    recall    = recall_score(y_true, y_pred, zero_division=0)
    f1        = f1_score(y_true, y_pred, zero_division=0)
    f2        = fbeta_score(y_true, y_pred, beta=2, zero_division=0)
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()

    print(f"\n=== {name} ===")
    print(f"Precision : {precision:.4f}")
    print(f"Recall    : {recall:.4f}")
    print(f"F1        : {f1:.4f}")
    print(f"F2        : {f2:.4f}")
    print(f"TN={tn}, FP={fp}, FN={fn}, TP={tp}")

    return dict(precision=precision, recall=recall, f1=f1, f2=f2,
                tn=tn, fp=fp, fn=fn, tp=tp)

In [9]:
y_pred_AD_SAFE = np.zeros(len(df_test_all), dtype=int)

for i, row in df_test_all.iterrows():
    cat = row["category"]
    score = row["ad_score"]
    tau_safe_cat = ad_results[cat]["tau_safe"]
    # anomalie si score >= seuil_safe(cat)
    y_pred_AD_SAFE[i] = int(score >= tau_safe_cat)

res_A = eval_policy(df_test_all, y_pred_AD_SAFE, "Policy A : AD_SAFE seul (toutes catégories)")


=== Policy A : AD_SAFE seul (toutes catégories) ===
Precision : 0.7833
Recall    : 1.0000
F1        : 0.8785
F2        : 0.9476
TN=119, FP=348, FN=0, TP=1258


In [10]:
from sklearn.metrics import precision_score, recall_score, f1_score, fbeta_score, confusion_matrix

def eval_policy(df, y_pred, name=""):
    y_true = df["label"].values.astype(int)

    precision = precision_score(y_true, y_pred, zero_division=0)
    recall    = recall_score(y_true, y_pred, zero_division=0)
    f1        = f1_score(y_true, y_pred, zero_division=0)
    f2        = fbeta_score(y_true, y_pred, beta=2, zero_division=0)
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()

    print(f"\n=== {name} ===")
    print(f"Precision : {precision:.4f}")
    print(f"Recall    : {recall:.4f}")
    print(f"F1        : {f1:.4f}")
    print(f"F2        : {f2:.4f}")
    print(f"TN={tn}, FP={fp}, FN={fn}, TP={tp}")

    return dict(precision=precision, recall=recall, f1=f1, f2=f2,
                tn=tn, fp=fp, fn=fn, tp=tp)

In [11]:
import numpy as np

N = len(df_test_all)
y_pred_A = np.zeros(N, dtype=int)

for i, row in df_test_all.iterrows():
    cat = row["category"]
    score = row["ad_score"]
    tau_safe_cat = ad_results[cat]["tau_safe"]
    y_pred_A[i] = int(score >= tau_safe_cat)   # anomalie si score >= seuil SAFE cat.

res_A = eval_policy(df_test_all, y_pred_A, "Policy A : AD_SAFE seul (toutes catégories)")
res_A


=== Policy A : AD_SAFE seul (toutes catégories) ===
Precision : 0.7833
Recall    : 1.0000
F1        : 0.8785
F2        : 0.9476
TN=119, FP=348, FN=0, TP=1258


{'precision': 0.7833125778331258,
 'recall': 1.0,
 'f1': 0.8784916201117319,
 'f2': 0.9475745706538113,
 'tn': np.int64(119),
 'fp': np.int64(348),
 'fn': np.int64(0),
 'tp': np.int64(1258)}

In [12]:
y_pred_B = np.zeros(N, dtype=int)

for i, row in df_test_all.iterrows():
    cat = row["category"]
    score_ad = row["ad_score"]
    score_cls = row["cls_prob"]

    tau_safe_cat = ad_results[cat]["tau_safe"]

    pred_ad_safe  = (score_ad >= tau_safe_cat)
    pred_cls_safe = (score_cls >= tau_img_safe)

    y_pred_B[i] = int(pred_ad_safe or pred_cls_safe)

res_B = eval_policy(df_test_all, y_pred_B, "Policy B : AD_SAFE OR ResNet_SAFE")
res_B


=== Policy B : AD_SAFE OR ResNet_SAFE ===
Precision : 0.7601
Recall    : 1.0000
F1        : 0.8637
F2        : 0.9406
TN=70, FP=397, FN=0, TP=1258


{'precision': 0.7601208459214501,
 'recall': 1.0,
 'f1': 0.8637143837967731,
 'f2': 0.9406310752205772,
 'tn': np.int64(70),
 'fp': np.int64(397),
 'fn': np.int64(0),
 'tp': np.int64(1258)}

In [13]:
y_pred_C = np.zeros(N, dtype=int)

for i, row in df_test_all.iterrows():
    cat = row["category"]
    score_ad = row["ad_score"]
    score_cls = row["cls_prob"]

    tau_safe_cat = ad_results[cat]["tau_safe"]

    pred_ad_safe   = (score_ad >= tau_safe_cat)
    pred_cls_F1    = (score_cls >= tau_img_F1)

    y_pred_C[i] = int(pred_ad_safe and pred_cls_F1)

res_C = eval_policy(df_test_all, y_pred_C, "Policy C : AD_SAFE AND ResNet_F1*")
res_C


=== Policy C : AD_SAFE AND ResNet_F1* ===
Precision : 0.9943
Recall    : 0.9714
F1        : 0.9827
F2        : 0.9759
TN=460, FP=7, FN=36, TP=1222


{'precision': 0.9943043124491456,
 'recall': 0.9713831478537361,
 'f1': 0.9827100924809007,
 'f2': 0.9758824468934675,
 'tn': np.int64(460),
 'fp': np.int64(7),
 'fn': np.int64(36),
 'tp': np.int64(1222)}

In [14]:
# Récupérer les précisions SAFE par catégorie
prec_safe_by_cat = {
    cat: ad_results[cat]["metrics_safe"]["precision"]
    for cat in ad_results.keys()
}

strong_threshold = 0.9
tau_stage2 = 0.60   # seuil ResNet pour les catégories moyennes/faibles

strong_cats = [cat for cat, p in prec_safe_by_cat.items() if p >= strong_threshold]
print("Catégories 'fortes' AD (prec_safe >= 0.9):", strong_cats)

y_pred_D = np.zeros(N, dtype=int)

for i, row in df_test_all.iterrows():
    cat = row["category"]
    score_ad = row["ad_score"]
    score_cls = row["cls_prob"]

    tau_safe_cat = ad_results[cat]["tau_safe"]
    prec_safe_cat = prec_safe_by_cat[cat]

    pred_ad_safe = (score_ad >= tau_safe_cat)

    if prec_safe_cat >= strong_threshold:
        # Catégories fortes : AD_SAFE suffit
        y_pred_D[i] = int(pred_ad_safe)
    else:
        # Catégories moyennes/faibles : cascade
        pred_cls_stage2 = (score_cls >= tau_stage2)
        y_pred_D[i] = int(pred_ad_safe and pred_cls_stage2)

res_D = eval_policy(df_test_all, y_pred_D,
                    f"Policy D : AD_SAFE seul (fortes) / AD_SAFE AND ResNet({tau_stage2}) (autres)")
res_D

Catégories 'fortes' AD (prec_safe >= 0.9): ['bottle', 'metal_nut', 'toothbrush']

=== Policy D : AD_SAFE seul (fortes) / AD_SAFE AND ResNet(0.6) (autres) ===
Precision : 0.9756
Recall    : 0.9873
F1        : 0.9814
F2        : 0.9849
TN=436, FP=31, FN=16, TP=1242


{'precision': 0.9756480754124116,
 'recall': 0.9872813990461049,
 'f1': 0.9814302647175029,
 'f2': 0.9849325931800159,
 'tn': np.int64(436),
 'fp': np.int64(31),
 'fn': np.int64(16),
 'tp': np.int64(1242)}

In [15]:
# On suppose que y_pred_A, y_pred_B, y_pred_C, y_pred_D existent déjà
df_test_all["y_A"] = y_pred_A
df_test_all["y_B"] = y_pred_B
df_test_all["y_C"] = y_pred_C
df_test_all["y_D"] = y_pred_D

df_test_all.head()

Unnamed: 0,path,category,split,label,defect_type,final_split,cls_prob,category_from_loader,ad_score,y_A,y_B,y_C,y_D
0,C:\Users\othni\Projects\mvtec_ad\data\bottle\t...,bottle,test,1,broken_large,test,1.0,bottle,89.614496,1,1,1,1
1,C:\Users\othni\Projects\mvtec_ad\data\bottle\t...,bottle,test,1,broken_large,val,1.0,bottle,92.867646,1,1,1,1
2,C:\Users\othni\Projects\mvtec_ad\data\bottle\t...,bottle,test,1,broken_large,val,1.0,bottle,128.321736,1,1,1,1
3,C:\Users\othni\Projects\mvtec_ad\data\bottle\t...,bottle,test,1,broken_large,test,0.999992,bottle,145.374881,1,1,1,1
4,C:\Users\othni\Projects\mvtec_ad\data\bottle\t...,bottle,test,1,broken_large,test,0.999787,bottle,89.16088,1,1,1,1


In [16]:
def per_category_stats(df, pred_col):
    rows = []
    for cat, df_cat in df.groupby("category"):
        y_true = df_cat["label"].values.astype(int)
        y_pred = df_cat[pred_col].values.astype(int)

        tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
        precision = precision_score(y_true, y_pred, zero_division=0)
        recall    = recall_score(y_true, y_pred, zero_division=0)
        f1        = f1_score(y_true, y_pred, zero_division=0)
        f2        = fbeta_score(y_true, y_pred, beta=2, zero_division=0)

        rows.append({
            "category": cat,
            "n_test": len(df_cat),
            "TP": tp, "FP": fp, "TN": tn, "FN": fn,
            "precision": precision,
            "recall": recall,
            "f1": f1,
            "f2": f2,
        })
    return pd.DataFrame(rows).sort_values("recall", ascending=True).reset_index(drop=True)

In [17]:
stats_A = per_category_stats(df_test_all, "y_A")
print("=== Policy A : AD_SAFE seul (par catégorie) ===")
display(stats_A)

=== Policy A : AD_SAFE seul (par catégorie) ===


Unnamed: 0,category,n_test,TP,FP,TN,FN,precision,recall,f1,f2
0,bottle,83,63,4,16,0,0.940299,1.0,0.969231,0.987461
1,cable,150,92,43,15,0,0.681481,1.0,0.810573,0.914513
2,capsule,132,109,16,7,0,0.872,1.0,0.931624,0.97148
3,carpet,117,89,28,0,0,0.760684,1.0,0.864078,0.940803
4,grid,78,57,21,0,0,0.730769,1.0,0.844444,0.931373
5,hazelnut,110,70,19,21,0,0.786517,1.0,0.880503,0.948509
6,leather,124,92,20,12,0,0.821429,1.0,0.901961,0.958333
7,metal_nut,115,93,10,12,0,0.902913,1.0,0.94898,0.978947
8,pill,167,141,25,1,0,0.849398,1.0,0.918567,0.965753
9,screw,160,119,41,0,0,0.74375,1.0,0.853047,0.935535


In [18]:
stats_D = per_category_stats(df_test_all, "y_D")
print("=== Policy D : par catégorie ===")
display(stats_D)

=== Policy D : par catégorie ===


Unnamed: 0,category,n_test,TP,FP,TN,FN,precision,recall,f1,f2
0,transistor,100,37,0,60,3,1.0,0.925,0.961039,0.939086
1,grid,78,54,0,21,3,1.0,0.947368,0.972973,0.957447
2,tile,117,81,0,33,3,1.0,0.964286,0.981818,0.971223
3,wood,79,58,0,19,2,1.0,0.966667,0.983051,0.973154
4,cable,150,90,0,58,2,1.0,0.978261,0.989011,0.982533
5,carpet,117,88,0,28,1,1.0,0.988764,0.99435,0.990991
6,capsule,132,108,7,16,1,0.93913,0.990826,0.964286,0.980036
7,zipper,151,118,0,32,1,1.0,0.991597,0.995781,0.993266
8,metal_nut,115,93,10,12,0,0.902913,1.0,0.94898,0.978947
9,leather,124,92,0,32,0,1.0,1.0,1.0,1.0


In [19]:
fn_D = df_test_all[(df_test_all["label"] == 1) & (df_test_all["y_D"] == 0)].copy()

print("Nombre de FN Policy D :", len(fn_D))
fn_D = fn_D.sort_values(["category", "ad_score", "cls_prob"], ascending=[True, True, False])

# On affiche les colonnes importantes
cols = ["category", "defect_type", "path", "ad_score", "cls_prob"]
# si defect_type n'existe pas, enlève-la dans la liste ci-dessus
cols = [c for c in cols if c in fn_D.columns]

display(fn_D[cols].head(20))

Nombre de FN Policy D : 16


Unnamed: 0,category,defect_type,path,ad_score,cls_prob
225,cable,poke_insulation,C:\Users\othni\Projects\mvtec_ad\data\cable\te...,18.373501,0.226542
103,cable,cable_swap,C:\Users\othni\Projects\mvtec_ad\data\cable\te...,21.914021,0.16203
354,capsule,squeeze,C:\Users\othni\Projects\mvtec_ad\data\capsule\...,66.040046,0.489861
466,carpet,thread,C:\Users\othni\Projects\mvtec_ad\data\carpet\t...,44.9667,0.473622
514,grid,glue,C:\Users\othni\Projects\mvtec_ad\data\grid\tes...,18.352708,0.253491
507,grid,glue,C:\Users\othni\Projects\mvtec_ad\data\grid\tes...,26.98703,0.039458
552,grid,thread,C:\Users\othni\Projects\mvtec_ad\data\grid\tes...,34.47305,0.255354
1316,tile,gray_stroke,C:\Users\othni\Projects\mvtec_ad\data\tile\tes...,20.732866,0.047593
1306,tile,gray_stroke,C:\Users\othni\Projects\mvtec_ad\data\tile\tes...,24.806133,0.023159
1315,tile,gray_stroke,C:\Users\othni\Projects\mvtec_ad\data\tile\tes...,35.807127,0.199871


In [20]:
stats_compare = stats_A[["category", "n_test", "TP", "FP", "FN"]].merge(
    stats_D[["category", "TP", "FP", "FN"]],
    on="category",
    suffixes=("_A", "_D")
)

stats_compare["d_FP"] = stats_compare["FP_D"] - stats_compare["FP_A"]
stats_compare["d_FN"] = stats_compare["FN_D"] - stats_compare["FN_A"]

print("Différence D - A (FP et FN) par catégorie :")
display(stats_compare.sort_values("d_FN"))

Différence D - A (FP et FN) par catégorie :


Unnamed: 0,category,n_test,TP_A,FP_A,FN_A,TP_D,FP_D,FN_D,d_FP,d_FN
0,bottle,83,63,4,0,63,4,0,0,0
7,metal_nut,115,93,10,0,93,10,0,0,0
6,leather,124,92,20,0,92,0,0,-20,0
5,hazelnut,110,70,19,0,70,0,0,-19,0
9,screw,160,119,41,0,119,0,0,-41,0
11,toothbrush,42,30,3,0,30,3,0,0,0
8,pill,167,141,25,0,141,7,0,-18,0
2,capsule,132,109,16,0,108,7,1,-9,1
3,carpet,117,89,28,0,88,0,1,-28,1
14,zipper,151,119,30,0,118,0,1,-30,1


In [21]:
import json

config = {
    "resnet_image_level": {
        "tau_F1": float(tau_img_F1),
        "tau_safe": float(tau_img_safe),
    },
    "ad_mahalanobis": {
        cat: {
            "tau_F1": float(ad_results[cat]["tau_F1"]),
            "tau_F2": float(ad_results[cat]["tau_F2"]),
            "tau_safe": float(ad_results[cat]["tau_safe"]),
            "prec_safe": float(ad_results[cat]["metrics_safe"]["precision"]),
            "recall_safe": float(ad_results[cat]["metrics_safe"]["recall"]),
        }
        for cat in sorted(ad_results.keys())
    },
    "policies": {
        "SAFE_0FN": {
            "type": "AD_SAFE_ONLY"  # y = 1 si ad_score >= tau_safe(cat)
        },
        "INDUSTRIAL_BALANCED": {
            "type": "MIXED",
            "strong_cats": ["bottle", "metal_nut", "toothbrush"],
            "tau_stage2": 0.60  # seuil ResNet pour les autres catégories
        },
        "PRECISION": {
            "type": "AND_F1",  # y = 1 si AD_SAFE AND (cls_prob >= tau_F1)
        },
    },
}

config_path = EXPERIMENTS_DIR / "image_level_policies_config.json"
with open(config_path, "w") as f:
    json.dump(config, f, indent=2)

config_path

WindowsPath('C:/Users/othni/Projects/mvtec_ad/experiments/image_level_policies_config.json')