In [1]:
import os
import json
import shutil
from pathlib import Path
import logging
import time
import random
from queue import Queue
import sys
import re
import copy

import numpy as np
import pandas as pd
from IPython.display import display
from PIL import Image
from dotenv import load_dotenv
from tqdm import tqdm
from sklearn.metrics import accuracy_score, recall_score, confusion_matrix

sys.path.append(os.path.abspath('..'))

from gemini_api import GeminiAPIHandler

In [2]:
# Dataset paths
dataset = "derm12345"
dataset_path = f"../../datasets/{dataset}"

hierarchy_path = f"../../datasets/{dataset}/dataset_hierarchy.json"
with open(hierarchy_path, 'r') as dh:
    derm12345_hier = json.load(dh)

metadata_path = f"../../datasets/{dataset}/metadata.csv"
metadata = pd.read_csv(metadata_path)

output_dir = "../api_cls/derm12345"
os.makedirs(os.path.dirname(output_dir), exist_ok=True)

log_filename = os.path.join(output_dir, f"api_clscont.log")
os.makedirs(os.path.dirname(log_filename), exist_ok=True)
logging.basicConfig(
    level=logging.INFO, 
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', 
    filename=log_filename
)
logger = logging.getLogger(__name__)

eval_dir = "../api_cls/derm12345/evaluation_metrics"
os.makedirs(os.path.dirname(eval_dir), exist_ok=True)

In [3]:
# Flatten the dataset directory
for dirpath, dirnames, filenames in os.walk(dataset_path, topdown=False):
        if dirpath == dataset_path:
            continue  # Skip the root itself

        for file in filenames:
            src_path = os.path.join(dirpath, file)
            dest_path = os.path.join(dataset_path, file)
            shutil.move(src_path, dest_path)

        # Optionally remove the now-empty subfolders
        os.rmdir(dirpath)

In [4]:
load_dotenv(dotenv_path="../../.env")

num_key = 5
api_handler_queue = Queue()
model_name = "gemini-2.0-flash-exp"
#model_name = "gemini-2.0-flash-thinking-exp"

for index in range(1, num_key + 1):
    api_key = os.getenv(f"GEMINI_API_KEY_{index}")
    api_handler = GeminiAPIHandler(api_key=api_key, index=index, model_name=model_name)
    api_handler_queue.put(api_handler)

In [5]:
prompt = (
        f"***** SYSTEM *****\n"
        f"You are a dermatologist examining a lesion/skin disease image.\n"
        f"The image is from the DERM12345 dataset and the dataset is organised hierarchically:\n"
         
        f"• **5 super-classes** (broad diagnostic families)\n"
        f"• **15 main classes** (finer diagnostic groups inside each super-class)\n"
        f"• **40 subclasses** (leaf-level, mutually-exclusive labels used for evaluation)\n"

        f"• Besides, the `label` field is a unique identifier assigned to each subclass\n"

        f"***** GOAL *****\n" 
        f"Determine which **subclass** it belongs to based on the **DATASET HIERARCHY** provided below." 

        f"***** DATASET HIERARCHY *****\n"
        f"{derm12345_hier}\n\n"

        f"***** RULES *****\n"  
        f"1. Traverse the dataset hierarchy and determine the correct classification path for the input image.\n"
        f"2. Output **exactly one line** in the following format:\n"
        f"   <SuperClass> - <MainClass> - <SubClass> - <Label>\n"
        f"3. The values must be in lower case and match the class names and labels from the **Dataset HIERARCHY** exactly.\n"
        f"4. Do **not** include quotes, extra punctuation, or reasoning.\n"
        f"5. If the image does **not** match any super-class, main class or subclass with ≥ 90 % confidence, output **\"unknow\"** in place.\n"
        f"   • If super-class undetermined, output \"unknown - unkown - unknown - unknown\";\n"
        f"   • Else if main class undetermined, output \"<SuperClass> - unkown - unknown - unknown\";\n"
        f"   • Else if subclass undetermined, output \"<SuperClass> - <MainClass> - unknown - unknown\";\n"
)

In [6]:
def call_gemini_api_cls(api_handler: GeminiAPIHandler, image_id: str, metadata) -> dict:
    """
    Call the Gemini API for a given request and dataset image path.
    """
    image_path = os.path.join(dataset_path, image_id + ".jpg")
    
    # Load the image
    with Image.open(image_path) as pil_image:
        # Generate response
        response = api_handler.generate_from_pil_image(pil_image, prompt)
    
    parts = response.rstrip("\n").split(" - ")
    if len(parts) == 4:
        super_class_pred = parts[0].lower()
        main_class_pred = super_class_pred + " - " + parts[1].lower()
        label_pred = parts[3].lower()
    else:
        super_class_pred = "malformed output"
        main_class_pred = "malformed output"
        label_pred = "malformed output"

    row = metadata[metadata['image_id'] == image_id]
    super_class_true = row['super_class'].values[0] + " " + row['malignancy'].values[0]
    if row['main_class_1'].values[0] == row['main_class_2'].values[0]:
        main_class_true = super_class_true + " - " + row['main_class_1'].values[0]
    else:
        main_class_true = super_class_true + " - " + row['main_class_1'].values[0] + " " + row['main_class_2'].values[0]
    label_true = row['label'].values[0]
    
    result = {
        "image_id": image_id,
        "image_path": image_path,
        "response": response,
        "super_class_pred": super_class_pred,
        "super_class_true": super_class_true,
        "main_class_pred": main_class_pred,
        "main_class_true": main_class_true,
        "label_pred": label_pred,
        "label_true": label_true
    }
    
    return result

