In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_curve, auc, precision_recall_fscore_support

root_dir = "../"
sys.path.append(root_dir)
from datasets import SKINCON

data_dir = os.path.join(root_dir, "data")
results_dir = os.path.join(root_dir, "results")

dataset = SKINCON(data_dir)
attributes = dataset.attributes
image_attribute = np.array(dataset.image_attribute)
results = np.load(os.path.join(results_dir, "skincon_monet.npy"))

sns.set_theme()
sns.set_context("paper")

In [None]:
tpr, fpr, thresholds, auroc = [], [], [], []
for attribute_idx, attribute in enumerate(attributes):
    attribute_score = results[:, attribute_idx]
    ground_truth = image_attribute[:, attribute_idx]

    attribute_fpr, attribute_tpr, attribute_thresholds = roc_curve(
        ground_truth, attribute_score
    )
    attribute_auc = auc(attribute_fpr, attribute_tpr)

    tpr.append(attribute_tpr)
    fpr.append(attribute_fpr)
    thresholds.append(attribute_thresholds)
    auroc.append(attribute_auc)

sorted_idx = np.argsort(auroc)[::-1]

_, ax = plt.subplots(figsize=(5, 5))
for attribute_idx in sorted_idx:
    attribute = attributes[attribute_idx]
    ax.plot(
        fpr[attribute_idx],
        tpr[attribute_idx],
        label=f"{attribute} (AUC = {auroc[attribute_idx]:.2f})",
    )
ax.plot([0, 1], [0, 1], "--", color="gray", label="Random")
ax.set_title(f"Mean AUC: {np.mean(auroc):.2f} +- {np.std(auroc):.2f}")
ax.legend(loc="upper left", bbox_to_anchor=(1, 1))
plt.show()

In [None]:
precision, recall, fscore = [], [], []
for attribute_idx, attribute in enumerate(attributes):
    attribute = attributes[attribute_idx]
    attribute_score = results[:, attribute_idx]
    ground_truth = image_attribute[:, attribute_idx]

    attribute_thresholds = thresholds[attribute_idx]
    attribute_precision, attribute_recall, attribute_fscore = [], [], []
    for threshold in attribute_thresholds:
        attribute_pred = attribute_score >= threshold

        _precision, _recall, _fscore, _ = precision_recall_fscore_support(
            ground_truth, attribute_pred, average="binary", zero_division=0
        )
        attribute_precision.append(_precision)
        attribute_recall.append(_recall)
        attribute_fscore.append(_fscore)

    precision.append(attribute_precision)
    recall.append(attribute_recall)
    fscore.append(attribute_fscore)

_, axes = plt.subplots(1, 3, figsize=(16 / 1.5, 9 / 4), gridspec_kw={"wspace": 0.3})
for attribute_idx in sorted_idx:
    attribute = attributes[attribute_idx]
    attribute_thresholds = thresholds[attribute_idx]
    attribute_precision = precision[attribute_idx]
    attribute_recall = recall[attribute_idx]
    attribute_fscore = fscore[attribute_idx]

    ax = axes[0]
    ax.plot(attribute_thresholds, attribute_precision, label=attribute)

    ax = axes[1]
    ax.plot(attribute_thresholds, attribute_recall, label=attribute)

    ax = axes[2]
    ax.plot(attribute_thresholds, attribute_fscore, label=attribute)

ax = axes[0]
ax.set_xlabel("Threshold")
ax.set_ylabel("Precision")

ax = axes[1]
ax.set_xlabel("Threshold")
ax.set_ylabel("Recall")

ax = axes[2]
ax.set_xlabel("Threshold")
ax.set_ylabel("F1 score")
plt.show()

In [None]:
optimal_idx = [np.argmax(attribute_fscore) for attribute_fscore in fscore]
optimal_threshold = [
    thresholds[attribute_idx][idx] for attribute_idx, idx in enumerate(optimal_idx)
]
optimal_precision = [
    precision[attribute_idx][idx] for attribute_idx, idx in enumerate(optimal_idx)
]
optimal_recall = [
    recall[attribute_idx][idx] for attribute_idx, idx in enumerate(optimal_idx)
]
optimal_fscore = [
    fscore[attribute_idx][idx] for attribute_idx, idx in enumerate(optimal_idx)
]
print("Optimal thresholds:")
for (
    attribute,
    attribute_threshold,
    attribute_precision,
    attribute_recall,
    attribute_score,
) in zip(
    attributes, optimal_threshold, optimal_precision, optimal_recall, optimal_fscore
):
    print(
        f"\t{attribute}:    \t{attribute_threshold:.2f} threshold  "
        f" \t{attribute_precision:.2f} precision   \t{attribute_recall:.2f} recall"
        f" \t{attribute_score:.2f} f1 score"
    )