In [3]:
from pathlib import Path
import json
import numpy as np
from PIL import Image

import torch
from torch import nn
from torchvision import models, transforms
from scipy.spatial import distance

PROJECT_ROOT = Path(r"C:\Users\othni\Projects\mvtec_ad")
DATA_ROOT = PROJECT_ROOT / "data"
EXPERIMENTS_DIR = PROJECT_ROOT / "experiments"
MODELS_DIR = PROJECT_ROOT / "models"
AD_MODELS_DIR = MODELS_DIR / "ad_resnet_mahalanobis"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [4]:
# 1) Config
with open(EXPERIMENTS_DIR / "image_level_policies_config.json", "r") as f:
    cfg = json.load(f)

# 2) Classifieur ResNet (image-level)
cls_model = models.resnet18(weights=None)
in_features = cls_model.fc.in_features
cls_model.fc = nn.Linear(in_features, 1)
cls_ckpt = MODELS_DIR / "resnet18_image_level_best.pt"
cls_model.load_state_dict(torch.load(cls_ckpt, map_location=device))
cls_model = cls_model.to(device)
cls_model.eval()

# 3) Extracteur de features ResNet pour AD
resnet_feat = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
modules = list(resnet_feat.children())[:-1]
resnet_feat = nn.Sequential(*modules).to(device)
resnet_feat.eval()

# 4) Transform commun
eval_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

In [5]:
def compute_scores_for_image(img_path: str, category: str):
    img = Image.open(img_path).convert("RGB")
    x = eval_transform(img).unsqueeze(0).to(device)  # (1,3,224,224)

    # Score classifieur
    with torch.no_grad():
        logit = cls_model(x).squeeze(1)
        cls_prob = torch.sigmoid(logit).item()

    # Score AD (Mahalanobis)
    npz_path = AD_MODELS_DIR / f"{category}_gaussian_stats.npz"
    data = np.load(npz_path)
    mu = data["mu"]
    precision = data["precision"]

    with torch.no_grad():
        feat = resnet_feat(x)          # (1,512,1,1)
        feat = feat.view(1, -1).cpu().numpy()  # (1,512)

    diff = feat - mu
    left = diff @ precision
    m = float(np.sqrt(np.sum(left * diff, axis=1))[0])

    return cls_prob, m

In [7]:
import pandas as pd

# On recharge le CSV global construit dans les notebooks précédents
IMAGE_DF_PATH = EXPERIMENTS_DIR / "image_level_df.csv"
print("Chargement :", IMAGE_DF_PATH)

df = pd.read_csv(IMAGE_DF_PATH)
print(df.shape)
print(df["final_split"].value_counts())

# On ne garde que le split test pour l'analyse
df_test_all = df[df["final_split"] == "test"].reset_index(drop=True)
print("df_test_all :", df_test_all.shape)
df_test_all.head()

Chargement : C:\Users\othni\Projects\mvtec_ad\experiments\image_level_df.csv
(5354, 6)
final_split
train    3629
test      869
val       856
Name: count, dtype: int64
df_test_all : (869, 6)


Unnamed: 0,path,category,split,label,defect_type,final_split
0,C:\Users\othni\Projects\mvtec_ad\data\bottle\t...,bottle,test,1,broken_large,test
1,C:\Users\othni\Projects\mvtec_ad\data\bottle\t...,bottle,test,1,broken_large,test
2,C:\Users\othni\Projects\mvtec_ad\data\bottle\t...,bottle,test,1,broken_large,test
3,C:\Users\othni\Projects\mvtec_ad\data\bottle\t...,bottle,test,1,broken_large,test
4,C:\Users\othni\Projects\mvtec_ad\data\bottle\t...,bottle,test,1,broken_large,test


In [8]:
from tqdm import tqdm

scores_csv = EXPERIMENTS_DIR / "df_test_all_with_scores.csv"

if scores_csv.exists():
    print("Recharge des scores depuis :", scores_csv)
    df_scores = pd.read_csv(scores_csv)
else:
    print("Pas de CSV de scores trouvé, on calcule cls_prob et ad_score avec compute_scores_for_image...")

    cls_probs = []
    ad_scores = []

    for _, row in tqdm(df_test_all.iterrows(), total=len(df_test_all)):
        p, m = compute_scores_for_image(row["path"], row["category"])
        cls_probs.append(p)
        ad_scores.append(m)

    df_scores = df_test_all.copy()
    df_scores["cls_prob"] = cls_probs
    df_scores["ad_score"] = ad_scores

    df_scores.to_csv(scores_csv, index=False)
    print("Scores sauvegardés dans :", scores_csv)

print(df_scores.shape)
df_scores.head()

Pas de CSV de scores trouvé, on calcule cls_prob et ad_score avec compute_scores_for_image...


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 869/869 [02:42<00:00,  5.35it/s]

Scores sauvegardés dans : C:\Users\othni\Projects\mvtec_ad\experiments\df_test_all_with_scores.csv
(869, 8)





Unnamed: 0,path,category,split,label,defect_type,final_split,cls_prob,ad_score
0,C:\Users\othni\Projects\mvtec_ad\data\bottle\t...,bottle,test,1,broken_large,test,1.0,89.614496
1,C:\Users\othni\Projects\mvtec_ad\data\bottle\t...,bottle,test,1,broken_large,test,0.999992,145.374881
2,C:\Users\othni\Projects\mvtec_ad\data\bottle\t...,bottle,test,1,broken_large,test,0.999787,89.16088
3,C:\Users\othni\Projects\mvtec_ad\data\bottle\t...,bottle,test,1,broken_large,test,0.999999,56.286167
4,C:\Users\othni\Projects\mvtec_ad\data\bottle\t...,bottle,test,1,broken_large,test,1.0,102.899164


In [9]:
import numpy as np
from sklearn.metrics import precision_recall_fscore_support

# Seuils du classifieur global (ceux qu'on a trouvés sur la validation)
tau_img_F1   = 0.82   # seuil F1* global
tau_img_safe = 0.023  # seuil "SAFE" global

print("tau_img_F1  =", tau_img_F1)
print("tau_img_safe=", tau_img_safe)