In [7]:
def toggle_api_handler(old_handler, api_handler_queue, toggle_reason):
    api_handler_queue.put(old_handler) # Requeue the old API handler
    new_handler = api_handler_queue.get()
    logger.info(f"{toggle_reason}. Switch to API_Handler_{new_handler.index}")
    return new_handler

In [8]:
def parse_api_result(result, super_true, super_pred, main_true, main_pred, sub_true, sub_pred):
    super_true.append(result['super_class_true'])
    super_pred.append(result['super_class_pred'])
    main_true.append(result['main_class_true'])
    main_pred.append(result['main_class_pred'])
    sub_true.append(result['label_true'])
    sub_pred.append(result['label_pred'])

In [9]:
def eval_overall(y_true, y_pred, granularity, isprocessing):
    accuracy = accuracy_score(y_true, y_pred)
    sensitivity = recall_score(y_true, y_pred, average='macro', zero_division=0)
    eval_metrics = {
        'Granularity': f'{granularity}',
        'Accuracy': f'{accuracy:.3%}',
        'Sensitivity': f'{sensitivity:.3%}'
    }

    message = f"• **{granularity} Level: Accuracy={accuracy:.3%}, Sensitivity={sensitivity:.3%}"
    if isprocessing:
        tqdm.write(message) # Temporarily overwrite the progress bar
    else:
        print(message)
    return eval_metrics

In [10]:
def eval_per_class(y_true, y_pred, granularity):   
    labels = sorted(set(y_true))  # all unique labels in ground truth
    cm = confusion_matrix(y_true, y_pred, labels=labels)
    sensitivities = recall_score(y_true, y_pred, average=None, labels=labels)
    specificities = []

    for i, label in enumerate(labels):
        TP = cm[i, i]
        FN = cm[i, :].sum() - TP
        FP = cm[:, i].sum() - TP
        TN = cm.sum() - (TP + FN + FP)
        specificity = TN / (TN + FP) if (TN + FP) > 0 else 0.0
        specificities.append(specificity)

    eval_data = {
        'Class': labels,
        'Sensitivity': sensitivities,
        'Specificity': specificities
    }

    eval_df = pd.DataFrame(eval_data)
    display(eval_df)
    eval_file_path = os.path.join(eval_dir, f"per_{granularity}_class.csv")
    eval_df.to_csv(eval_file_path, index=False)

    return eval_df

In [11]:
img_queue = Queue()
for idx, row in metadata.iterrows():
    img_queue.put(row['image_id'])
img_total = img_queue.qsize()

super_true = []
super_pred = []
main_true = []
main_pred = []
sub_true = []
sub_pred = []

# Parse all existing response files in the output directory
existing_response_files = [f for f in os.listdir(output_dir) if f.endswith('.json')]
for filename in existing_response_files:
    file_path = os.path.join(output_dir, filename)
    with open(file_path, 'r') as f:
        stored_result = json.load(f)
    parse_api_result(stored_result, super_true, super_pred, main_true, main_pred, sub_true, sub_pred)

In [12]:
# Create a tqdm progress bar
progress_bar = tqdm(total=img_total, desc="Processing images", unit="img")
progress_bar.n = len(existing_response_files)
progress_bar.refresh()

# Monitor the progress
now = time.time()
latest_update = now

api_handler = api_handler_queue.get()
handler_task_count = 0

