In [1]:
import pickle

In [2]:
def load_pickle(file_path):
    """
    Load a pickle file and return the object contained within it.

    :param file_path: Path to the pickle file.
    :return: The object loaded from the pickle file.
    """
    try:
        with open(file_path, 'rb') as file:
            return pickle.load(file)
    except Exception as e:
        print(f"Error loading pickle file: {e}")
        return None

In [5]:
train_set_names = [
    "trained_on_nail_dataset_v5_train_default",
    "trained_on_mvtec_default",
    "trained_on_visa_default",
]
TEST_DATASET_NAME = "test_on_hard_test_case"

In [None]:

result_dict_nail = load_pickle(f'/workspace/poc_jci_Crane/results/trained_on_nail_dataset_v5_train_default/{TEST_DATASET_NAME}/magnesium-massachusetts/epoch_5/results.pkl')
print("Nail dataset results loaded successfully.")
result_dict_mvtec = load_pickle(f'/workspace/poc_jci_Crane/results/trained_on_mvtec_default/{TEST_DATASET_NAME}/magnesium-massachusetts/epoch_5/results.pkl')
print("MVTec dataset results loaded successfully.")
result_dict_visa = load_pickle(f'/workspace/poc_jci_Crane/results/trained_on_visa_default/{TEST_DATASET_NAME}/magnesium-massachusetts/epoch_5/results.pkl')
print("VISA dataset results loaded successfully.")

Nail dataset results loaded successfully.
MVTec dataset results loaded successfully.
VISA dataset results loaded successfully.


In [7]:
# Get Each sample's path & predtction & ground truth
import pandas as pd

result_df_dict = {}
for train_set_name in ["nail", "mvtec", "visa"]:
    result_dict = eval(f"result_dict_{train_set_name}")
    result_df_dict[train_set_name] = pd.DataFrame(
        {
            "img_paths": result_dict[0]["img_paths"],
            "gt_sp": [gt.item() for gt in result_dict[0]["gt_sp"]],
            "pr_sp": [pr.item() for pr in result_dict[0]["pr_sp"]],
        }
    )

In [8]:
result_df_dict["nail"].head()

Unnamed: 0,img_paths,gt_sp,pr_sp
0,/workspace/data/hard_test_case/test/nail/fault...,1,0.442821
1,/workspace/data/hard_test_case/test/nail/fault...,1,0.39586
2,/workspace/data/hard_test_case/test/nail/fault...,1,0.331715
3,/workspace/data/hard_test_case/test/nail/fault...,1,0.551333
4,/workspace/data/hard_test_case/test/nail/fault...,1,0.331385


In [9]:
from sklearn.metrics import precision_recall_curve, f1_score, accuracy_score, precision_score, recall_score

def find_best_threshold(gt, scores, method="max_f1"):
    """
    Args:
        gt: Ground truth labels (list or np.array), shape (N,)
        scores: Model anomaly scores (list or np.array), shape (N,)
        method: "max_f1" | "youden_j"

    Returns:
        best_threshold: 최적 threshold
        best_metrics: dict (precision, recall, f1, accuracy)
    """
    precision, recall, thresholds = precision_recall_curve(gt, scores)
    f1_scores = 2 * precision * recall / (precision + recall + 1e-8)

    if method == "max_f1":
        best_idx = f1_scores.argmax()
        best_threshold = thresholds[best_idx]
    elif method == "youden_j":
        # Youden's J: TPR - FPR 를 최대화하는 threshold
        from sklearn.metrics import roc_curve
        fpr, tpr, roc_thresholds = roc_curve(gt, scores)
        j_scores = tpr - fpr
        best_idx = j_scores.argmax()
        best_threshold = roc_thresholds[best_idx]
    else:
        raise ValueError("Unknown method")

    # 이진 분류로 변환
    pred_binary = (scores >= best_threshold).astype(int)

    best_metrics = {
        "threshold": best_threshold,
        "precision": precision_score(gt, pred_binary),
        "recall": recall_score(gt, pred_binary),
        "f1": f1_score(gt, pred_binary),
        "accuracy": accuracy_score(gt, pred_binary),
    }

    return best_threshold, best_metrics


In [10]:
train_set_names

['trained_on_nail_dataset_v5_train_default',
 'trained_on_mvtec_default',
 'trained_on_visa_default']

In [None]:
# gt: [0, 1, 0, 1, 0, ...]
# scores: [0.12, 0.95, 0.3, 0.88, 0.15, ...]
for train_set_name, train_set_dir in zip(["nail", "mvtec", "visa"], train_set_names):
    result_df = result_df_dict[train_set_name]
    best_th, metrics = find_best_threshold(result_df["gt_sp"], result_df["pr_sp"], method="youden_j")

    print("Best Threshold:", best_th)
    print("Metrics at best threshold:")
    for k, v in metrics.items():
        print(f"{k}: {v:.4f}")
    result_df["pr_sp_binary"] = (result_df["pr_sp"] >= best_th).astype(int)
    result_df["failure"] = result_df["pr_sp_binary"] != result_df["gt_sp"]
    # Save the result DataFrame with binary predictions
    result_df.to_csv(f"/workspace/poc_jci_Crane/results/{train_set_dir}/{TEST_DATASET_NAME}/magnesium-massachusetts/epoch_5/results_with_binary_best_threshold_{best_th}_youden_j.csv", index=False)

