In [1]:
from pathlib import Path
import json
import numpy as np
from PIL import Image

import torch
from torch import nn
from torchvision import models, transforms
from scipy.spatial import distance

PROJECT_ROOT = Path(r"C:\Users\othni\Projects\mvtec_ad")
DATA_ROOT = PROJECT_ROOT / "data"
EXPERIMENTS_DIR = PROJECT_ROOT / "experiments"
MODELS_DIR = PROJECT_ROOT / "models"
AD_MODELS_DIR = MODELS_DIR / "ad_resnet_mahalanobis"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [2]:
# 1) Config
with open(EXPERIMENTS_DIR / "image_level_policies_config.json", "r") as f:
    cfg = json.load(f)

# 2) Classifieur ResNet (image-level)
cls_model = models.resnet18(weights=None)
in_features = cls_model.fc.in_features
cls_model.fc = nn.Linear(in_features, 1)
cls_ckpt = MODELS_DIR / "resnet18_image_level_best.pt"
cls_model.load_state_dict(torch.load(cls_ckpt, map_location=device))
cls_model = cls_model.to(device)
cls_model.eval()

# 3) Extracteur de features ResNet pour AD
resnet_feat = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
modules = list(resnet_feat.children())[:-1]
resnet_feat = nn.Sequential(*modules).to(device)
resnet_feat.eval()

# 4) Transform commun
eval_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

In [3]:
def compute_scores_for_image(img_path: str, category: str):
    img = Image.open(img_path).convert("RGB")
    x = eval_transform(img).unsqueeze(0).to(device)  # (1,3,224,224)

    # Score classifieur
    with torch.no_grad():
        logit = cls_model(x).squeeze(1)
        cls_prob = torch.sigmoid(logit).item()

    # Score AD (Mahalanobis)
    npz_path = AD_MODELS_DIR / f"{category}_gaussian_stats.npz"
    data = np.load(npz_path)
    mu = data["mu"]
    precision = data["precision"]

    with torch.no_grad():
        feat = resnet_feat(x)          # (1,512,1,1)
        feat = feat.view(1, -1).cpu().numpy()  # (1,512)

    diff = feat - mu
    left = diff @ precision
    m = float(np.sqrt(np.sum(left * diff, axis=1))[0])

    return cls_prob, m

In [4]:
def predict_image(img_path: str, category: str, mode: str = "SAFE_0FN", cfg: dict = None):
    """
    Retourne un dict avec :
      - cls_prob : proba ResNet (défaut)
      - ad_score : score Mahalanobis
      - decision : 0/1 selon le mode
      - reason   : petite explication textuelle
    """
    if cfg is None:
        raise ValueError("cfg (config) ne doit pas être None")

    img_path = str(img_path)
    category = str(category)

    cls_prob, ad_score = compute_scores_for_image(img_path, category)

    # Seuils utiles
    tau_img_F1   = cfg["resnet_image_level"]["tau_F1"]
    tau_img_safe = cfg["resnet_image_level"]["tau_safe"]
    tau_safe_cat = cfg["ad_mahalanobis"][category]["tau_safe"]

    policy_cfg = cfg["policies"][mode]

    # ------------ Mode SAFE_0FN ------------
    if mode == "SAFE_0FN":
        is_anomaly = int(ad_score >= tau_safe_cat)
        reason = f"Mode SAFE_0FN: y=1 si ad_score >= tau_safe({tau_safe_cat:.3f})"

    # ------------ Mode INDUSTRIAL_BALANCED ------------
    elif mode == "INDUSTRIAL_BALANCED":
        strong_cats = policy_cfg["strong_cats"]
        tau_stage2  = policy_cfg["tau_stage2"]

        pred_ad_safe = (ad_score >= tau_safe_cat)

        if category in strong_cats:
            is_anomaly = int(pred_ad_safe)
            reason = (
                f"Mode INDUSTRIAL_BALANCED (cat forte): "
                f"y=1 si AD_SAFE (ad_score >= {tau_safe_cat:.3f})"
            )
        else:
            pred_cls_stage2 = (cls_prob >= tau_stage2)
            is_anomaly = int(pred_ad_safe and pred_cls_stage2)
            reason = (
                f"Mode INDUSTRIAL_BALANCED (cat moyenne/faible): "
                f"y=1 si AD_SAFE AND cls_prob >= {tau_stage2:.3f}"
            )

    # ------------ Mode PRECISION ------------
    elif mode == "PRECISION":
        pred_ad_safe = (ad_score >= tau_safe_cat)
        pred_cls_F1  = (cls_prob >= tau_img_F1)
        is_anomaly = int(pred_ad_safe and pred_cls_F1)
        reason = (
            f"Mode PRECISION: y=1 si AD_SAFE AND cls_prob >= tau_F1({tau_img_F1:.3f})"
        )

    else:
        raise ValueError(f"Mode inconnu : {mode}")

    return {
        "img_path": img_path,
        "category": category,
        "cls_prob": cls_prob,
        "ad_score": ad_score,
        "decision": is_anomaly,
        "mode": mode,
        "reason": reason,
    }

