In [None]:
import os
import sys
import numpy as np
import json
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_curve, auc, precision_recall_fscore_support

root_dir = "../"
sys.path.append(root_dir)
from datasets import SKINCON

data_dir = os.path.join(root_dir, "data")
results_dir = os.path.join(root_dir, "results")

dataset = SKINCON(data_dir)
attributes = dataset.attributes
attributes = list(map(str.lower, attributes))

filenames = np.array([path.split("/")[-1] for path, _ in dataset.samples])
image_attribute = np.array(dataset.image_attribute)

sns.set_theme()
sns.set_context("paper")

In [None]:
model_name = "gpt-4o-mini"
response_path = os.path.join(
    results_dir, f"{model_name.replace('-', '_')}_responses.json"
)
responses = json.load(open(response_path, "r"))

results = -2 * np.ones((len(filenames), len(attributes)))
for path, response in responses.items():
    filename = path.split("/")[-1]
    image_mask = filenames == filename

    for _, choice in response.items():
        annotations = choice["annotations"]

        # labels = -2 * np.ones(len(attributes))
        for annotation in annotations:
            attribute = annotation["attribute"]
            attribute_label = annotation["label"]

            attribute_idx = attributes.index(attribute.strip())
            results[image_mask, attribute_idx] = attribute_label
            # labels[attribute_idx] = attribute_label

    # results[image_mask] = labels

n_refusals = np.sum(results == -1)
print(f"Number of refusals: {n_refusals} ({n_refusals/results.size:.2%})")

results[results == -1] = 0.0
assert np.all(results != -2)

In [None]:
tpr, fpr, thresholds, auroc = [], [], [], []
for attribute_idx, attribute in enumerate(attributes):
    attribute_score = results[:, attribute_idx]
    ground_truth = image_attribute[:, attribute_idx]

    attribute_fpr, attribute_tpr, attribute_thresholds = roc_curve(
        ground_truth, attribute_score
    )
    attribute_auc = auc(attribute_fpr, attribute_tpr)

    tpr.append(attribute_tpr)
    fpr.append(attribute_fpr)
    thresholds.append(attribute_thresholds)
    auroc.append(attribute_auc)

sorted_idx = np.argsort(auroc)[::-1]

_, ax = plt.subplots(figsize=(5, 5))
for attribute_idx in sorted_idx:
    attribute = attributes[attribute_idx]
    ax.plot(
        fpr[attribute_idx],
        tpr[attribute_idx],
        label=f"{attribute} (AUC = {auroc[attribute_idx]:.2f})",
    )
ax.plot([0, 1], [0, 1], "--", color="gray", label="Random")
ax.set_title(f"Mean AUC: {np.mean(auroc):.2f} +- {np.std(auroc):.2f}")
ax.legend(loc="upper left", bbox_to_anchor=(1, 1))
plt.show()

In [None]:
threshold = np.linspace(0, 1, 21)

precision, recall, fscore = [], [], []
for attribute_idx, attribute in enumerate(attributes):
    attribute = attributes[attribute_idx]
    attribute_score = results[:, attribute_idx]
    ground_truth = image_attribute[:, attribute_idx]

    attribute_precision, attribute_recall, attribute_fscore = [], [], []
    for t in threshold:
        attribute_pred = attribute_score >= t

        _precision, _recall, _fscore, _ = precision_recall_fscore_support(
            ground_truth, attribute_pred, average="binary", zero_division=0
        )
        attribute_precision.append(_precision)
        attribute_recall.append(_recall)
        attribute_fscore.append(_fscore)

    precision.append(attribute_precision)
    recall.append(attribute_recall)
    fscore.append(attribute_fscore)

_, axes = plt.subplots(1, 3, figsize=(16 / 1.5, 9 / 4), gridspec_kw={"wspace": 0.3})
for attribute_idx in sorted_idx:
    attribute = attributes[attribute_idx]
    attribute_thresholds = thresholds[attribute_idx]
    attribute_precision = precision[attribute_idx]
    attribute_recall = recall[attribute_idx]
    attribute_fscore = fscore[attribute_idx]

    ax = axes[0]
    ax.plot(threshold, attribute_precision, label=attribute)

    ax = axes[1]
    ax.plot(threshold, attribute_recall, label=attribute)

    ax = axes[2]
    ax.plot(threshold, attribute_fscore, label=attribute)

ax = axes[0]
ax.set_xlabel("Threshold")
ax.set_ylabel("Precision")

ax = axes[1]
ax.set_xlabel("Threshold")
ax.set_ylabel("Recall")

ax = axes[2]
ax.set_xlabel("Threshold")
ax.set_ylabel("F1 score")
plt.show()

mean_f1 = np.mean(fscore, axis=0)
_, ax = plt.subplots(figsize=(16 / 4, 9 / 4))
ax.plot(threshold, mean_f1)
plt.show()

In [None]:
t = 0.65
threshold_idx = threshold.tolist().index(t)
threshold_f1 = np.array(fscore)[:, threshold_idx]
for attribute_idx, attribute in enumerate(attributes):
    print(f"{attribute}: {threshold_f1[attribute_idx]:.2f}")