In [1]:
import os
import json
import re
import shutil
from pathlib import Path
import logging


import time
import random
from queue import Queue
import sys

import numpy as np
import pandas as pd
from IPython.display import display
from PIL import Image
from dotenv import load_dotenv
from tqdm import tqdm
from sklearn.metrics import accuracy_score, recall_score, confusion_matrix

sys.path.append(os.path.abspath('..'))
pd.set_option('display.max_rows', None) # Force to display all rows of tables
#pd.set_option('display.max_rows', 20)

from gemini_api import GeminiAPIHandler
from cls_eval import cal_specificity, save_eval

  from pandas.core import (


In [2]:
img_total = 500

# Dataset paths
dataset = "derm12345"
dataset_path = f"../../datasets/{dataset}"

cls_criteria = "melanocytic_binary"
#cls_criteria = "melanocytic_malignancy"
if cls_criteria == "melanocytic_binary":
    class_labels = ["melanocytic", "nonmelanocytic"]
elif cls_criteria == "melanocytic_malignancy":
    class_labels = ["melanocytic benign",
                    "melanocytic malignant",
                    "nonmelanocytic benign",
                    "nonmelanocytic indeterminate",
                    "nonmelanocytic malignant"]

metadata_path = os.path.join(dataset_path, "metadata.csv")
metadata = pd.read_csv(metadata_path)

output_dir = f"../api_cls/{dataset}_{cls_criteria}"
os.makedirs(os.path.dirname(output_dir), exist_ok=True)

log_filename = os.path.join(output_dir, "api_clscont.log")
os.makedirs(os.path.dirname(log_filename), exist_ok=True)
logging.basicConfig(
    level=logging.INFO, 
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', 
    filename=log_filename
)
logger = logging.getLogger(__name__)

In [None]:
load_dotenv(dotenv_path="../../.env")

api_key_indices = [2, 3, 4]
#api_key_indices = [1, 2, 3]
api_handler_queue = Queue()

model_name = "gemini-2.0-flash-exp"
request_interval = 3
#model_name = "gemini-2.0-flash-thinking-exp"
#request_interval = 12

for index in api_key_indices:
    api_key = os.getenv(f"GEMINI_API_KEY_{index}")
    api_handler = GeminiAPIHandler(
        api_key=api_key, 
        index=index, 
        model_name=model_name, 
        request_interval=request_interval, 
        logger=logger
    )
    api_handler_queue.put(api_handler)

In [None]:
def toggle_api_handler(old_handler, api_handler_queue, toggle_reason):
    api_handler_queue.put(old_handler) # Requeue the old API handler
    new_handler = api_handler_queue.get()
    logger.info(f"{toggle_reason}. Switch to API_Handler_{new_handler.index}")
    return new_handler

In [None]:
img_queue = Queue()

# Filter the images with non-empty label
metadata = metadata[metadata['label'].replace(r'^\s*$', pd.NA, regex=True).notna()]

img_ids = [id for id in metadata['image_id'].dropna().unique()]
random.shuffle(img_ids) # Shuffle the image IDs for even class distribution

for img_id in img_ids[:img_total]: img_queue.put(img_id)
results = pd.DataFrame()

# Parse all existing response files in the output directory
existing_response_files = [f for f in os.listdir(output_dir) if f.endswith('.json')]
for filename in existing_response_files:
    file_path = os.path.join(output_dir, filename)
    with open(file_path, 'r') as f:
        stored_result = json.load(f)
    stored_result_df = pd.DataFrame([stored_result])
    results = pd.concat([results, stored_result_df], ignore_index=True)
    

In [3]:
prompt = (
    "You are a dermatologist examining a skin lesion image.\n"
    "Classify the lesion into exactly one of the following classes:\n"
    "{class_labels_str}\n"
    "Rules:\n"
    "1. Output only the class name.\n"
    "2. Do not include any extra text, punctuation, or explanation."
).format(class_labels_str='\n'.join(f"    • {label}" for label in class_labels))
print(prompt)

You are a dermatologist examining a skin lesion image.
Classify the lesion into exactly one of the following classes:
    • melanocytic
    • nonmelanocytic
Rules:
1. Output only the class name.
2. Do not include any extra text, punctuation, or explanation.


In [None]:
def call_gemini_api_cls_lite(image_id:str):
    image_path = os.path.join(dataset_path, image_id + '.jpg')
    metadata_row = metadata[metadata['image_id'] == image_id].iloc[0]

    with Image.open(image_path) as pil_image:
        answer = api_handler.generate_from_pil_image(pil_image, prompt=prompt).rstrip('\n').lower()
    if answer not in {"melanocytic", "nonmelanocytic"}:
        answer = "malformed output"

    if cls_criteria == "melanocytic_binary":
        truth = str(metadata_row['super_class'])
    elif cls_criteria == "melanocytic_malignancy":
        truth = str(metadata_row['super_class']) + ' ' + str(metadata_row['malignancy'])

    return {
        "image_id": image_id,
        "answer": answer,
        "truth": truth
    }

In [None]:
def eval_cls_lite(results:pd.DataFrame, cls_criteria:str, labels:list, isprocessing:bool=False):
    accuracy = accuracy_score(results['truth'], results['answer'])
    cm = confusion_matrix(results['truth'], results['answer'], labels=labels)
    sensitivity = recall_score(results['truth'], results['answer'], average="macro", labels=labels, zero_division=0)
    specificity = cal_specificity(labels, cm, average="macro")
    message = f"Accuracy={accuracy:.3%}, Sensitivity={sensitivity:.3%}, Specificity={specificity:.3%}"

    if isprocessing:
        tqdm.write(message)
    else:
        print(message)

        eval_dir = os.path.join(output_dir, "evaluation_metrics")
        os.makedirs(os.path.dirname(eval_dir), exist_ok=True)

        cm_df = pd.DataFrame(cm, index=labels, columns=labels)
        save_eval(cm_df, eval_dir, "confusion_matrix", format='xlsx', index=True)

        malformed_rate = results['answer'].value_counts().get("malformed output", 0) / len(results)
        eval_metrics = {
            'Accuracy': f"{accuracy:.3%}",
            'Sensitivity (Macro)': f"{sensitivity:.3%}",
            'Specificity (Macro)': f"{specificity:.3%}",
            'Malformed Rate': f"{malformed_rate:.3%}",
        }
        save_eval(pd.DataFrame([eval_metrics]), eval_dir, "overall_evaluation", format='xlsx', index=False)
    

In [None]:
# Create a tqdm progress bar
progress_bar = tqdm(total=img_total, desc="Processing images", unit="img")
progress_bar.n = len(existing_response_files)
progress_bar.refresh()

# Monitor the progress
now = time.time()
latest_update = now

api_handler = api_handler_queue.get()
handler_task_count = 0

In [None]:
while progress_bar.n < img_total:
    image_id = img_queue.get()
    task_exists = False
    progress_made = False

    try:
        generated_response_file = os.path.join(output_dir, f"{image_id}_cls.json")
        
        if os.path.exists(generated_response_file):
            logger.info(f"Skipping image_id {image_id} as response already exists.")
            task_exists = True

        else:
            result = call_gemini_api_cls_lite(image_id)
            result_df = pd.DataFrame([result])
            results = pd.concat([results, result_df], ignore_index=True)

            with open(generated_response_file, 'w') as out_file:
                json.dump(result, out_file, indent=4)
            logger.info(f"Successfully processed image_id {image_id}, saved to {generated_response_file}")
            progress_made = True
            handler_task_count += 1

    except Exception as e:
        logger.warning(f"Error processing image_id {image_id}: {e}")
        
        # Requeue the request
        img_queue.put(image_id)

    finally:
        img_queue.task_done()

        now = time.time()
        if (now - latest_update > 180):
            api_handler = toggle_api_handler(api_handler, api_handler_queue, "API request timeout")
            handler_task_count = 0

        if not task_exists and progress_made:    
            progress_bar.update(1)
            if progress_bar.n % 50 == 0:
                eval_cls_lite(results, cls_criteria, class_labels, isprocessing=True)
            latest_update = now

progress_bar.close()

In [None]:
eval_dir = os.path.join(output_dir, "evaluation_metrics")
if os.path.exists(eval_dir):
    shutil.rmtree(eval_dir) # Delete the old evaluation file folder

In [None]:
existing_response_files = [f for f in os.listdir(output_dir) if f.endswith('.json')]
print(f"{len(existing_response_files)}/{img_total} of the images have been processed so far.\n")
eval_cls_lite(results, cls_criteria, class_labels)
results_path = os.path.join(output_dir, "results.csv")
results.to_csv(results_path, index=False)