In [6]:
import pandas as pd
from pathlib import Path

PROJECT_ROOT = Path(r"C:\Users\othni\Projects\mvtec_ad")
EXPERIMENTS_DIR = PROJECT_ROOT / "experiments"

# On recharge le CSV image-level construit dans les notebooks précédents
df_all = pd.read_csv(EXPERIMENTS_DIR / "image_level_df.csv")

print(df_all.shape)
print(df_all["final_split"].value_counts())

# On reconstruit df_test_all = split test uniquement
df_test_all = df_all[df_all["final_split"] == "test"].copy().reset_index(drop=True)

print("df_test_all :", df_test_all.shape)
df_test_all.head()

(5354, 6)
final_split
train    3629
test      869
val       856
Name: count, dtype: int64
df_test_all : (869, 6)


Unnamed: 0,path,category,split,label,defect_type,final_split
0,C:\Users\othni\Projects\mvtec_ad\data\bottle\t...,bottle,test,1,broken_large,test
1,C:\Users\othni\Projects\mvtec_ad\data\bottle\t...,bottle,test,1,broken_large,test
2,C:\Users\othni\Projects\mvtec_ad\data\bottle\t...,bottle,test,1,broken_large,test
3,C:\Users\othni\Projects\mvtec_ad\data\bottle\t...,bottle,test,1,broken_large,test
4,C:\Users\othni\Projects\mvtec_ad\data\bottle\t...,bottle,test,1,broken_large,test


In [7]:
test_row = df_test_all.iloc[0]
print(test_row["category"], test_row["defect_type"], test_row["path"])

res_safe = predict_image(test_row["path"], test_row["category"], mode="SAFE_0FN", cfg=cfg)
res_ind  = predict_image(test_row["path"], test_row["category"], mode="INDUSTRIAL_BALANCED", cfg=cfg)
res_prec = predict_image(test_row["path"], test_row["category"], mode="PRECISION", cfg=cfg)

res_safe, res_ind, res_prec

bottle broken_large C:\Users\othni\Projects\mvtec_ad\data\bottle\test\broken_large\000.png


({'img_path': 'C:\\Users\\othni\\Projects\\mvtec_ad\\data\\bottle\\test\\broken_large\\000.png',
  'category': 'bottle',
  'cls_prob': 0.9999997615814209,
  'ad_score': 89.61449598992749,
  'decision': 1,
  'mode': 'SAFE_0FN',
  'reason': 'Mode SAFE_0FN: y=1 si ad_score >= tau_safe(24.862)'},
 {'img_path': 'C:\\Users\\othni\\Projects\\mvtec_ad\\data\\bottle\\test\\broken_large\\000.png',
  'category': 'bottle',
  'cls_prob': 0.9999997615814209,
  'ad_score': 89.61449598992749,
  'decision': 1,
  'mode': 'INDUSTRIAL_BALANCED',
  'reason': 'Mode INDUSTRIAL_BALANCED (cat forte): y=1 si AD_SAFE (ad_score >= 24.862)'},
 {'img_path': 'C:\\Users\\othni\\Projects\\mvtec_ad\\data\\bottle\\test\\broken_large\\000.png',
  'category': 'bottle',
  'cls_prob': 0.9999997615814209,
  'ad_score': 89.61449598992749,
  'decision': 1,
  'mode': 'PRECISION',
  'reason': 'Mode PRECISION: y=1 si AD_SAFE AND cls_prob >= tau_F1(0.820)'})