# Seuils AD "safe" par catégorie (ceux que tu avais dans le tableau prec_safe / recall=1.0)
TAU_AD_SAFE = {
    "bottle":      24.861817,
    "toothbrush":   7.383400,
    "metal_nut":   20.165997,
    "capsule":     19.341192,
    "pill":        19.386349,
    "leather":     27.482090,
    "wood":        28.664655,
    "zipper":      19.596640,
    "hazelnut":    40.880028,
    "tile":        19.477031,
    "carpet":      23.721512,
    "screw":       25.333641,
    "grid":        17.018129,
    "cable":       18.284675,
    "transistor":  17.372243,
}

# petit check
missing_cats = set(df_scores["category"].unique()) - set(TAU_AD_SAFE.keys())
print("Catégories manquantes dans TAU_AD_SAFE :", missing_cats)

tau_img_F1  = 0.82
tau_img_safe= 0.023
Catégories manquantes dans TAU_AD_SAFE : set()


In [10]:
y_true = df_scores["label"].values.astype(int)
categories = df_scores["category"].values
cls_prob = df_scores["cls_prob"].values
ad_score = df_scores["ad_score"].values

# Décisions élémentaires
y_cls_F1   = (cls_prob >= tau_img_F1).astype(int)
y_cls_safe = (cls_prob >= tau_img_safe).astype(int)

y_ad_safe = np.array([
    1 if ad_score[i] >= TAU_AD_SAFE[categories[i]] else 0
    for i in range(len(df_scores))
], dtype=int)

print("Vérif shapes :", y_true.shape, y_cls_F1.shape, y_ad_safe.shape)

# Politiques combinées
policies = {
    "CLS_F1_ONLY":               y_cls_F1,
    "AD_SAFE_ONLY":              y_ad_safe,
    "AND_AD_SAFE_AND_CLS_F1":    (y_ad_safe & y_cls_F1).astype(int),
    "OR_AD_SAFE_OR_CLS_SAFE":    (y_ad_safe | y_cls_safe).astype(int),
}

def summarize_policy(name, y_pred):
    precision, recall, f1, _ = precision_recall_fscore_support(
        y_true, y_pred, average="binary", zero_division=0
    )
    # F2
    beta = 2.0
    if precision + recall > 0:
        f2 = (1 + beta**2) * (precision * recall) / (beta**2 * precision + recall)
    else:
        f2 = 0.0
    
    tn = int(((y_true == 0) & (y_pred == 0)).sum())
    fp = int(((y_true == 0) & (y_pred == 1)).sum())
    fn = int(((y_true == 1) & (y_pred == 0)).sum())
    tp = int(((y_true == 1) & (y_pred == 1)).sum())
    
    print(f"\n=== {name} ===")
    print(f"Precision : {precision:.4f}")
    print(f"Recall    : {recall:.4f}")
    print(f"F1        : {f1:.4f}")
    print(f"F2        : {f2:.4f}")
    print(f"TN={tn}, FP={fp}, FN={fn}, TP={tp}")
    
    return {
        "policy": name,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "f2": f2,
        "tn": tn, "fp": fp, "fn": fn, "tp": tp,
    }

results = []
for name, y_pred in policies.items():
    res = summarize_policy(name, y_pred)
    results.append(res)

import pandas as pd
df_policies = pd.DataFrame(results)
df_policies

Vérif shapes : (869,) (869,) (869,)

=== CLS_F1_ONLY ===
Precision : 0.9952
Recall    : 0.9826
F1        : 0.9889
F2        : 0.9851
TN=233, FP=3, FN=11, TP=622

=== AD_SAFE_ONLY ===
Precision : 0.7824
Recall    : 1.0000
F1        : 0.8779
F2        : 0.9473
TN=60, FP=176, FN=0, TP=633

=== AND_AD_SAFE_AND_CLS_F1 ===
Precision : 0.9952
Recall    : 0.9826
F1        : 0.9889
F2        : 0.9851
TN=233, FP=3, FN=11, TP=622

=== OR_AD_SAFE_OR_CLS_SAFE ===
Precision : 0.7608
Recall    : 1.0000
F1        : 0.8642
F2        : 0.9408
TN=37, FP=199, FN=0, TP=633


Unnamed: 0,policy,precision,recall,f1,f2,tn,fp,fn,tp
0,CLS_F1_ONLY,0.9952,0.982622,0.988871,0.985112,233,3,11,622
1,AD_SAFE_ONLY,0.782447,1.0,0.877947,0.947321,60,176,0,633
2,AND_AD_SAFE_AND_CLS_F1,0.9952,0.982622,0.988871,0.985112,233,3,11,622
3,OR_AD_SAFE_OR_CLS_SAFE,0.760817,1.0,0.864164,0.940844,37,199,0,633


In [11]:
def add_err_type(df, y_pred, col_prefix):
    """
    Ajoute les colonnes :
      - {col_prefix}_pred
      - {col_prefix}_err_type parmi {TP, FP, TN, FN}
    """
    df = df.copy()
    df[f"{col_prefix}_pred"] = y_pred.astype(int)
    
    y_true = df["label"].values.astype(int)
    err_types = []

    for yt, yp in zip(y_true, y_pred):
        if yt == 1 and yp == 1:
            err_types.append("TP")
        elif yt == 0 and yp == 0:
            err_types.append("TN")
        elif yt == 0 and yp == 1:
            err_types.append("FP")
        elif yt == 1 and yp == 0:
            err_types.append("FN")
        else:
            err_types.append("UNK")

    df[f"{col_prefix}_err_type"] = err_types
    return df

# On construit un df enrichi pour les deux politiques clés :
df_analysis = df_scores.copy()

df_analysis = add_err_type(df_analysis, policies["CLS_F1_ONLY"], "clsF1")
df_analysis = add_err_type(df_analysis, policies["AD_SAFE_ONLY"], "adSAFE")

df_analysis.head()

Unnamed: 0,path,category,split,label,defect_type,final_split,cls_prob,ad_score,clsF1_pred,clsF1_err_type,adSAFE_pred,adSAFE_err_type
0,C:\Users\othni\Projects\mvtec_ad\data\bottle\t...,bottle,test,1,broken_large,test,1.0,89.614496,1,TP,1,TP
1,C:\Users\othni\Projects\mvtec_ad\data\bottle\t...,bottle,test,1,broken_large,test,0.999992,145.374881,1,TP,1,TP
2,C:\Users\othni\Projects\mvtec_ad\data\bottle\t...,bottle,test,1,broken_large,test,0.999787,89.16088,1,TP,1,TP
3,C:\Users\othni\Projects\mvtec_ad\data\bottle\t...,bottle,test,1,broken_large,test,0.999999,56.286167,1,TP,1,TP
4,C:\Users\othni\Projects\mvtec_ad\data\bottle\t...,bottle,test,1,broken_large,test,1.0,102.899164,1,TP,1,TP


