In [None]:
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


In [None]:
!unzip "/content/drive/MyDrive/ViTskinmodel.zip"

In [None]:
!unzip "/content/drive/MyDrive/deitskinmodel.zip"

In [None]:
!unzip "/content/drive/MyDrive/dataset.zip"

Archive:  /content/drive/MyDrive/dataset.zip
   creating: content/dataset.hf/
   creating: content/dataset.hf/validation/
  inflating: content/dataset.hf/validation/data-00000-of-00001.arrow  
  inflating: content/dataset.hf/validation/state.json  
  inflating: content/dataset.hf/validation/dataset_info.json  
   creating: content/dataset.hf/test/
  inflating: content/dataset.hf/test/data-00000-of-00001.arrow  
  inflating: content/dataset.hf/test/state.json  
  inflating: content/dataset.hf/test/dataset_info.json  
   creating: content/dataset.hf/train/
  inflating: content/dataset.hf/train/data-00000-of-00007.arrow  
  inflating: content/dataset.hf/train/data-00006-of-00007.arrow  
  inflating: content/dataset.hf/train/data-00001-of-00007.arrow  
  inflating: content/dataset.hf/train/data-00003-of-00007.arrow  
  inflating: content/dataset.hf/train/data-00002-of-00007.arrow  
  inflating: content/dataset.hf/train/state.json  
  inflating: content/dataset.hf/train/data-00005-of-00007.

In [None]:
!pip install datasets
!pip install transformers

In [None]:
import numpy as np
import math
from tqdm import tqdm

In [None]:
from datasets import load_from_disk

ds = load_from_disk("/content/content/dataset.hf")
test_ds = ds["test"]

In [None]:
labels = ds['train'].features['label'].names

In [None]:
print(labels)

['Acne and Rosacea Photos', 'Actinic cheilitis', 'Normal Skin', 'acanthosis nigricans', 'actinic keratosis', 'alopecia', 'angiokeratoma', 'atopic dermatitis', 'atypical melanocytic proliferation', 'basal cell carcinoma', 'biting insects', 'bowens disease', 'bullous disease', 'candida', 'candidiasis', 'chondrodermatitis nodularis', 'ctcl', 'cutaneceous larva migrans', 'dermatofibroma', 'distal subungual onychomycosis', 'drug eruptions', 'eczema', 'epidermal cyst', 'fixed drug eruption', 'folliculitis', 'granuloma annulare', 'hemangioma', 'herpes', 'impetigo', 'intertrigo', 'keloids', 'keratoacanthoma', 'lentigo', 'lichen planus', 'lichenoid keratosis', 'lupus', 'melanocytic nevi', 'melanoma', 'molluscum contagiosum', 'necrobiosis lipoidica', 'neurofibromatosis', 'nevus', 'other connective tissue diseases', 'other lichen related diseases', 'other light diseases', 'other nail related diseases', 'other psoraisis related diseases', 'perleche', 'pigmented benign keratosis', 'pityriasis and r

In [None]:
from transformers import ViTFeatureExtractor
from transformers import ViTForImageClassification

vit_model = ViTForImageClassification.from_pretrained(
    "/content/content/vit-base-SKINMODEL",
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)},
    ignore_mismatched_sizes=True
)

deit_model = ViTForImageClassification.from_pretrained(
    "/content/content/deit-base-SKINMODEL",
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)},
    ignore_mismatched_sizes=True
)

vit_feature_extractor = ViTFeatureExtractor.from_pretrained("/content/content/vit-base-SKINMODEL")

deit_feature_extractor = ViTFeatureExtractor.from_pretrained("/content/content/deit-base-SKINMODEL")



In [None]:
vit_model.config.id2label

In [None]:
def norm(a):
  return min(max(a, 0), 1)