In [8]:
import numpy as np
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, fbeta_score

def run_policy_on_df(df, mode: str, cfg: dict, max_samples=None):
    """
    Applique predict_image sur toutes les lignes de df (ou sur max_samples premières)
    et renvoie :
      - y_true : array des labels réels (0/1)
      - y_pred : array des prédictions (0/1) pour le mode choisi
    """
    rows = df if max_samples is None else df.iloc[:max_samples]

    y_true = []
    y_pred = []

    for i, row in rows.iterrows():
        img_path = row["path"]
        category = row["category"]

        out = predict_image(img_path, category, mode=mode, cfg=cfg)
        y_true.append(int(row["label"]))
        y_pred.append(int(out["decision"]))

    y_true = np.array(y_true, dtype=int)
    y_pred = np.array(y_pred, dtype=int)
    return y_true, y_pred

In [9]:
def compute_metrics_global(y_true, y_pred, beta=2.0):
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    precision = precision_score(y_true, y_pred, zero_division=0)
    recall    = recall_score(y_true, y_pred, zero_division=0)
    f1        = f1_score(y_true, y_pred, zero_division=0)
    f2        = fbeta_score(y_true, y_pred, beta=beta, zero_division=0)

    return {
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "f2": f2,
        "tn": tn,
        "fp": fp,
        "fn": fn,
        "tp": tp,
    }


def print_metrics(name, metrics):
    print(f"\n=== {name} ===")
    print(f"Precision : {metrics['precision']:.4f}")
    print(f"Recall    : {metrics['recall']:.4f}")
    print(f"F1        : {metrics['f1']:.4f}")
    print(f"F2        : {metrics['f2']:.4f}")
    print(f"TN={metrics['tn']}, FP={metrics['fp']}, FN={metrics['fn']}, TP={metrics['tp']}")

In [10]:
# On utilise tout le test set global (869 images)
print("Taille df_test_all :", df_test_all.shape)

# SAFE_0FN
y_true_A, y_pred_A = run_policy_on_df(df_test_all, mode="SAFE_0FN", cfg=cfg)
metrics_A = compute_metrics_global(y_true_A, y_pred_A)
print_metrics("Policy SAFE_0FN (via API)", metrics_A)

# INDUSTRIAL_BALANCED
y_true_D, y_pred_D = run_policy_on_df(df_test_all, mode="INDUSTRIAL_BALANCED", cfg=cfg)
metrics_D = compute_metrics_global(y_true_D, y_pred_D)
print_metrics("Policy INDUSTRIAL_BALANCED (via API)", metrics_D)

# PRECISION
y_true_P, y_pred_P = run_policy_on_df(df_test_all, mode="PRECISION", cfg=cfg)
metrics_P = compute_metrics_global(y_true_P, y_pred_P)
print_metrics("Policy PRECISION (via API)", metrics_P)

Taille df_test_all : (869, 6)

=== Policy SAFE_0FN (via API) ===
Precision : 0.7824
Recall    : 1.0000
F1        : 0.8779
F2        : 0.9473
TN=60, FP=176, FN=0, TP=633

=== Policy INDUSTRIAL_BALANCED (via API) ===
Precision : 0.9828
Recall    : 0.9921
F1        : 0.9874
F2        : 0.9902
TN=225, FP=11, FN=5, TP=628

=== Policy PRECISION (via API) ===
Precision : 0.9952
Recall    : 0.9826
F1        : 0.9889
F2        : 0.9851
TN=233, FP=3, FN=11, TP=622