In [12]:
fns_cls = df_analysis[df_analysis["clsF1_err_type"] == "FN"].copy()
print("Nombre de FN (CLS_F1_ONLY) :", len(fns_cls))

cols_show = ["category", "defect_type", "path", "cls_prob", "ad_score"]
fns_cls[cols_show].sort_values(by=["category", "defect_type"]).head(20)

Nombre de FN (CLS_F1_ONLY) : 11


Unnamed: 0,category,defect_type,path,cls_prob,ad_score
50,cable,cable_swap,C:\Users\othni\Projects\mvtec_ad\data\cable\te...,0.812187,20.562237
54,cable,cable_swap,C:\Users\othni\Projects\mvtec_ad\data\cable\te...,0.700812,18.549954
237,carpet,thread,C:\Users\othni\Projects\mvtec_ad\data\carpet\t...,0.473622,44.9667
450,metal_nut,scratch,C:\Users\othni\Projects\mvtec_ad\data\metal_nu...,0.705767,22.953692
665,tile,gray_stroke,C:\Users\othni\Projects\mvtec_ad\data\tile\tes...,0.047593,20.732866
683,toothbrush,defective,C:\Users\othni\Projects\mvtec_ad\data\toothbru...,0.752681,7.768863
704,transistor,bent_lead,C:\Users\othni\Projects\mvtec_ad\data\transist...,0.201865,22.077504
711,transistor,cut_lead,C:\Users\othni\Projects\mvtec_ad\data\transist...,0.713614,24.1424
775,wood,hole,C:\Users\othni\Projects\mvtec_ad\data\wood\tes...,0.019329,42.743958
782,wood,scratch,C:\Users\othni\Projects\mvtec_ad\data\wood\tes...,0.000361,29.506964


In [13]:
fns_grouped = (
    fns_cls
    .groupby(["category", "defect_type"])
    .size()
    .reset_index(name="n_FN")
    .sort_values(by="n_FN", ascending=False)
)

fns_grouped

Unnamed: 0,category,defect_type,n_FN
0,cable,cable_swap,2
1,carpet,thread,1
2,metal_nut,scratch,1
3,tile,gray_stroke,1
4,toothbrush,defective,1
5,transistor,bent_lead,1
6,transistor,cut_lead,1
7,wood,hole,1
8,wood,scratch,1
9,zipper,fabric_interior,1


In [14]:
# On regarde les mêmes FN, mais côté AD_SAFE
cols_show_extended = [
    "category", "defect_type", "path",
    "label", "cls_prob", "ad_score",
    "clsF1_err_type", "adSAFE_err_type"
]

fns_cls[cols_show_extended].sort_values(by=["category", "defect_type"])

Unnamed: 0,category,defect_type,path,label,cls_prob,ad_score,clsF1_err_type,adSAFE_err_type
50,cable,cable_swap,C:\Users\othni\Projects\mvtec_ad\data\cable\te...,1,0.812187,20.562237,FN,TP
54,cable,cable_swap,C:\Users\othni\Projects\mvtec_ad\data\cable\te...,1,0.700812,18.549954,FN,TP
237,carpet,thread,C:\Users\othni\Projects\mvtec_ad\data\carpet\t...,1,0.473622,44.9667,FN,TP
450,metal_nut,scratch,C:\Users\othni\Projects\mvtec_ad\data\metal_nu...,1,0.705767,22.953692,FN,TP
665,tile,gray_stroke,C:\Users\othni\Projects\mvtec_ad\data\tile\tes...,1,0.047593,20.732866,FN,TP
683,toothbrush,defective,C:\Users\othni\Projects\mvtec_ad\data\toothbru...,1,0.752681,7.768863,FN,TP
704,transistor,bent_lead,C:\Users\othni\Projects\mvtec_ad\data\transist...,1,0.201865,22.077504,FN,TP
711,transistor,cut_lead,C:\Users\othni\Projects\mvtec_ad\data\transist...,1,0.713614,24.1424,FN,TP
775,wood,hole,C:\Users\othni\Projects\mvtec_ad\data\wood\tes...,1,0.019329,42.743958,FN,TP
782,wood,scratch,C:\Users\othni\Projects\mvtec_ad\data\wood\tes...,1,0.000361,29.506964,FN,TP


In [15]:
fps_ad = df_analysis[df_analysis["adSAFE_err_type"] == "FP"].copy()
print("Nombre de FP (AD_SAFE_ONLY) :", len(fps_ad))

# FP par catégorie
fps_grouped = (
    fps_ad
    .groupby(["category"])
    .size()
    .reset_index(name="n_FP")
    .sort_values(by="n_FP", ascending=False)
)

fps_grouped

Nombre de FP (AD_SAFE_ONLY) : 176


Unnamed: 0,category,n_FP
11,transistor,24
1,cable,22
9,screw,20
10,tile,16
13,zipper,15
3,carpet,14
8,pill,13
6,leather,12
4,grid,11
2,capsule,9


In [16]:
# quelques exemples de FP pour inspection
fps_ad[["category", "defect_type", "path", "cls_prob", "ad_score"]].head(20)

Unnamed: 0,category,defect_type,path,cls_prob,ad_score
39,bottle,good,C:\Users\othni\Projects\mvtec_ad\data\bottle\t...,0.021703,29.199477
74,cable,good,C:\Users\othni\Projects\mvtec_ad\data\cable\te...,0.576278,21.38025
76,cable,good,C:\Users\othni\Projects\mvtec_ad\data\cable\te...,0.031419,19.902417
77,cable,good,C:\Users\othni\Projects\mvtec_ad\data\cable\te...,0.073952,20.410719
79,cable,good,C:\Users\othni\Projects\mvtec_ad\data\cable\te...,0.065705,22.918919
81,cable,good,C:\Users\othni\Projects\mvtec_ad\data\cable\te...,0.028365,22.445676
82,cable,good,C:\Users\othni\Projects\mvtec_ad\data\cable\te...,0.349497,22.916417
84,cable,good,C:\Users\othni\Projects\mvtec_ad\data\cable\te...,0.07604,20.315808
86,cable,good,C:\Users\othni\Projects\mvtec_ad\data\cable\te...,0.133053,24.858273
88,cable,good,C:\Users\othni\Projects\mvtec_ad\data\cable\te...,0.063647,19.323516