def run_vit(PIL_Image):
  inputs = vit_feature_extractor(images=PIL_Image, return_tensors="pt")
  outputs = vit_model(**inputs)
  logits = outputs.logits

  top_5_class_idx = np.array(logits.argsort())[0][::-1][:5]

  top_5_class = [(vit_model.config.id2label[str(class_idx)],
                norm( logits[0][class_idx].item()/10 ) ) for class_idx in top_5_class_idx]

  return top_5_class

def run_deit(PIL_Image):
  inputs = deit_feature_extractor(images=PIL_Image, return_tensors="pt")
  outputs = deit_model(**inputs)
  logits = outputs.logits

  top_5_class_idx = np.array(logits.argsort())[0][::-1][:5]

  top_5_class = [( vit_model.config.id2label[str(class_idx)],
                  norm( logits[0][class_idx].item()/10 ) ) for class_idx in top_5_class_idx]

  return top_5_class

In [None]:
def ensemble_2(t1, t2):
  if t1[0] == t2[0]:
    return ( t1[0], (t1[1] + t2[1])/2 )
  else:
    if t1[1] >= t2[1]:
      return ( t1[0], t1[1] )
    if t1[1] < t2[1]:
      return ( t2[0], t2[1] )

In [None]:
def predict(PIL_Image):
  deit_predicition = run_deit(PIL_Image)
  vit_prediction = run_vit(PIL_Image)

  ensemble_prediction = []

  for idx in range(len(vit_prediction)):
    ensemble_prediction.append(ensemble_2(vit_prediction[idx], deit_predicition[idx]))

  return ensemble_prediction

Statistics:

In [None]:
ex = ds["train"][3123]
ex_image = ex['pixel_values']
ex_label = (ds['test'].features['label']).int2str(ex['label'])

print(ensemble_2(ex_image))
print(ex_label)

('basal cell carcinoma', 0.9919523239135742)
basal cell carcinoma


In [None]:
sum = {label: 0 for label in ds["test"].features["label"].names}
total = {label: [] for label in ds["test"].features["label"].names}


for test_ex_index in tqdm(range(len(test_ds))):
  test_ex = test_ds[test_ex_index]

  test_ex_image = test_ex["pixel_values"]
  test_ex_label = (ds['test'].features['label']).int2str(test_ex['label'])

  prediction = ensemble_2(test_ex_image)

  total[test_ex_label] += 1

  if prediction[0] == test_ex_label:
    sum[test_ex_label] += 1

100%|██████████| 6051/6051 [22:15<00:00,  4.53it/s]


In [None]:
accuracy = {}

for label in ds["test"].features["label"].names:
  accuracy[label] = sum[label]/total[label]

In [None]:
num_images = {label: 0 for label in ds["test"].features["label"].names}

for ex in test_ds:
  num_images[ (ds['test'].features['label']).int2str(ex['label']) ] += 1

In [None]:
weighted_accuracy = 0

for disease, acc in accuracy.items():
  weighted_accuracy += acc * (num_images[disease]/len(test_ds))

In [None]:
print(weighted_accuracy)

0.7909436456784006


In [None]:
import json
with open("accuracy.json", "w") as f:
  json.dump(accuracy, f)

In [None]:
with open("num_test_images.json", "w") as f:
  json.dump(num_images, f)

In [None]:
num_images["Normal Skin"]

76

In [None]:
total_disease = 0
predict_disease = 0

for test_ex_index in tqdm(range(len(test_ds))):
  test_ex = test_ds[test_ex_index]

  test_ex_image = test_ex["pixel_values"]
  test_ex_label = (ds['test'].features['label']).int2str(test_ex['label'])

  if test_ex_label != "Normal Skin":
    prediction = ensemble_2(test_ex_image)

    total_disease += 1

    if prediction[0] == test_ex_label:
      predict_disease += 1

100%|██████████| 6051/6051 [21:53<00:00,  4.61it/s]


In [None]:
print(f"Sensativity: {100 * (predict_disease/total_disease)}")

Sensativity: 78.82845188284519


In [None]:
print(total_disease)
print(predict_disease)

5975
4710