In [11]:
def per_category_stats_via_api(df, mode: str, cfg: dict):
    rows = []
    for cat, df_cat in df.groupby("category"):
        y_true, y_pred = run_policy_on_df(df_cat, mode=mode, cfg=cfg)
        tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
        precision = precision_score(y_true, y_pred, zero_division=0)
        recall    = recall_score(y_true, y_pred, zero_division=0)
        f1        = f1_score(y_true, y_pred, zero_division=0)
        f2        = fbeta_score(y_true, y_pred, beta=2.0, zero_division=0)

        rows.append({
            "category": cat,
            "n_test": len(df_cat),
            "TP": tp, "FP": fp, "TN": tn, "FN": fn,
            "precision": precision,
            "recall": recall,
            "f1": f1,
            "f2": f2,
        })
    return pd.DataFrame(rows).sort_values("recall", ascending=True).reset_index(drop=True)

stats_ind_api = per_category_stats_via_api(df_test_all, mode="INDUSTRIAL_BALANCED", cfg=cfg)
print("=== INDUSTRIAL_BALANCED (via API) - par catégorie ===")
display(stats_ind_api)

=== INDUSTRIAL_BALANCED (via API) - par catégorie ===


Unnamed: 0,category,n_test,TP,FP,TN,FN,precision,recall,f1,f2
0,wood,40,28,0,10,2,1.0,0.933333,0.965517,0.945946
1,transistor,50,19,0,30,1,1.0,0.95,0.974359,0.959596
2,tile,59,41,0,17,1,1.0,0.97619,0.987952,0.980861
3,carpet,59,44,0,14,1,1.0,0.977778,0.988764,0.982143
4,grid,40,29,0,11,0,1.0,1.0,1.0,1.0
5,bottle,42,32,1,9,0,0.969697,1.0,0.984615,0.993789
6,cable,75,46,0,29,0,1.0,1.0,1.0,1.0
7,capsule,67,55,3,9,0,0.948276,1.0,0.973451,0.989209
8,metal_nut,58,47,3,8,0,0.94,1.0,0.969072,0.987395
9,leather,62,46,0,16,0,1.0,1.0,1.0,1.0


In [12]:
import pandas as pd

def get_fn_for_mode(df, mode: str, cfg: dict, max_samples=None):
    """
    Parcourt df, applique predict_image(...) pour le mode donné,
    et renvoie un DataFrame avec les FAUX NÉGATIFS (label=1, decision=0).
    """
    rows = df if max_samples is None else df.iloc[:max_samples]

    records = []
    for i, row in rows.iterrows():
        out = predict_image(row["path"], row["category"], mode=mode, cfg=cfg)
        if int(row["label"]) == 1 and int(out["decision"]) == 0:
            records.append({
                "idx": i,
                "category": row["category"],
                "defect_type": row.get("defect_type", None),
                "path": row["path"],
                "cls_prob": out["cls_prob"],
                "ad_score": out["ad_score"],
            })
    return pd.DataFrame(records)


fn_ind = get_fn_for_mode(df_test_all, mode="INDUSTRIAL_BALANCED", cfg=cfg)
print("Nombre de FN (INDUSTRIAL_BALANCED via API) :", len(fn_ind))
fn_ind

Nombre de FN (INDUSTRIAL_BALANCED via API) : 5


Unnamed: 0,idx,category,defect_type,path,cls_prob,ad_score
0,237,carpet,thread,C:\Users\othni\Projects\mvtec_ad\data\carpet\t...,0.473622,44.9667
1,665,tile,gray_stroke,C:\Users\othni\Projects\mvtec_ad\data\tile\tes...,0.047593,20.732866
2,704,transistor,bent_lead,C:\Users\othni\Projects\mvtec_ad\data\transist...,0.201865,22.077504
3,775,wood,hole,C:\Users\othni\Projects\mvtec_ad\data\wood\tes...,0.019329,42.743958
4,782,wood,scratch,C:\Users\othni\Projects\mvtec_ad\data\wood\tes...,0.000361,29.506964


In [13]:
# Tu peux ajuster cette liste après avoir vu fn_ind
CRITICAL_CATS = ["wood", "transistor", "tile", "carpet"]
CRITICAL_CATS

['wood', 'transistor', 'tile', 'carpet']

