In [1]:
import glob
from timeit import default_timer as timer

import numpy as np
import torch
import clip
from PIL import Image
import scipy.special

import pandas as pd
from tqdm import tqdm
import os
from pathlib import Path

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
# CONSTANTS
"""
Choose accordingly to the Model/Dataset being evaluated:
MODEL= {ViT-L-14} 
DATASET= {ISIC_2018}
"""

MODEL = "ViT-L-14"
DATASET = "ISIC_2018"
ADDITIONAL_COMMENTS = "CLASS_LABELS_PROMPTS"

In [4]:
CLASS_LABELS_PROMPTS = {
    "BKL": ["This is dermatoscopy of pigmented benign keratosis", 'This is dermoscopy of pigmented benign keratosis'],
    "NV": ["This is dermatoscopy of nevus", 'This is dermoscopy of nevus'],
    "DF": ['This is dermatoscopy of dermatofibroma', 'This is dermoscopy of dermatofibroma'],
    "MEL": ['This is dermatoscopy of melanoma', 'This is dermoscopy of melanoma'],
    "VASC": ['This is dermatoscopy of vascular lesion', 'This is dermoscopy of vascular lesion'],
    "BCC": ['This is dermatoscopy of basal cell carcinoma', 'This is dermoscopy of basal cell carcinoma'],
    "AKIEC": ['This is dermatoscopy of actinic keratosis', 'This is dermoscopy of actinic keratosis']
}

In [5]:
def calculate_similarity_score(image_features_norm,
                               prompt_target_embedding_norm,
                               prompt_ref_embedding_norm,
                               temp=1,
                               top_k=-1,
                               normalize=True):
    """
    Similarity Score used in "Fostering transparent medical image AI via an image-text foundation model grounded in medical literature"
    https://www.medrxiv.org/content/10.1101/2023.06.07.23291119v1.full.pdf
    """
    

    target_similarity = prompt_target_embedding_norm.float() @ image_features_norm.T.float()
    ref_similarity = prompt_ref_embedding_norm.float() @ image_features_norm.T.float()


    if top_k > 0:
        idx_target = target_similarity.argsort(dim=1, descending=True)
        target_similarity_mean = target_similarity[:,idx_target.squeeze()[:top_k]].mean(dim=1)
        
        ref_similarity_mean = ref_similarity.mean(dim=1)
    else:
        target_similarity_mean = target_similarity.mean(dim=1)
        ref_similarity_mean = ref_similarity.mean(dim=1)
    
    if normalize:
        similarity_score = scipy.special.softmax([target_similarity_mean.numpy(), ref_similarity_mean.numpy()], axis=0)[0, :].mean(axis=0)
    else:
        similarity_score = target_similarity_mean.mean(axis=0)

    return similarity_score

In [6]:
print(f"[INFO] DATASET: {DATASET}")
print(f"[INFO] MODEL: {MODEL}")

# Load image embeddings 
img_embeddings = np.load(f"img_embeddings/image_embeddings_{DATASET}_MONET_{MODEL}_Segmented_Norm.npy", allow_pickle=True).item()

# Load reference embeddings
reference_embeddings = torch.from_numpy(np.load(f"reference_embeddings/reference_concept_embeddings.npy")).unsqueeze(0)

results = dict()
# Iterate over images and calculate similarity
for im in img_embeddings.keys():
    img_feats = torch.from_numpy(img_embeddings[im]).unsqueeze(0)

    similarity_scores = []
    for disease_label in CLASS_LABELS_PROMPTS.keys():
        # Load text embeddings
        text_feats = torch.from_numpy(np.load(f"text_embeddings/class_label_embeddings_{disease_label}.npy")).unsqueeze(0)

        # Calculate similarity
        similarity = calculate_similarity_score(image_features_norm=img_feats,
                                                prompt_target_embedding_norm=text_feats,
                                                prompt_ref_embedding_norm=reference_embeddings,
                                                top_k=-1,
                                                temp=(1/np.exp(4.5944)),
                                                normalize=True)

        similarity_scores.append(similarity[0])

    # Save score into a dictionary w.r.t. to image
    results[im] = similarity_scores

[INFO] DATASET: ISIC_2018
[INFO] MODEL: ViT-L-14


In [8]:
# Evaluation
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, balanced_accuracy_score, auc

if DATASET == "ISIC_2018":
    gt = pd.read_csv("../data/ISIC_2018/image_classes_ISIC_2018.csv")
    
    train_images_df = pd.read_csv("../data/ISIC_2018/ISIC_2018_train.csv")
    train_images = train_images_df["images"].tolist()
    
    valiadtion_images_df = pd.read_csv("../data/ISIC_2018/ISIC_2018_validation.csv")
    validation_images = valiadtion_images_df["images"].tolist()
    
    test_images_df = pd.read_csv("../data/ISIC_2018/ISIC_2018_test.csv")
    test_images = test_images_df["images"].tolist()

y_true = []
y_pred = []
y_pred_probs = []
for im in results.keys():
    
    if str(im) in test_images:
        #y_true.append(gt.loc[gt['images'] == str(im)]['labels'].tolist()[0])
        y_true.append(1 if gt.loc[gt['images'] == str(im)]['labels'].tolist()[0] == 3 else 0)
        y_pred.append(1 if np.argmax(results[im]) == 3 else 0)
        y_pred_probs.append(results[im])

print("Classification Report:")
print(classification_report(y_true=y_true, y_pred=y_pred, target_names=["MEL", "NON-MEL"]))

# Calculate the confusion matrix
conf_matrix = confusion_matrix(y_true, y_pred)
TP = conf_matrix[1][1]
TN = conf_matrix[0][0]
FP = conf_matrix[0][1]
FN = conf_matrix[1][0]

print("Confusion Matrix:")
print(conf_matrix, "\n")

# BACC
bacc = balanced_accuracy_score(y_true, y_pred)
print(f"BACC: {bacc}")

# Sensitivity
SE = TP / (TP + FN)
print(f"Sensitivity: {SE}")

# Specificity
SP = TN / (TN + FP)
print(f"Specificity: {SP}")

Classification Report:
              precision    recall  f1-score   support

         MEL       0.91      0.90      0.91      1340
     NON-MEL       0.29      0.32      0.30       171

    accuracy                           0.84      1511
   macro avg       0.60      0.61      0.61      1511
weighted avg       0.84      0.84      0.84      1511

Confusion Matrix:
[[1209  131]
 [ 117   54]] 

BACC: 0.6090141398271799
Sensitivity: 0.3157894736842105
Specificity: 0.9022388059701493
