In [None]:
import os
import sys
import numpy as np
import pandas as pd
import torch
import torchvision.transforms as T
import pickle
import matplotlib.pyplot as plt
import seaborn as sns
from torchvision.utils import make_grid
from sklearn.metrics import confusion_matrix

root_dir = "../"
sys.path.append(root_dir)
from datasets import HAM

data_dir = os.path.join(root_dir, "data")
results_dir = os.path.join(root_dir, "results")
ham_dir = os.path.join(data_dir, "HAM10000")

transform = T.Compose([T.Resize(224), T.CenterCrop(224), T.ToTensor()])
dataset = HAM(data_dir, train=False, transform=transform)
classes = dataset.classes
label = np.array([label for _, label in dataset.samples])

sns.set_theme()
sns.set_context("paper")

In [None]:
m = 4
image = []
for class_idx, class_name in enumerate(classes):
    idx = np.where(label == class_idx)[0]
    idx = np.random.choice(idx, m)
    image.append(torch.stack([dataset[i][0] for i in idx]))
image = torch.stack(image, dim=1)
image = image.flatten(0, 1)

print(classes)
_, ax = plt.subplots(figsize=(16 / 2, 9 / 2))
grid = make_grid(image, nrow=len(classes))
ax.imshow(grid.permute(1, 2, 0))
ax.axis("off")
plt.show()

In [None]:
biomedclip_results = pd.read_csv(os.path.join(results_dir, "ham_biomedclip.csv"))
label = biomedclip_results["label"].values.tolist()
prediction = biomedclip_results["prediction"].values.tolist()

cm = confusion_matrix(label, prediction, normalize="true")
class_accuracy = np.diag(cm)
sorted_class_idx = np.argsort(class_accuracy)[::-1]
sorted_class_accuracy = class_accuracy[sorted_class_idx]
sorted_classes = [classes[i] for i in sorted_class_idx]

_, ax = plt.subplots(figsize=(5, 5))
sns.heatmap(cm, annot=True, fmt=".2f", cmap="viridis", ax=ax)
ax.set_xticklabels(classes, rotation=45, ha="right")
ax.set_yticklabels(classes, rotation=0)
plt.show()

In [None]:
biomedclip_skincon_cbm = pickle.load(
    open(os.path.join(root_dir, "weights", "ham_biomedclip_skincon_cbm.pkl"), "rb")
)

accuracy_df = {"attribute": [], "val_accuracy": []}
for attribute, data in biomedclip_skincon_cbm.items():
    val_accuracy = data["val_score"]
    accuracy_df["attribute"].extend([attribute.replace("_", " ")] * len(val_accuracy))
    accuracy_df["val_accuracy"].extend(val_accuracy.tolist())
accuracy_df = pd.DataFrame(accuracy_df)

_, ax = plt.subplots(figsize=(16, 9 / 2))
sns.boxplot(data=accuracy_df, x="attribute", y="val_accuracy", ax=ax)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
ax.set_xlabel("Attribute")
ax.set_ylabel("Validation accuracy")
plt.show()

In [None]:
label = np.array([label for _, label in dataset.samples])

with open(os.path.join(ham_dir, "cbm_attributes.txt"), "r") as f:
    cbm_attributes = f.read().splitlines()

biomedclip_skincon_margin = np.load(
    os.path.join(ham_dir, "val_ham_biomedclip_skincon_cbm_margin.npy")
)
skincon_annotation = biomedclip_skincon_margin > 0

skincon_class_frequency = np.empty((len(classes), skincon_annotation.shape[1]))
for i in range(len(classes)):
    class_mask = label == i
    skincon_class_frequency[i] = np.mean(skincon_annotation[class_mask], axis=0)

skincon_class_binary = skincon_class_frequency > np.mean(
    skincon_class_frequency, axis=-1, keepdims=True
)

_, ax = plt.subplots(figsize=(16, 9 / 2))
sns.heatmap(
    skincon_class_frequency,
    annot=True,
    fmt=".2f",
    cmap="viridis",
    ax=ax,
)
ax.set_xticklabels(cbm_attributes, rotation=45, ha="right")
ax.set_yticklabels(classes, rotation=0)
plt.show()

_, ax = plt.subplots(figsize=(16, 9 / 2))
sns.heatmap(skincon_class_binary, linewidths=0.5, ax=ax)
ax.set_xticklabels(cbm_attributes, rotation=45, ha="right")
ax.set_yticklabels(classes, rotation=0)
plt.show()

In [None]:
m = np.ceil(len(classes) / 2).astype(int)
_, axes = plt.subplots(
    2, m, figsize=(16, 9), gridspec_kw={"hspace": 0.3, "wspace": 0.75}
)
for i, (class_idx, class_name) in enumerate(zip(sorted_class_idx, sorted_classes)):
    n_images = np.sum(label == class_idx)
    class_accuracy = sorted_class_accuracy[i]

    ax = axes[i // m, i % m]

    _class_frequency = skincon_class_frequency[class_idx]
    _sorted_class_frequency_idx = np.argsort(_class_frequency)[::-1]
    _sorted_class_frequency = _class_frequency[_sorted_class_frequency_idx]
    _sorted_clss_attribute = [cbm_attributes[i] for i in _sorted_class_frequency_idx]
    sns.barplot(x=_sorted_class_frequency, y=_sorted_clss_attribute, ax=ax)
    ax.set_xlabel("Frequency")
    ax.set_xlim(0, 1)
    ax.set_xticks([0, 0.2, 0.4, 0.6, 0.8, 1.0])
    ax.set_title(
        f"Class: {class_name}\n({n_images} images, {class_accuracy:.0%} accuracy)"
    )
plt.show()