In [14]:
def predict_image_industrial_v2(img_path: str, category: str, cfg: dict, critical_cats=None):
    """
    Policy INDUSTRIAL_BALANCED_v2 :
      - strong_cats -> AD_SAFE only
      - critical_cats -> AD_SAFE only
      - autres -> AD_SAFE AND (cls_prob >= tau_stage2)
    """
    if critical_cats is None:
        critical_cats = []

    # Scores
    cls_prob, ad_score = compute_scores_for_image(img_path, category)

    # Seuils
    tau_img_F1   = cfg["resnet_image_level"]["tau_F1"]
    tau_img_safe = cfg["resnet_image_level"]["tau_safe"]
    tau_safe_cat = cfg["ad_mahalanobis"][category]["tau_safe"]

    pol = cfg["policies"]["INDUSTRIAL_BALANCED"]
    strong_cats = pol["strong_cats"]
    tau_stage2  = pol["tau_stage2"]

    pred_ad_safe = (ad_score >= tau_safe_cat)

    if (category in strong_cats) or (category in critical_cats):
        # AD_SAFE only
        is_anomaly = int(pred_ad_safe)
        reason = (
            f"INDUSTRIAL_BALANCED_v2: cat in strong/critical -> "
            f"y=1 si AD_SAFE (ad_score >= {tau_safe_cat:.3f})"
        )
    else:
        pred_cls_stage2 = (cls_prob >= tau_stage2)
        is_anomaly = int(pred_ad_safe and pred_cls_stage2)
        reason = (
            f"INDUSTRIAL_BALANCED_v2: cat normale -> "
            f"y=1 si AD_SAFE AND cls_prob >= {tau_stage2:.3f}"
        )

    return {
        "img_path": str(img_path),
        "category": category,
        "cls_prob": cls_prob,
        "ad_score": ad_score,
        "decision": is_anomaly,
        "mode": "INDUSTRIAL_BALANCED_v2",
        "reason": reason,
    }

In [15]:
def run_policy_industrial_v2_on_df(df, cfg: dict, critical_cats=None, max_samples=None):
    """
    Applique la policy INDUSTRIAL_BALANCED_v2 sur tout df.
    """
    if critical_cats is None:
        critical_cats = []

    rows = df if max_samples is None else df.iloc[:max_samples]

    y_true = []
    y_pred = []

    for i, row in rows.iterrows():
        out = predict_image_industrial_v2(row["path"], row["category"], cfg=cfg, critical_cats=critical_cats)
        y_true.append(int(row["label"]))
        y_pred.append(int(out["decision"]))

    y_true = np.array(y_true, dtype=int)
    y_pred = np.array(y_pred, dtype=int)
    return y_true, y_pred

In [16]:
y_true_v2, y_pred_v2 = run_policy_industrial_v2_on_df(
    df_test_all,
    cfg=cfg,
    critical_cats=CRITICAL_CATS
)

metrics_v2 = compute_metrics_global(y_true_v2, y_pred_v2)
print_metrics("Policy INDUSTRIAL_BALANCED_v2 (via API custom)", metrics_v2)


=== Policy INDUSTRIAL_BALANCED_v2 (via API custom) ===
Precision : 0.8966
Recall    : 1.0000
F1        : 0.9455
F2        : 0.9775
TN=163, FP=73, FN=0, TP=633


In [17]:
def per_category_stats_industrial_v2(df, cfg: dict, critical_cats=None):
    if critical_cats is None:
        critical_cats = []
    rows = []
    for cat, df_cat in df.groupby("category"):
        y_true, y_pred = run_policy_industrial_v2_on_df(df_cat, cfg=cfg, critical_cats=critical_cats)
        tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
        precision = precision_score(y_true, y_pred, zero_division=0)
        recall    = recall_score(y_true, y_pred, zero_division=0)
        f1        = f1_score(y_true, y_pred, zero_division=0)
        f2        = fbeta_score(y_true, y_pred, beta=2.0, zero_division=0)

        rows.append({
            "category": cat,
            "n_test": len(df_cat),
            "TP": tp, "FP": fp, "TN": tn, "FN": fn,
            "precision": precision,
            "recall": recall,
            "f1": f1,
            "f2": f2,
        })
    return pd.DataFrame(rows).sort_values("recall", ascending=True).reset_index(drop=True)