while not img_queue.empty():
    image_id = img_queue.get()
    task_exists = False
    progress_made = False
    try:
        generated_response_file = os.path.join(output_dir, f"{image_id}_cls.json")
        
        if os.path.exists(generated_response_file):
            logger.info(f"Skipping image_id {image_id} as response already exists.")
            task_exists = True

        else:
            result = call_gemini_api_cls(api_handler, image_id, metadata)
            parse_api_result(result, super_true, super_pred, main_true, main_pred, sub_true, sub_pred)

            with open(generated_response_file, "w") as out_file:
                json.dump(result, out_file, indent=4)
            logger.info(f"Successfully processed image_id {image_id}, saved to {generated_response_file}")
            progress_made = True
            handler_task_count += 1

    except Exception as e:
        logger.warning(f"Error processing image_id {image_id}: {e}")
        
        # Requeue the request
        img_queue.put(image_id)

    finally:
        img_queue.task_done()

        now = time.time()
        if (now - latest_update > 180):
            api_handler = toggle_api_handler(api_handler, api_handler_queue, "API request timeout")
        elif handler_task_count >= 100:
            api_handler = toggle_api_handler(api_handler, api_handler_queue, "API_Handler_{api_handler.index} sleeps")
            handler_task_count = 0

        if not task_exists and progress_made:    
            progress_bar.update(1)
            if progress_bar.n % 50 == 0: 
                eval_overall(super_true, super_pred, "SuperClass", isprocessing=True)
                eval_overall(main_true, main_pred, "MainClass", isprocessing=True)
                eval_overall(sub_true, sub_pred, "SubClass", isprocessing=True)
                latest_update = now

progress_bar.close()

Processing images: 100%|██████████| 2485/2485 [00:00<00:00, 77856.22img/s]   


In [13]:
existing_response_files = [f for f in os.listdir(output_dir) if f.endswith('.json')]
print(f"{len(existing_response_files)}/{img_total} of the images have been processed so far.\n")
super_overall = eval_overall(super_true, super_pred, "SuperClass", isprocessing=False)
main_overall = eval_overall(main_true, main_pred, "MainClass", isprocessing=False)
sub_overall = eval_overall(sub_true, sub_pred, "SubClass", isprocessing=False)

overall_eval = pd.DataFrame([super_overall, main_overall, sub_overall])
overall_eval_file = os.path.join(eval_dir, f"overall_evaluation.csv")
os.makedirs(os.path.dirname(overall_eval_file), exist_ok=True)
overall_eval.to_csv(overall_eval_file, index=False)

2485/2485 of the images have been processed so far.

• **SuperClass Level: Accuracy=81.650%, Sensitivity=27.978%
• **MainClass Level: Accuracy=13.803%, Sensitivity=13.064%
• **SubClass Level: Accuracy=7.686%, Sensitivity=9.418%


In [14]:
per_super_class = eval_per_class(super_true, super_pred, 'super')
per_main_class = eval_per_class(main_true, main_pred, 'main')
per_sub_class = eval_per_class(sub_true, sub_pred, 'sub')


# Orgnise the evaluation files hierarchically
for super_class_key, super_class_dict in derm12345_hier.items():
    one_super_eval = per_super_class[per_super_class['Class'] == super_class_key]

    for main_class_key, main_class_dict in super_class_dict.items():
        main_class_key = super_class_key + " - " + main_class_key
        one_main_eval = per_main_class[per_main_class['Class'] == main_class_key]
        one_super_eval = pd.concat([one_super_eval, one_main_eval], ignore_index=True)

        for subclass in main_class_dict['subclasses']:
            one_sub_eval = per_sub_class[per_sub_class['Class'] == subclass.get('label')]
            one_main_eval = pd.concat([one_main_eval, one_sub_eval], ignore_index=True)
        
        one_main_eval_file = os.path.join(eval_dir, f"{main_class_key}.csv")
        one_main_eval.to_csv(one_main_eval_file, index=False)

    one_super_eval_file = os.path.join(eval_dir, f"{super_class_key}.csv")
    one_super_eval.to_csv(one_super_eval_file, index=False)


Unnamed: 0,Class,Sensitivity,Specificity
0,melanocytic benign,0.949975,0.526882
1,melanocytic malignant,0.170732,0.984975
2,nonmelanocytic benign,0.384259,0.916446
3,nonmelanocytic indeterminate,0.090909,0.998784
4,nonmelanocytic malignant,0.082803,0.999569


Unnamed: 0,Class,Sensitivity,Specificity
0,melanocytic benign - banal compound,0.837209,0.230874
1,melanocytic benign - banal dermal,0.075342,0.997422
2,melanocytic benign - banal junctional,0.011204,0.995743
3,melanocytic benign - dysplastic compound,0.135135,0.969851
4,melanocytic benign - dysplastic junctional,0.0,1.0
5,melanocytic benign - dysplastic recurrent,0.0,1.0
6,melanocytic benign - lentigo,0.052632,0.95709
7,melanocytic malignant - melanoma,0.170732,0.984899
8,nonmelanocytic benign - fibro_histiocytic,0.0,1.0
9,nonmelanocytic benign - keratinocytic,0.146341,0.96927


Unnamed: 0,Class,Sensitivity,Specificity
0,acb,0.0,1.0
1,acd,0.0,1.0
2,ajb,0.011628,0.999582
3,ajd,0.0,1.0
4,ak,0.090909,0.998783
5,alm,0.0,1.0
6,angk,0.0,0.99514
7,anm,0.0,1.0
8,bcc,0.023529,0.998746
9,bd,0.0,1.0