In [17]:
# catégories où le classif a fait au moins un FN
cats_with_FN_cls = sorted(fns_cls["category"].unique())
cats_with_FN_cls

['cable',
 'carpet',
 'metal_nut',
 'tile',
 'toothbrush',
 'transistor',
 'wood',
 'zipper']

In [18]:
cats_AD_very_strong = ["bottle", "toothbrush", "metal_nut"]
cats_AD_very_strong

['bottle', 'toothbrush', 'metal_nut']

In [19]:
import numpy as np

y_true = df_scores["label"].values.astype(int)
cats = df_scores["category"].values

y_clsF1 = policies["CLS_F1_ONLY"]
y_adSAFE = policies["AD_SAFE_ONLY"]

cats_with_FN_cls = sorted(fns_cls["category"].unique())

# Policy MIX_FN: CLS_F1 par défaut, AD_SAFE sur les catégories à FN
y_mix_fn = np.zeros_like(y_true)

for i, cat in enumerate(cats):
    if cat in cats_with_FN_cls:
        # catégorie “difficile” → on prend la décision AD_SAFE_ONLY
        y_mix_fn[i] = y_adSAFE[i]
    else:
        # catégorie “facile” pour le classif → on garde CLS_F1_ONLY
        y_mix_fn[i] = y_clsF1[i]

def compute_metrics(y_true, y_pred, name="policy"):
    y_true = y_true.astype(int)
    y_pred = y_pred.astype(int)
    tp = int(((y_true == 1) & (y_pred == 1)).sum())
    tn = int(((y_true == 0) & (y_pred == 0)).sum())
    fp = int(((y_true == 0) & (y_pred == 1)).sum())
    fn = int(((y_true == 1) & (y_pred == 0)).sum())
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall    = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
    # F2: recall vaut 2x plus que precision
    beta2 = 2.0
    f2 = (1+beta2**2)*precision*recall / (beta2**2*precision + recall) if (precision+recall)>0 else 0.0

    print(f"=== {name} ===")
    print(f"Precision : {precision:.4f}")
    print(f"Recall    : {recall:.4f}")
    print(f"F1        : {f1:.4f}")
    print(f"F2        : {f2:.4f}")
    print(f"TN={tn}, FP={fp}, FN={fn}, TP={tp}")
    return {
        "policy": name,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "f2": f2,
        "tn": tn,
        "fp": fp,
        "fn": fn,
        "tp": tp,
    }

metrics_mix_fn = compute_metrics(y_true, y_mix_fn, name="MIX_FN (CLS_F1 sauf catégories à FN → AD_SAFE)")
metrics_mix_fn

=== MIX_FN (CLS_F1 sauf catégories à FN → AD_SAFE) ===
Precision : 0.8577
Recall    : 1.0000
F1        : 0.9234
F2        : 0.9679
TN=131, FP=105, FN=0, TP=633


{'policy': 'MIX_FN (CLS_F1 sauf catégories à FN → AD_SAFE)',
 'precision': 0.8577235772357723,
 'recall': 1.0,
 'f1': 0.9234135667396061,
 'f2': 0.9678899082568807,
 'tn': 131,
 'fp': 105,
 'fn': 0,
 'tp': 633}

In [20]:
df_mix = df_scores.copy()
df_mix["y_mix_fn"] = y_mix_fn

group_mix = (
    df_mix
    .groupby("category")
    .apply(lambda d: pd.Series({
        "n_test": len(d),
        "TP": int(((d["label"] == 1) & (d["y_mix_fn"] == 1)).sum()),
        "FP": int(((d["label"] == 0) & (d["y_mix_fn"] == 1)).sum()),
        "TN": int(((d["label"] == 0) & (d["y_mix_fn"] == 0)).sum()),
        "FN": int(((d["label"] == 1) & (d["y_mix_fn"] == 0)).sum()),
    }))
    .reset_index()
)

group_mix["precision"] = group_mix["TP"] / (group_mix["TP"] + group_mix["FP"]).replace(0, np.nan)
group_mix["recall"]    = group_mix["TP"] / (group_mix["TP"] + group_mix["FN"]).replace(0, np.nan)