stats_v2 = per_category_stats_industrial_v2(df_test_all, cfg=cfg, critical_cats=CRITICAL_CATS)
print("=== INDUSTRIAL_BALANCED_v2 (via API custom) - par catégorie ===")
display(stats_v2)


=== INDUSTRIAL_BALANCED_v2 (via API custom) - par catégorie ===


Unnamed: 0,category,n_test,TP,FP,TN,FN,precision,recall,f1,f2
0,bottle,42,32,1,9,0,0.969697,1.0,0.984615,0.993789
1,cable,75,46,0,29,0,1.0,1.0,1.0,1.0
2,capsule,67,55,3,9,0,0.948276,1.0,0.973451,0.989209
3,carpet,59,45,14,0,0,0.762712,1.0,0.865385,0.941423
4,grid,40,29,0,11,0,1.0,1.0,1.0,1.0
5,hazelnut,55,35,0,20,0,1.0,1.0,1.0,1.0
6,leather,62,46,0,16,0,1.0,1.0,1.0,1.0
7,metal_nut,58,47,3,8,0,0.94,1.0,0.969072,0.987395
8,pill,84,71,4,9,0,0.946667,1.0,0.972603,0.988858
9,screw,81,60,0,21,0,1.0,1.0,1.0,1.0


In [18]:
# Recalcule les métriques globales pour les 4 policies

# A : SAFE_0FN
y_true_A, y_pred_A = run_policy_on_df(df_test_all, mode="SAFE_0FN", cfg=cfg)
metrics_A = compute_metrics_global(y_true_A, y_pred_A)

# D : INDUSTRIAL_BALANCED (v1)
y_true_D, y_pred_D = run_policy_on_df(df_test_all, mode="INDUSTRIAL_BALANCED", cfg=cfg)
metrics_D = compute_metrics_global(y_true_D, y_pred_D)

# P : PRECISION
y_true_P, y_pred_P = run_policy_on_df(df_test_all, mode="PRECISION", cfg=cfg)
metrics_P = compute_metrics_global(y_true_P, y_pred_P)

# v2 : INDUSTRIAL_BALANCED_v2 (custom)
y_true_v2, y_pred_v2 = run_policy_industrial_v2_on_df(
    df_test_all,
    cfg=cfg,
    critical_cats=CRITICAL_CATS
)
metrics_v2 = compute_metrics_global(y_true_v2, y_pred_v2)

summary = pd.DataFrame([
    {
        "policy": "SAFE_0FN",
        "precision": metrics_A["precision"],
        "recall": metrics_A["recall"],
        "f1": metrics_A["f1"],
        "f2": metrics_A["f2"],
        "tn": metrics_A["tn"],
        "fp": metrics_A["fp"],
        "fn": metrics_A["fn"],
        "tp": metrics_A["tp"],
    },
    {
        "policy": "INDUSTRIAL_BALANCED",
        "precision": metrics_D["precision"],
        "recall": metrics_D["recall"],
        "f1": metrics_D["f1"],
        "f2": metrics_D["f2"],
        "tn": metrics_D["tn"],
        "fp": metrics_D["fp"],
        "fn": metrics_D["fn"],
        "tp": metrics_D["tp"],
    },
    {
        "policy": "PRECISION",
        "precision": metrics_P["precision"],
        "recall": metrics_P["recall"],
        "f1": metrics_P["f1"],
        "f2": metrics_P["f2"],
        "tn": metrics_P["tn"],
        "fp": metrics_P["fp"],
        "fn": metrics_P["fn"],
        "tp": metrics_P["tp"],
    },
    {
        "policy": "INDUSTRIAL_SAFE_0FN (v2)",
        "precision": metrics_v2["precision"],
        "recall": metrics_v2["recall"],
        "f1": metrics_v2["f1"],
        "f2": metrics_v2["f2"],
        "tn": metrics_v2["tn"],
        "fp": metrics_v2["fp"],
        "fn": metrics_v2["fn"],
        "tp": metrics_v2["tp"],
    },
])

