In [1]:
import os
import json
import re
import shutil
from pathlib import Path
import logging

import time
import threading
import random
from queue import Queue, Empty
from collections import defaultdict
import sys

import numpy as np
import pandas as pd
from IPython.display import display
from PIL import Image
from dotenv import load_dotenv
from tqdm import tqdm
from sklearn.metrics import accuracy_score, recall_score, confusion_matrix

sys.path.append(os.path.abspath('..'))
pd.set_option('display.max_rows', None) # Force to display all rows of tables
#pd.set_option('display.max_rows', 20)

from dataset_utility import img_col_map, label_col_map, get_metadata_row, filter_one_condition_scin_clinical
from gemini_api import GeminiAPIHandler
from cls_generation import call_gemini_api_cls
from cls_eval import eval_gemini_cls

  from pandas.core import (


In [2]:
# Dataset paths
dataset = "derm12345"
#dataset = "bcn20000"
#dataset = "hiba"
#dataset = "pad-ufes-20"
#dataset = "scin_clinical"
dataset_path = f"../../datasets/{dataset}"

run_cls = True
#run_cls = False

isMultiStep = False
if dataset == "derm12345":
    #isMultiStep = False
    isMultiStep = True

use_context = True
#use_context = False

In [3]:
metadata_path = os.path.join(dataset_path, "metadata.csv")
metadata = pd.read_csv(metadata_path)

label_col = label_col_map[dataset]
img_col = img_col_map[dataset]

output_dir = f"../api_cls/{dataset}"
if isMultiStep: 
    output_dir = output_dir + "_multistep"
os.makedirs(os.path.dirname(output_dir), exist_ok=True)

log_filename = os.path.join(output_dir, "api_clscont.log")
os.makedirs(os.path.dirname(log_filename), exist_ok=True)
logging.basicConfig(
    level=logging.INFO, 
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', 
    filename=log_filename
)
logger = logging.getLogger(__name__)


In [4]:
# Flatten the dataset directory
for dirpath, dirnames, filenames in os.walk(dataset_path, topdown=False):
        if dirpath == dataset_path:
            continue  # Skip the root itself

        for file in filenames:
            src_path = os.path.join(dirpath, file)
            dest_path = os.path.join(dataset_path, file)
            shutil.move(src_path, dest_path)

        # Optionally remove the now-empty subfolders
        os.rmdir(dirpath)

In [None]:
load_dotenv(dotenv_path="../../.env")

api_key_indices = [8, 9, 10]
#api_key_indices = [1, 2, 3]
api_handler_queue = Queue()

model_name = "gemini-2.0-flash-exp"
request_interval = 6
#model_name = "gemini-2.0-flash-thinking-exp"
#request_interval = 12

for index in api_key_indices:
    api_key = os.getenv(f"GEMINI_API_KEY_{index}")
    api_handler = GeminiAPIHandler(
        api_key=api_key, 
        index=index, 
        model_name=model_name, 
        request_interval=request_interval, 
        logger=logger
    )
    api_handler_queue.put(api_handler)

# Multi-thread configuration
NUM_WORKERS = len(api_handler_queue.queue)
TIMEOUT_BOUND = 300
HANDLER_COOLDOWN = 60

# Shared structures
worker_queues = [Queue() for _ in range(NUM_WORKERS)]
results = []
results_lock = threading.Lock()
worker_threads = []

In [6]:
def handle_timeout(worker_id, old_handler):
    try:
        new_handler = api_handler_queue.get_nowait()
        logger.info(f"[Worker_{worker_id}] Swapping to API_Handler_{new_handler.index}.")
        api_handler_queue.put(old_handler)
        return new_handler
    except Empty:
        logger.info(f"[Worker_{worker_id}] No other available handler. Sleeping {HANDLER_COOLDOWN}s before retry...")
        time.sleep(HANDLER_COOLDOWN)
        return old_handler

In [7]:
if dataset == "scin_clinical":
    # Filter the images with only one dignosis label
    metadata = filter_one_condition_scin_clinical(metadata)
elif dataset in {"derm12345", "bcn20000", "hiba", "pad-ufes-20"}:
    # Filter the images with non-empty label
    metadata = metadata[metadata[label_col].replace(r'^\s*$', pd.NA, regex=True).notna()]
else:
    raise ValueError("Invalid dataset name. Please provide a valid dataset name.")

img_ids = [f.removesuffix('.jpg').removesuffix('.png') for f in metadata[img_col].dropna().unique()]

# Distribute image IDs across worker queues
for i, image_id in enumerate(img_ids):
    worker_queues[i % NUM_WORKERS].put(image_id)

# Parse all existing response files in the output directory
old_response_files = [f for f in os.listdir(output_dir) if f.endswith('.json')]
for filename in old_response_files:
    file_path = os.path.join(output_dir, filename)
    with open(file_path, 'r') as f:
        stored_result = json.load(f)
    stored_result_df = pd.DataFrame([stored_result])
    results.append(stored_result_df)

In [8]:
# Thread-safe progress bar
progress_bar = tqdm(total=len(img_ids), desc="Processing images", unit="img")
progress_bar.n = len(old_response_files)
progress_bar.refresh()
progress_lock = threading.Lock()

def update_progress_bar(task_exists, progress_made, last_update):
    elapsed_time = time.time() - last_update
    if not task_exists and progress_made:
        last_update = time.time()
        with progress_lock:
            progress_bar.update(1)
            if progress_bar.n % 50 == 0:
                current_results = []
                with results_lock: current_results = results[:]
                current_results_df = pd.concat(current_results, ignore_index=True)
                eval_gemini_cls(current_results_df, dataset, isMultiStep, isprocessing=True)
    return last_update, elapsed_time

Processing images:  33%|███▎      | 809/2485 [00:00<00:00, 6124895.19img/s]

In [9]:
def worker_fn(worker_id, api_handler, stop_event):
    queue_ = worker_queues[worker_id - 1]
    last_update = time.time()

    try:
        while not queue_.empty() and not stop_event.is_set():
            image_id = None
            try:
                image_id = queue_.get_nowait()
            except Empty:
                continue

            task_exists = False
            progress_made = False

            try:
                metadata_row = get_metadata_row(metadata, dataset, image_id)
                generated_response_file = os.path.join(output_dir, f"{image_id}_cls.json")

                if os.path.exists(generated_response_file):
                    logger.info(f"[Worker_{worker_id}] Skipping image_id {image_id} as response already exists.")
                    task_exists = True
                else:
                    result = call_gemini_api_cls(api_handler, dataset, image_id, metadata_row, isMultiStep, use_context)
                    result_df = pd.DataFrame([result])
                    with results_lock: results.append(result_df)

                    with open(generated_response_file, 'w') as out_file:
                        json.dump(result, out_file, indent=4)

                    logger.info(f"[Worker_{worker_id}] Successfully processed {image_id} " + 
                                f"using Handler_{api_handler.index}.")
                    progress_made = True

            except Exception as e:
                logger.warning(f"[Worker_{worker_id}] Error processing {image_id} " + 
                               f"using Handler_{api_handler.index}: {e}")
                if image_id:
                    queue_.put(image_id)

            finally:
                queue_.task_done()
                last_update, elapsed_time = update_progress_bar(task_exists, progress_made, last_update)
                if elapsed_time > TIMEOUT_BOUND:
                    logger.info(f"[Worker_{worker_id}] Handler_{api_handler.index} timeout. " + 
                                f"Elapsed time: {elapsed_time:.2f}s.")
                    api_handler = handle_timeout(worker_id, api_handler)
                    last_update = time.time()

    except KeyboardInterrupt:
        logger.warning(f"[Worker_{worker_id}] Received KeyboardInterrupt. Shutting down...")

    finally:
        api_handler_queue.put(api_handler)
        logger.info(f"[Worker_{worker_id}] Finished all tasks.")

In [10]:
def run_all():
    stop_event = threading.Event()  # Shared event to signal shutdown
    worker_threads.clear()  # Reset thread list in case this is re-run

    # Start all worker threads
    for i in range(1, NUM_WORKERS + 1):
        try:
            api_handler = api_handler_queue.get_nowait()
        except Empty:
            logger.error("Not enough API handlers available for workers.")
            break

        t = threading.Thread(target=worker_fn, args=(i, api_handler, stop_event), daemon=True)
        t.start()
        worker_threads.append(t)

    try:
        # Wait for all threads to complete
        for t in worker_threads:
            t.join()
    except KeyboardInterrupt:
        logger.warning("KeyboardInterrupt received. Signaling workers to stop...")
        stop_event.set()
        for t in worker_threads:
            t.join()

    progress_bar.close()
    logger.info("All workers completed.")


if run_cls:
    run_all()

Processing images:  33%|███▎      | 809/2485 [07:27<15:27,  1.81img/s]     


In [None]:
eval_dir = os.path.join(output_dir, "evaluation_metrics")
if os.path.exists(eval_dir):
    shutil.rmtree(eval_dir) # Delete the old evaluation file folder

In [None]:
all_response_files = [f for f in os.listdir(output_dir) if f.endswith('.json')]
print(f"{len(all_response_files)}/{len(img_ids)} of the images have been processed so far.\n")
results_df = pd.concat(results, ignore_index=True)
if not results_df.empty:
    eval_gemini_cls(results_df, dataset, isMultiStep, isprocessing=False)

In [None]:
'''
# Clean up new response files to roll back if needed
new_response_files = [f for f in all_response_files if f not in old_response_files]
for filename in new_response_files:
    file_path = os.path.join(output_dir, filename)
    os.remove(file_path)
'''