In [23]:
import torch
import clip
from PIL import Image
import pandas as pd
import numpy as np

In [55]:
metadata = pd.read_csv('dataverse_files/HAM10000_metadata.csv')
metadata.head()

Unnamed: 0,lesion_id,image_id,dx,dx_type,age,sex,localization,dataset
0,HAM_0000118,ISIC_0027419,bkl,histo,80.0,male,scalp,vidir_modern
1,HAM_0000118,ISIC_0025030,bkl,histo,80.0,male,scalp,vidir_modern
2,HAM_0002730,ISIC_0026769,bkl,histo,80.0,male,scalp,vidir_modern
3,HAM_0002730,ISIC_0025661,bkl,histo,80.0,male,scalp,vidir_modern
4,HAM_0001466,ISIC_0031633,bkl,histo,75.0,male,ear,vidir_modern


In [12]:
metadata.dx.unique()

array(['bkl', 'nv', 'df', 'mel', 'vasc', 'bcc', 'akiec'], dtype=object)

In [30]:
len(metadata.image_id.unique())

10015

In [15]:
disease_dict = {"Actinic keratoses and intraepithelial carcinoma / Bowen's disease": "akiec", 
                      "basal cell carcinoma": "bcc", 
                      "benign keratosis-like lesions (solar lentigines / seborrheic keratoses and lichen-planus like keratoses)": "bkl",
                      "dermatofibroma": "df", "melanoma": "mel", "melanocytic nevi": "nv",
                      "vascular lesions (angiomas, angiokeratomas, pyogenic granulomas and hemorrhage)": "vasc"}

In [81]:
clip.available_models()

['RN50',
 'RN101',
 'RN50x4',
 'RN50x16',
 'RN50x64',
 'ViT-B/32',
 'ViT-B/16',
 'ViT-L/14',
 'ViT-L/14@336px']

In [86]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

image = preprocess(Image.open("dataverse_files/HAM10000_images_part_1/ISIC_0024306.jpg")).unsqueeze(0).to(device)
diseases = ["Actinic keratoses and intraepithelial carcinoma / Bowen's disease", 
                      "basal cell carcinoma", 
                      "benign keratosis-like lesions (solar lentigines / seborrheic keratoses and lichen-planus like keratoses)",
                      "dermatofibroma", "melanoma", "melanocytic nevi",
                      "vascular lesions (angiomas, angiokeratomas, pyogenic granulomas and hemorrhage)"]
text = clip.tokenize(diseases).to(device)

In [27]:
metadata[metadata['image_id']=='ISIC_0024306']

Unnamed: 0,lesion_id,image_id,dx,dx_type,age,sex,localization,dataset
4349,HAM_0000550,ISIC_0024306,nv,follow_up,45.0,male,trunk,vidir_molemax


In [83]:
with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    
    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

disease_probs = dict()
for i, disease in enumerate(diseases):
    disease_probs[disease_dict[disease]] = probs[0][i]
print("Label probs:", disease_probs)

Label probs: {'akiec': 0.04109688, 'bcc': 0.1986009, 'bkl': 0.10070173, 'df': 0.01837381, 'mel': 0.4465283, 'nv': 0.1378731, 'vasc': 0.056825276}


In [84]:
disease_dict[diseases[np.argmax(probs)]]

'mel'

In [54]:
from glob import glob
first_part = glob('dataverse_files/HAM10000_images_part_1/*')
second_part = glob('dataverse_files/HAM10000_images_part_2/*')
first_part.extend(second_part)
first_part_id = [x.split('/')[2].split('.')[0] for x in first_part]

df_paths = pd.DataFrame({'id': first_part_id, 'path': first_part})
df_paths.head()

Unnamed: 0,id,path
0,ISIC_0026784,dataverse_files/HAM10000_images_part_1/ISIC_00...
1,ISIC_0028971,dataverse_files/HAM10000_images_part_1/ISIC_00...
2,ISIC_0026948,dataverse_files/HAM10000_images_part_1/ISIC_00...
3,ISIC_0026790,dataverse_files/HAM10000_images_part_1/ISIC_00...
4,ISIC_0028965,dataverse_files/HAM10000_images_part_1/ISIC_00...


In [58]:
metadata = metadata.merge(df_paths, left_on='image_id', right_on='id').drop(columns=['id'])

In [66]:
from tqdm.notebook import tqdm

In [69]:
batches = np.append(np.arange(0, len(images), 64), 10015)
batches

array([    0,    64,   128,   192,   256,   320,   384,   448,   512,
         576,   640,   704,   768,   832,   896,   960,  1024,  1088,
        1152,  1216,  1280,  1344,  1408,  1472,  1536,  1600,  1664,
        1728,  1792,  1856,  1920,  1984,  2048,  2112,  2176,  2240,
        2304,  2368,  2432,  2496,  2560,  2624,  2688,  2752,  2816,
        2880,  2944,  3008,  3072,  3136,  3200,  3264,  3328,  3392,
        3456,  3520,  3584,  3648,  3712,  3776,  3840,  3904,  3968,
        4032,  4096,  4160,  4224,  4288,  4352,  4416,  4480,  4544,
        4608,  4672,  4736,  4800,  4864,  4928,  4992,  5056,  5120,
        5184,  5248,  5312,  5376,  5440,  5504,  5568,  5632,  5696,
        5760,  5824,  5888,  5952,  6016,  6080,  6144,  6208,  6272,
        6336,  6400,  6464,  6528,  6592,  6656,  6720,  6784,  6848,
        6912,  6976,  7040,  7104,  7168,  7232,  7296,  7360,  7424,
        7488,  7552,  7616,  7680,  7744,  7808,  7872,  7936,  8000,
        8064,  8128,

In [85]:
probs = []
with torch.no_grad():
    for i in tqdm(range(len(batches)-1)):
        images_batch = torch.cat([preprocess(Image.open(metadata['path'][i])).unsqueeze(0).to(device) for i in range(batches[i],
                                                                                                                    batches[i+1])])
        logits_per_image, logits_per_text = model(images_batch.to(device), text)
        probs.extend(logits_per_image.softmax(dim=-1).cpu().numpy())

  0%|          | 0/157 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [78]:
predicted_disease= []
for prob in probs:
    predicted_disease.append(disease_dict[diseases[np.argmax(prob)]])
metadata['predicted_disease'] = predicted_disease

In [80]:
sum(metadata['predicted_disease'] == metadata['dx'])

1357