summary

Unnamed: 0,policy,precision,recall,f1,f2,tn,fp,fn,tp
0,SAFE_0FN,0.782447,1.0,0.877947,0.947321,60,176,0,633
1,INDUSTRIAL_BALANCED,0.982786,0.992101,0.987421,0.990224,225,11,5,628
2,PRECISION,0.9952,0.982622,0.988871,0.985112,233,3,11,622
3,INDUSTRIAL_SAFE_0FN (v2),0.896601,1.0,0.945482,0.977455,163,73,0,633


In [19]:
import json
from copy import deepcopy

# On part de la policy INDUSTRIAL_BALANCED existante
pol_balanced = cfg["policies"]["INDUSTRIAL_BALANCED"]

cfg["policies"]["INDUSTRIAL_SAFE_0FN"] = {
    "type": "MIXED_V2",
    "strong_cats": pol_balanced["strong_cats"],   # bottle, metal_nut, toothbrush
    "critical_cats": CRITICAL_CATS,               # ["wood", "transistor", "tile", "carpet"] par ex.
    "tau_stage2": pol_balanced["tau_stage2"],     # le seuil stage2 déjà trouvé (0.60)
}

# Sauvegarde dans le JSON
with open(EXPERIMENTS_DIR / "image_level_policies_config.json", "w") as f:
    json.dump(cfg, f, indent=2)

print("Config mise à jour avec INDUSTRIAL_SAFE_0FN.")

Config mise à jour avec INDUSTRIAL_SAFE_0FN.


In [20]:
def predict_image(img_path: str, category: str, mode: str, cfg: dict):
    """
    API unique d'inférence image-level avec plusieurs modes :
      - SAFE_0FN
      - INDUSTRIAL_BALANCED
      - INDUSTRIAL_SAFE_0FN
      - PRECISION
    """
    # --- scores de base ---
    cls_prob, ad_score = compute_scores_for_image(img_path, category)

    # Seuils globaux ResNet
    tau_img_F1   = cfg["resnet_image_level"]["tau_F1"]
    tau_img_safe = cfg["resnet_image_level"]["tau_safe"]

    # Seuils AD par catégorie
    tau_safe_cat = cfg["ad_mahalanobis"][category]["tau_safe"]
    tau_F1_cat   = cfg["ad_mahalanobis"][category]["tau_F1"]

    # Policy configs
    pol_balanced   = cfg["policies"]["INDUSTRIAL_BALANCED"]
    pol_safe0      = cfg["policies"].get("INDUSTRIAL_SAFE_0FN", None)

    # Décisions élémentaires
    pred_img_F1   = (cls_prob >= tau_img_F1)
    pred_img_safe = (cls_prob >= tau_img_safe)
    pred_ad_safe  = (ad_score >= tau_safe_cat)
    pred_ad_F1    = (ad_score >= tau_F1_cat)

    # --- Mode 1 : SAFE_0FN (global, AD SAFE only) ---
    if mode == "SAFE_0FN":
        is_anomaly = int(pred_ad_safe)
        reason = f"Mode SAFE_0FN: y=1 si ad_score >= tau_safe({tau_safe_cat:.3f})"

    # --- Mode 2 : PRECISION (très peu de FP, plus de FN) ---
    elif mode == "PRECISION":
        # Durcir les deux : AD F1* + ResNet F1*
        is_anomaly = int(pred_ad_F1 and pred_img_F1)
        reason = (
            "Mode PRECISION: y=1 si ad_score >= tau_F1_cat "
            f"({tau_F1_cat:.3f}) AND cls_prob >= tau_F1_global({tau_img_F1:.3f})"
        )

    # --- Mode 3 : INDUSTRIAL_BALANCED (v1) ---
    elif mode == "INDUSTRIAL_BALANCED":
        strong_cats = pol_balanced["strong_cats"]
        tau_stage2  = pol_balanced["tau_stage2"]

        if category in strong_cats:
            # AD_SAFE only pour les catégories très fortes
            is_anomaly = int(pred_ad_safe)
            reason = (
                "Mode INDUSTRIAL_BALANCED (cat forte): "
                f"y=1 si AD_SAFE (ad_score >= {tau_safe_cat:.3f})"
            )
        else:
            pred_stage2 = (cls_prob >= tau_stage2)
            is_anomaly = int(pred_ad_safe and pred_stage2)
            reason = (
                "Mode INDUSTRIAL_BALANCED (cat normale): "
                f"y=1 si AD_SAFE (ad_score >= {tau_safe_cat:.3f}) "
                f"AND cls_prob >= tau_stage2({tau_stage2:.3f})"
            )

    # --- Mode 4 : INDUSTRIAL_SAFE_0FN (v2) ---
    elif mode == "INDUSTRIAL_SAFE_0FN":
        if pol_safe0 is None:
            raise ValueError("Policy INDUSTRIAL_SAFE_0FN non trouvée dans cfg['policies'].")

        strong_cats   = pol_safe0["strong_cats"]
        critical_cats = pol_safe0["critical_cats"]
        tau_stage2    = pol_safe0["tau_stage2"]

        if (category in strong_cats) or (category in critical_cats):
            # AD_SAFE only : on vise 0 FN coûte que coûte
            is_anomaly = int(pred_ad_safe)
            reason = (
                "Mode INDUSTRIAL_SAFE_0FN: cat strong/critical -> "
                f"y=1 si AD_SAFE (ad_score >= {tau_safe_cat:.3f})"
            )
        else:
            # Pour les autres, on garde la cascade AD_SAFE AND cls_prob>=tau_stage2
            pred_stage2 = (cls_prob >= tau_stage2)
            is_anomaly = int(pred_ad_safe and pred_stage2)
            reason = (
                "Mode INDUSTRIAL_SAFE_0FN: cat normale -> "
                f"y=1 si AD_SAFE AND cls_prob >= tau_stage2({tau_stage2:.3f})"
            )

    else:
        raise ValueError(f"Mode inconnu: {mode}")

    return {
        "img_path": str(img_path),
        "category": category,
        "cls_prob": float(cls_prob),
        "ad_score": float(ad_score),
        "decision": int(is_anomaly),
        "mode": mode,
        "reason": reason,
    }