group_mix.sort_values("precision", ascending=True)

  .apply(lambda d: pd.Series({


Unnamed: 0,category,n_test,TP,FP,TN,FN,precision,recall
12,transistor,50,20,24,6,0,0.454545,1.0
1,cable,75,46,22,7,0,0.676471,1.0
10,tile,59,42,16,1,0,0.724138,1.0
3,carpet,59,45,14,0,0,0.762712,1.0
13,wood,40,30,8,2,0,0.789474,1.0
14,zipper,76,60,15,1,0,0.8,1.0
7,metal_nut,58,47,3,8,0,0.94,1.0
8,pill,84,71,2,11,0,0.972603,1.0
2,capsule,67,55,1,11,0,0.982143,1.0
5,hazelnut,55,35,0,20,0,1.0,1.0


In [21]:
# Ensemble des couples (category, defect_type) où le classif a fait FN
bad_pairs = set(zip(fns_cls["category"], fns_cls["defect_type"]))

y_mix_pairs = np.zeros_like(y_true)

for i, (cat, defect) in enumerate(zip(df_scores["category"], df_scores["defect_type"])):
    if (cat, defect) in bad_pairs:
        # cas (cat,defect) déjà problématique pour le classif -> AD_SAFE_ONLY
        y_mix_pairs[i] = y_adSAFE[i]
    else:
        # sinon, on reste sur CLS_F1_ONLY
        y_mix_pairs[i] = y_clsF1[i]

metrics_mix_pairs = compute_metrics(y_true, y_mix_pairs, name="MIX_PAIRS (AD_SAFE pour (cat,defect) problématiques)")
metrics_mix_pairs

=== MIX_PAIRS (AD_SAFE pour (cat,defect) problématiques) ===
Precision : 0.9953
Recall    : 1.0000
F1        : 0.9976
F2        : 0.9991
TN=233, FP=3, FN=0, TP=633


{'policy': 'MIX_PAIRS (AD_SAFE pour (cat,defect) problématiques)',
 'precision': 0.9952830188679245,
 'recall': 1.0,
 'f1': 0.9976359338061466,
 'f2': 0.9990530303030304,
 'tn': 233,
 'fp': 3,
 'fn': 0,
 'tp': 633}

In [22]:
df_pairs = df_scores.copy()
df_pairs["y_mix_pairs"] = y_mix_pairs

group_pairs = (
    df_pairs
    .groupby(["category", "defect_type"])
    .apply(lambda d: pd.Series({
        "n_test": len(d),
        "TP": int(((d["label"] == 1) & (d["y_mix_pairs"] == 1)).sum()),
        "FP": int(((d["label"] == 0) & (d["y_mix_pairs"] == 1)).sum()),
        "TN": int(((d["label"] == 0) & (d["y_mix_pairs"] == 0)).sum()),
        "FN": int(((d["label"] == 1) & (d["y_mix_pairs"] == 0)).sum()),
    }))
    .reset_index()
)

group_pairs.sort_values(["category", "defect_type"])

  .apply(lambda d: pd.Series({


Unnamed: 0,category,defect_type,n_test,TP,FP,TN,FN
0,bottle,broken_large,12,12,0,0,0
1,bottle,broken_small,8,8,0,0,0
2,bottle,contamination,12,12,0,0,0
3,bottle,good,10,0,0,10,0
4,cable,bent_wire,8,8,0,0,0
...,...,...,...,...,...,...,...
83,zipper,fabric_interior,8,8,0,0,0
84,zipper,good,16,0,0,16,0
85,zipper,rough,9,9,0,0,0
86,zipper,split_teeth,7,7,0,0,0


MIX_FN (CLS_F1 sauf catégories à FN → AD_SAFE)

Precision ≈ 0.858

Recall = 1.000

F1 ≈ 0.923
→ 0 FN mais encore pas mal de FP (105).

MIX_PAIRS (AD_SAFE pour (cat, defect_type) problématiques)

Precision ≈ 0.9953

Recall = 1.000

F1 ≈ 0.9976

F2 ≈ 0.9991

TN=233, FP=3, FN=0, TP=633
→ 0 FN et seulement 3 FP sur tout MVTec test

Donc théoriquement, MIX_PAIRS domine toutes les autres policies.

In [29]:
from pathlib import Path
import json

import numpy as np
import pandas as pd
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix

# --- Chemins projet ---
PROJECT_ROOT    = Path(r"C:\Users\othni\Projects\mvtec_ad")
EXPERIMENTS_DIR = PROJECT_ROOT / "experiments"

# --- 1) Charger la config des policies (image_level_policies_config.json) ---
cfg_path = EXPERIMENTS_DIR / "image_level_policies_config.json"
with open(cfg_path, "r") as f:
    cfg = json.load(f)

print("Clés top-level de cfg :", list(cfg.keys()))

# --- 2) Charger les scores déjà calculés sur df_test_all ---
scores_csv = EXPERIMENTS_DIR / "df_test_all_with_scores.csv"
df_scores = pd.read_csv(scores_csv)

print("df_scores shape :", df_scores.shape)
df_scores.head()

Clés top-level de cfg : ['resnet_image_level', 'ad_mahalanobis', 'policies']
df_scores shape : (869, 8)


Unnamed: 0,path,category,split,label,defect_type,final_split,cls_prob,ad_score
0,C:\Users\othni\Projects\mvtec_ad\data\bottle\t...,bottle,test,1,broken_large,test,1.0,89.614496
1,C:\Users\othni\Projects\mvtec_ad\data\bottle\t...,bottle,test,1,broken_large,test,0.999992,145.374881
2,C:\Users\othni\Projects\mvtec_ad\data\bottle\t...,bottle,test,1,broken_large,test,0.999787,89.16088
3,C:\Users\othni\Projects\mvtec_ad\data\bottle\t...,bottle,test,1,broken_large,test,0.999999,56.286167
4,C:\Users\othni\Projects\mvtec_ad\data\bottle\t...,bottle,test,1,broken_large,test,1.0,102.899164


In [38]:
# ---------- 2.1. Seuils globaux (classification) depuis cfg ----------

def extract_global_thresholds_from_cfg(cfg: dict):
    """
    Cherche seulement les seuils globaux de classification :
      - tau_img_F1
      - tau_img_safe
    dans image_level_policies_config.json.
    """
    # cas 1 : ancien format avec cfg["global"]
    if "global" in cfg and isinstance(cfg["global"], dict):
        gb = cfg["global"]
    else:
        # cas 2 : format multi-policies (SAFE_0FN, PRECISION, etc.) avec un bloc "global"
        candidate_policy = None
        # on prend en priorité une des policies connues
        for key in ["SAFE_0FN", "INDUSTRIAL_SAFE_0FN", "INDUSTRIAL_BALANCED", "PRECISION"]:
            if key in cfg and isinstance(cfg[key], dict):
                candidate_policy = cfg[key]
                break
        # sinon, premier dict trouvé
        if candidate_policy is None:
            for v in cfg.values():
                if isinstance(v, dict):
                    candidate_policy = v
                    break
        if candidate_policy is None:
            raise ValueError("Impossible de trouver un bloc de config pertinent dans cfg.")

        # dans cette policy, soit on a un sous-dict 'global', soit les seuils sont au même niveau
        if "global" in candidate_policy and isinstance(candidate_policy["global"], dict):
            gb = candidate_policy["global"]
        else:
            gb = candidate_policy

    # tau_F1
    if "tau_F1" in gb:
        tau_img_F1 = float(gb["tau_F1"])
    elif "tau_img_F1" in gb:
        tau_img_F1 = float(gb["tau_img_F1"])
    else:
        raise KeyError("Impossible de trouver 'tau_F1' ou 'tau_img_F1' dans la config globale.")

    # tau_safe
    if "tau_safe" in gb:
        tau_img_safe = float(gb["tau_safe"])
    elif "tau_img_safe" in gb:
        tau_img_safe = float(gb["tau_img_safe"])
    else:
        raise KeyError("Impossible de trouver 'tau_safe' ou 'tau_img_safe' dans la config globale.")

    return tau_img_F1, tau_img_safe


tau_img_F1, tau_img_safe = extract_global_thresholds_from_cfg(cfg)
print("tau_img_F1   =", tau_img_F1)
print("tau_img_safe =", tau_img_safe)

# ---------- 2.2. Seuils AD par catégorie depuis ad_resnet_mahalanobis_mvtec_all.json ----------

ad_json_path = EXPERIMENTS_DIR / "ad_resnet_mahalanobis_mvtec_all.json"
df_ad = pd.read_json(ad_json_path)

print("df_ad shape :", df_ad.shape)
print(df_ad.head())

# On construit : TAU_AD_SAFE = {category: tau_safe}
# df_ad a pour index : ['auc_test', 'tau_F1', 'tau_F2', 'tau_safe', ...]
print("Index df_ad :", df_ad.index)

# On récupère la ligne 'tau_safe' : un Series (index = catégories, valeurs = tau_safe)
row_tau_safe = df_ad.loc["tau_safe"]   # Series: index = 'bottle','cable',... ; value = tau_safe

# TAU_AD_SAFE = {category: tau_safe}
TAU_AD_SAFE = row_tau_safe.astype(float).to_dict()

print("\nTAU_AD_SAFE (quelques catégories) :")
for k in list(TAU_AD_SAFE.keys())[:10]:
    print(f"  {k}: {TAU_AD_SAFE[k]}")

# Vérification que toutes les catégories du test sont couvertes
cats_test = sorted(df_scores["category"].unique())
missing_cats = set(cats_test) - set(TAU_AD_SAFE.keys())
print("\nCatégories test :", cats_test)
print("Catégories manquantes dans TAU_AD_SAFE :", missing_cats)

print("\nTAU_AD_SAFE (quelques catégories) :")
for k in list(TAU_AD_SAFE.keys())[:10]:
    print(f"  {k}: {TAU_AD_SAFE[k]}")

# Vérification que toutes les catégories du test sont couvertes
cats_test = sorted(df_scores["category"].unique())
missing_cats = set(cats_test) - set(TAU_AD_SAFE.keys())
print("\nCatégories test :", cats_test)
print("Catégories manquantes dans TAU_AD_SAFE :", missing_cats)

tau_img_F1   = 0.8200000000000001
tau_img_safe = 0.023
df_ad shape : (9, 15)
                                                       bottle  \
auc_test                                             0.995238   
tau_F1                                              32.041318   
tau_F2                                              24.861817   
tau_safe                                            24.861817   
metrics_F1  {'precision': 1.0, 'recall': 0.968253968253968...   

                                                        cable  \
auc_test                                             0.883433   
tau_F1                                              22.489762   
tau_F2                                              18.284675   
tau_safe                                            18.284675   
metrics_F1  {'precision': 0.8915662650602411, 'recall': 0....   

                                                      capsule  \
auc_test                                             0.878341   
tau_F1     

In [39]:
import numpy as np
import pandas as pd
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix

print("df_scores shape :", df_scores.shape)
print(df_scores.head())

# ===========================
# 1) Extraction des vecteurs
# ===========================
y_true   = df_scores["label"].values.astype(int)
cls_probs = df_scores["cls_prob"].values.astype(float)
ad_scores = df_scores["ad_score"].values.astype(float)

print("\ny_true.shape   :", y_true.shape)
print("cls_probs.shape:", cls_probs.shape)
print("ad_scores.shape:", ad_scores.shape)

# ===========================
# 2) Fonction utilitaire
# ===========================
def compute_metrics_from_binary(y_true, y_pred, policy_name):
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    precision, recall, f1, _ = precision_recall_fscore_support(
        y_true, y_pred, average="binary", zero_division=0
    )
    beta2 = 2.0
    denom = beta2**2 * precision + recall
    f2 = (1 + beta2**2) * precision * recall / denom if denom > 0 else 0.0
    return {
        "policy": policy_name,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "f2": f2,
        "tn": int(tn),
        "fp": int(fp),
        "fn": int(fn),
        "tp": int(tp),
    }

# ===========================
# 3) Prédictions de base
# ===========================
# On réutilise tau_img_F1, tau_img_safe ET TAU_AD_SAFE définis dans la cellule 2
print("\ntau_img_F1  =", tau_img_F1)
print("tau_img_safe=", tau_img_safe)

# Classification seule
y_clsF1   = (cls_probs >= tau_img_F1).astype(int)
y_clsSAFE = (cls_probs >= tau_img_safe).astype(int)

# AD seul : seuil par catégorie via TAU_AD_SAFE
cats = df_scores["category"].values
tau_ad = np.array([TAU_AD_SAFE[cat] for cat in cats])
y_adSAFE = (ad_scores >= tau_ad).astype(int)

print("\nQuelques y_clsF1   :", y_clsF1[:10])
print("Quelques y_clsSAFE :", y_clsSAFE[:10])
print("Quelques y_adSAFE  :", y_adSAFE[:10])

# ===========================
# 4) 4 policies simples
# ===========================
metrics_clsF1only  = compute_metrics_from_binary(y_true, y_clsF1,   "CLS_F1_ONLY")
metrics_ADSAFEonly = compute_metrics_from_binary(y_true, y_adSAFE,  "AD_SAFE_ONLY")

y_AND = (y_clsF1 & y_adSAFE).astype(int)
metrics_AND = compute_metrics_from_binary(y_true, y_AND, "AND_AD_SAFE_AND_CLS_F1")

y_OR  = ((y_clsSAFE == 1) | (y_adSAFE == 1)).astype(int)
metrics_OR = compute_metrics_from_binary(y_true, y_OR, "OR_AD_SAFE_OR_CLS_SAFE")

for m in [metrics_clsF1only, metrics_ADSAFEonly, metrics_AND, metrics_OR]:
    print(f"\n=== {m['policy']} ===")
    print(f"Precision : {m['precision']:.4f}")
    print(f"Recall    : {m['recall']:.4f}")
    print(f"F1        : {m['f1']:.4f}")
    print(f"F2        : {m['f2']:.4f}")
    print(f"TN={m['tn']}, FP={m['fp']}, FN={m['fn']}, TP={m['tp']}")


df_scores shape : (869, 8)
                                                path category split  label  \
0  C:\Users\othni\Projects\mvtec_ad\data\bottle\t...   bottle  test      1   
1  C:\Users\othni\Projects\mvtec_ad\data\bottle\t...   bottle  test      1   
2  C:\Users\othni\Projects\mvtec_ad\data\bottle\t...   bottle  test      1   
3  C:\Users\othni\Projects\mvtec_ad\data\bottle\t...   bottle  test      1   
4  C:\Users\othni\Projects\mvtec_ad\data\bottle\t...   bottle  test      1   

    defect_type final_split  cls_prob    ad_score  
0  broken_large        test  1.000000   89.614496  
1  broken_large        test  0.999992  145.374881  
2  broken_large        test  0.999787   89.160880  
3  broken_large        test  0.999999   56.286167  
4  broken_large        test  1.000000  102.899164  

y_true.shape   : (869,)
cls_probs.shape: (869,)
ad_scores.shape: (869,)

tau_img_F1  = 0.8200000000000001
tau_img_safe= 0.023

Quelques y_clsF1   : [1 1 1 1 1 1 1 1 1 1]
Quelques y_clsSAFE : 

In [40]:
# On repart de df_scores + y_clsF1 + y_adSAFE de la cellule 3

df_tmp = df_scores.copy()
df_tmp["clsF1_pred"] = y_clsF1
df_tmp["adSAFE_pred"] = y_adSAFE

# ------------------------------
# A) MIX_FN (par catégorie)
# ------------------------------
fn_mask_clsF1 = (df_tmp["label"] == 1) & (df_tmp["clsF1_pred"] == 0)
fn_categories = sorted(df_tmp.loc[fn_mask_clsF1, "category"].unique())
print("Catégories où CLS_F1 a des FN :", fn_categories)

# Policy MIX_FN: CLS_F1 sauf pour ces catégories → AD_SAFE
use_ad_for_cat = df_tmp["category"].isin(fn_categories).values
y_mix_fn = np.where(use_ad_for_cat, y_adSAFE, y_clsF1).astype(int)

metrics_mix_fn = compute_metrics_from_binary(
    y_true, y_mix_fn, "MIX_FN (CLS_F1 sauf catégories à FN → AD_SAFE)"
)

print(f"\n=== {metrics_mix_fn['policy']} ===")
print(f"Precision : {metrics_mix_fn['precision']:.4f}")
print(f"Recall    : {metrics_mix_fn['recall']:.4f}")
print(f"F1        : {metrics_mix_fn['f1']:.4f}")
print(f"F2        : {metrics_mix_fn['f2']:.4f}")
print(f"TN={metrics_mix_fn['tn']}, FP={metrics_mix_fn['fp']}, "
      f"FN={metrics_mix_fn['fn']}, TP={metrics_mix_fn['tp']}")

# ------------------------------
# B) MIX_PAIRS (par (cat, defect_type))
# ------------------------------
fn_rows_clsF1 = df_tmp[fn_mask_clsF1][
    ["category", "defect_type", "path", "label", "cls_prob", "ad_score"]
]

print("\nNombre de FN (CLS_F1_ONLY) :", len(fn_rows_clsF1))
print(fn_rows_clsF1[["category", "defect_type"]].value_counts().reset_index(name="n_FN"))

# Paires (cat, defect) problématiques
bad_pairs = set(
    zip(fn_rows_clsF1["category"], fn_rows_clsF1["defect_type"])
)
print("\nPaires (cat, defect_type) où CLS_F1 a des FN :")
for p in bad_pairs:
    print("  ", p)

# Policy MIX_PAIRS: AD_SAFE si (cat, defect) ∈ bad_pairs, sinon CLS_F1
pairs_all = list(zip(df_tmp["category"], df_tmp["defect_type"]))
use_ad_for_pair = np.array([pair in bad_pairs for pair in pairs_all])
y_mix_pairs = np.where(use_ad_for_pair, y_adSAFE, y_clsF1).astype(int)

metrics_mix_pairs = compute_metrics_from_binary(
    y_true, y_mix_pairs, "MIX_PAIRS (AD_SAFE pour (cat,defect) problématiques)"
)

print(f"\n=== {metrics_mix_pairs['policy']} ===")
print(f"Precision : {metrics_mix_pairs['precision']:.4f}")
print(f"Recall    : {metrics_mix_pairs['recall']:.4f}")
print(f"F1        : {metrics_mix_pairs['f1']:.4f}")
print(f"F2        : {metrics_mix_pairs['f2']:.4f}")
print(f"TN={metrics_mix_pairs['tn']}, FP={metrics_mix_pairs['fp']}, "
      f"FN={metrics_mix_pairs['fn']}, TP={metrics_mix_pairs['tp']}")

Catégories où CLS_F1 a des FN : ['cable', 'carpet', 'metal_nut', 'tile', 'toothbrush', 'transistor', 'wood', 'zipper']

=== MIX_FN (CLS_F1 sauf catégories à FN → AD_SAFE) ===
Precision : 0.8577
Recall    : 1.0000
F1        : 0.9234
F2        : 0.9679
TN=131, FP=105, FN=0, TP=633

Nombre de FN (CLS_F1_ONLY) : 11
     category      defect_type  n_FN
0       cable       cable_swap     2
1      carpet           thread     1
2   metal_nut          scratch     1
3        tile      gray_stroke     1
4  toothbrush        defective     1
5  transistor        bent_lead     1
6  transistor         cut_lead     1
7        wood             hole     1
8        wood          scratch     1
9      zipper  fabric_interior     1

Paires (cat, defect_type) où CLS_F1 a des FN :
   ('wood', 'hole')
   ('transistor', 'bent_lead')
   ('toothbrush', 'defective')
   ('carpet', 'thread')
   ('metal_nut', 'scratch')
   ('cable', 'cable_swap')
   ('zipper', 'fabric_interior')
   ('wood', 'scratch')
   ('tile', 'gr

In [33]:
# On repart de df_scores + y_clsF1 + y_adSAFE de la cellule 3

df_tmp = df_scores.copy()
df_tmp["clsF1_pred"] = y_clsF1
df_tmp["adSAFE_pred"] = y_adSAFE

# ------------------------------
# A) MIX_FN (par catégorie)
# ------------------------------
fn_mask_clsF1 = (df_tmp["label"] == 1) & (df_tmp["clsF1_pred"] == 0)
fn_categories = sorted(df_tmp.loc[fn_mask_clsF1, "category"].unique())
print("Catégories où CLS_F1 a des FN :", fn_categories)

# Policy MIX_FN: CLS_F1 sauf pour ces catégories → AD_SAFE
use_ad_for_cat = df_tmp["category"].isin(fn_categories).values
y_mix_fn = np.where(use_ad_for_cat, y_adSAFE, y_clsF1).astype(int)

metrics_mix_fn = compute_metrics_from_binary(
    y_true, y_mix_fn, "MIX_FN (CLS_F1 sauf catégories à FN → AD_SAFE)"
)

print(f"\n=== {metrics_mix_fn['policy']} ===")
print(f"Precision : {metrics_mix_fn['precision']:.4f}")
print(f"Recall    : {metrics_mix_fn['recall']:.4f}")
print(f"F1        : {metrics_mix_fn['f1']:.4f}")
print(f"F2        : {metrics_mix_fn['f2']:.4f}")
print(f"TN={metrics_mix_fn['tn']}, FP={metrics_mix_fn['fp']}, "
      f"FN={metrics_mix_fn['fn']}, TP={metrics_mix_fn['tp']}")

# ------------------------------
# B) MIX_PAIRS (par (cat, defect_type))
# ------------------------------
fn_rows_clsF1 = df_tmp[fn_mask_clsF1][
    ["category", "defect_type", "path", "label", "cls_prob", "ad_score"]
]

print("\nNombre de FN (CLS_F1_ONLY) :", len(fn_rows_clsF1))
print(fn_rows_clsF1[["category", "defect_type"]].value_counts().reset_index(name="n_FN"))

# Paires (cat, defect) problématiques
bad_pairs = set(
    zip(fn_rows_clsF1["category"], fn_rows_clsF1["defect_type"])
)
print("\nPaires (cat, defect_type) où CLS_F1 a des FN :")
for p in bad_pairs:
    print("  ", p)

# Policy MIX_PAIRS: AD_SAFE si (cat, defect) ∈ bad_pairs, sinon CLS_F1
pairs_all = list(zip(df_tmp["category"], df_tmp["defect_type"]))
use_ad_for_pair = np.array([pair in bad_pairs for pair in pairs_all])
y_mix_pairs = np.where(use_ad_for_pair, y_adSAFE, y_clsF1).astype(int)

metrics_mix_pairs = compute_metrics_from_binary(
    y_true, y_mix_pairs, "MIX_PAIRS (AD_SAFE pour (cat,defect) problématiques)"
)

print(f"\n=== {metrics_mix_pairs['policy']} ===")
print(f"Precision : {metrics_mix_pairs['precision']:.4f}")
print(f"Recall    : {metrics_mix_pairs['recall']:.4f}")
print(f"F1        : {metrics_mix_pairs['f1']:.4f}")
print(f"F2        : {metrics_mix_pairs['f2']:.4f}")
print(f"TN={metrics_mix_pairs['tn']}, FP={metrics_mix_pairs['fp']}, "
      f"FN={metrics_mix_pairs['fn']}, TP={metrics_mix_pairs['tp']}")


=== CLS_F1_ONLY ===
Precision : 0.9952
Recall    : 0.9826
F1        : 0.9889
F2        : 0.9851
TN=233, FP=3, FN=11, TP=622

=== AD_SAFE_ONLY ===
Precision : 0.7824
Recall    : 1.0000
F1        : 0.8779
F2        : 0.9473
TN=60, FP=176, FN=0, TP=633

=== AND_AD_SAFE_AND_CLS_F1 ===
Precision : 0.9952
Recall    : 0.9826
F1        : 0.9889
F2        : 0.9851
TN=233, FP=3, FN=11, TP=622

=== OR_AD_SAFE_OR_CLS_SAFE ===
Precision : 0.7608
Recall    : 1.0000
F1        : 0.8642
F2        : 0.9408
TN=37, FP=199, FN=0, TP=633



In [41]:
all_policies_metrics = pd.DataFrame([
    metrics_clsF1only,          # CLS_F1_ONLY
    metrics_ADSAFEonly,         # AD_SAFE_ONLY
    metrics_AND,                # AND_AD_SAFE_AND_CLS_F1
    metrics_OR,                 # OR_AD_SAFE_OR_CLS_SAFE
    metrics_mix_fn,             # MIX_FN
    metrics_mix_pairs,          # MIX_PAIRS
])

display(all_policies_metrics)

out_csv = EXPERIMENTS_DIR / "image_level_policies_summary_all_v2.csv"
all_policies_metrics.to_csv(out_csv, index=False)
print("\nRésumé des policies sauvegardé dans :", out_csv)

Unnamed: 0,policy,precision,recall,f1,f2,tn,fp,fn,tp
0,CLS_F1_ONLY,0.9952,0.982622,0.988871,0.985112,233,3,11,622
1,AD_SAFE_ONLY,0.782447,1.0,0.877947,0.947321,60,176,0,633
2,AND_AD_SAFE_AND_CLS_F1,0.9952,0.982622,0.988871,0.985112,233,3,11,622
3,OR_AD_SAFE_OR_CLS_SAFE,0.760817,1.0,0.864164,0.940844,37,199,0,633
4,MIX_FN (CLS_F1 sauf catégories à FN → AD_SAFE),0.857724,1.0,0.923414,0.96789,131,105,0,633
5,"MIX_PAIRS (AD_SAFE pour (cat,defect) problémat...",0.995283,1.0,0.997636,0.999053,233,3,0,633



Résumé des policies sauvegardé dans : C:\Users\othni\Projects\mvtec_ad\experiments\image_level_policies_summary_all_v2.csv


In [35]:
# Paires (catégorie, defect_type) où CLS_F1 est en FN (les 11 cas que tu as listés)
hard_pairs = {
    ("cable", "cable_swap"),
    ("carpet", "thread"),
    ("metal_nut", "scratch"),
    ("tile", "gray_stroke"),
    ("toothbrush", "defective"),
    ("transistor", "bent_lead"),
    ("transistor", "cut_lead"),
    ("wood", "hole"),
    ("wood", "scratch"),
    ("zipper", "fabric_interior"),
}

mask_pairs = np.array([
    (cats[i], defects[i]) in hard_pairs
    for i in range(len(cats))
])

y_mix_pairs = np.where(mask_pairs, y_adSAFE, y_clsF1).astype(int)

metrics_mix_pairs = compute_metrics(
    y_true,
    y_mix_pairs,
    "MIX_PAIRS (AD_SAFE pour (cat,defect) problématiques)",
)

=== MIX_PAIRS (AD_SAFE pour (cat,defect) problématiques) ===
Precision : 0.9953
Recall    : 1.0000
F1        : 0.9976
F2        : 0.9991
TN=233, FP=3, FN=0, TP=633



In [42]:
mode="INDUSTRIAL_MIX_PAIRS"