Best Threshold: 0.3128008544445038
Metrics at best threshold:
threshold: 0.3128
precision: 0.7288
recall: 0.8600
f1: 0.7890
accuracy: 0.7677
Best Threshold: 0.4002387225627899
Metrics at best threshold:
threshold: 0.4002
precision: 0.8000
recall: 0.8800
f1: 0.8381
accuracy: 0.8283
Best Threshold: 0.335653692483902
Metrics at best threshold:
threshold: 0.3357
precision: 0.8070
recall: 0.9200
f1: 0.8598
accuracy: 0.8485


In [None]:
# gt: [0, 1, 0, 1, 0, ...]
# scores: [0.12, 0.95, 0.3, 0.88, 0.15, ...]
for train_set_name, train_set_dir in zip(["nail", "mvtec", "visa"], train_set_names):
    result_df = result_df_dict[train_set_name]
    best_th, metrics = find_best_threshold(result_df["gt_sp"], result_df["pr_sp"], method="max_f1")

    print("Best Threshold:", best_th)
    print("Metrics at best threshold:")
    for k, v in metrics.items():
        print(f"{k}: {v:.4f}")
    result_df["pr_sp_binary"] = (result_df["pr_sp"] >= best_th).astype(int)
    result_df["failure"] = result_df["pr_sp_binary"] != result_df["gt_sp"]
    # Save the result DataFrame with binary predictions
    result_df.to_csv(f"/workspace/poc_jci_Crane/results/{train_set_dir}/{TEST_DATASET_NAME}/magnesium-massachusetts/epoch_5/results_with_binary_best_threshold_{best_th}.csv", index=False)

Best Threshold: 0.30018168687820435
Metrics at best threshold:
threshold: 0.3002
precision: 0.6970
recall: 0.9200
f1: 0.7931
accuracy: 0.7576
Best Threshold: 0.4002387225627899
Metrics at best threshold:
threshold: 0.4002
precision: 0.8000
recall: 0.8800
f1: 0.8381
accuracy: 0.8283
Best Threshold: 0.335653692483902
Metrics at best threshold:
threshold: 0.3357
precision: 0.8070
recall: 0.9200
f1: 0.8598
accuracy: 0.8485


In [13]:
result_df_dict["nail"][result_df_dict["nail"]["failure"]]

Unnamed: 0,img_paths,gt_sp,pr_sp,pr_sp_binary,failure
14,/workspace/data/hard_test_case/test/nail/fault...,1,0.281871,0,True
23,/workspace/data/hard_test_case/test/nail/fault...,1,0.235527,0,True
24,/workspace/data/hard_test_case/test/nail/fault...,1,0.255743,0,True
28,/workspace/data/hard_test_case/test/nail/fault...,1,0.250162,0,True
50,/workspace/data/hard_test_case/test/nail/good/...,0,0.581607,1,True
52,/workspace/data/hard_test_case/test/nail/good/...,0,0.311107,1,True
55,/workspace/data/hard_test_case/test/nail/good/...,0,0.496947,1,True
56,/workspace/data/hard_test_case/test/nail/good/...,0,0.470409,1,True
57,/workspace/data/hard_test_case/test/nail/good/...,0,0.497625,1,True
58,/workspace/data/hard_test_case/test/nail/good/...,0,0.335207,1,True


In [None]:
for train_set_name, train_set_dir in zip(["nail", "mvtec", "visa"], train_set_names):

    result_df_dict[train_set_name][result_df_dict[train_set_name]["failure"]].to_csv(
        f"/workspace/poc_jci_Crane/results/{train_set_dir}/{TEST_DATASET_NAME}/magnesium-massachusetts/epoch_5/failure_samples.csv",
        index=False)

In [15]:
result_df_dict["mvtec"][result_df_dict["mvtec"]["failure"]]

Unnamed: 0,img_paths,gt_sp,pr_sp,pr_sp_binary,failure
14,/workspace/data/hard_test_case/test/nail/fault...,1,0.357465,0,True
17,/workspace/data/hard_test_case/test/nail/fault...,1,0.321199,0,True
23,/workspace/data/hard_test_case/test/nail/fault...,1,0.217706,0,True
24,/workspace/data/hard_test_case/test/nail/fault...,1,0.361454,0,True
28,/workspace/data/hard_test_case/test/nail/fault...,1,0.191071,0,True
38,/workspace/data/hard_test_case/test/nail/fault...,1,0.152973,0,True
55,/workspace/data/hard_test_case/test/nail/good/...,0,0.875139,1,True
56,/workspace/data/hard_test_case/test/nail/good/...,0,0.81709,1,True
57,/workspace/data/hard_test_case/test/nail/good/...,0,0.868113,1,True
58,/workspace/data/hard_test_case/test/nail/good/...,0,0.534236,1,True


In [16]:
result_df_dict["visa"][result_df_dict["visa"]["failure"]]

Unnamed: 0,img_paths,gt_sp,pr_sp,pr_sp_binary,failure
23,/workspace/data/hard_test_case/test/nail/fault...,1,0.194131,0,True
28,/workspace/data/hard_test_case/test/nail/fault...,1,0.219168,0,True
38,/workspace/data/hard_test_case/test/nail/fault...,1,0.187131,0,True
42,/workspace/data/hard_test_case/test/nail/fault...,1,0.25731,0,True
55,/workspace/data/hard_test_case/test/nail/good/...,0,0.713414,1,True
56,/workspace/data/hard_test_case/test/nail/good/...,0,0.65473,1,True
57,/workspace/data/hard_test_case/test/nail/good/...,0,0.763688,1,True
58,/workspace/data/hard_test_case/test/nail/good/...,0,0.395134,1,True
60,/workspace/data/hard_test_case/test/nail/good/...,0,0.64072,1,True
64,/workspace/data/hard_test_case/test/nail/good/...,0,0.533246,1,True