In [21]:
test_row = df_test_all.iloc[0]
print(test_row["category"], test_row["defect_type"], test_row["path"])

for m in ["SAFE_0FN", "INDUSTRIAL_SAFE_0FN", "INDUSTRIAL_BALANCED", "PRECISION"]:
    out = predict_image(test_row["path"], test_row["category"], mode=m, cfg=cfg)
    print("\n", m, "-> decision:", out["decision"])
    print("  cls_prob =", out["cls_prob"], "| ad_score =", out["ad_score"])
    print("  reason  =", out["reason"])

bottle broken_large C:\Users\othni\Projects\mvtec_ad\data\bottle\test\broken_large\000.png

 SAFE_0FN -> decision: 1
  cls_prob = 0.9999997615814209 | ad_score = 89.61449598992749
  reason  = Mode SAFE_0FN: y=1 si ad_score >= tau_safe(24.862)

 INDUSTRIAL_SAFE_0FN -> decision: 1
  cls_prob = 0.9999997615814209 | ad_score = 89.61449598992749
  reason  = Mode INDUSTRIAL_SAFE_0FN: cat strong/critical -> y=1 si AD_SAFE (ad_score >= 24.862)

 INDUSTRIAL_BALANCED -> decision: 1
  cls_prob = 0.9999997615814209 | ad_score = 89.61449598992749
  reason  = Mode INDUSTRIAL_BALANCED (cat forte): y=1 si AD_SAFE (ad_score >= 24.862)

 PRECISION -> decision: 1
  cls_prob = 0.9999997615814209 | ad_score = 89.61449598992749
  reason  = Mode PRECISION: y=1 si ad_score >= tau_F1_cat (32.041) AND cls_prob >= tau_F1_global(0.820)
