## Experiment classification Accuracy when the Agent has access to the saliency map
## Setup

In [None]:
import warnings
from typing import *
import os
from dotenv import load_dotenv
from langgraph.checkpoint.memory import MemorySaver
from langchain_openai import ChatOpenAI
from transformers import logging
import sys

ROOT = ...
sys.path.append(ROOT)


from medDerm.agent import *
from medDerm.tools import *
from medDerm.utils import *
from experiments.Ham10k.experiment_utils import evaluate_gpt4o_reliability


import json
import os
import glob
import logging
from datetime import datetime
import re
import tqdm
import base64

warnings.filterwarnings("ignore")
_ = load_dotenv()

In [None]:
PROMPT_FILE = f"{ROOT}/medDerm/docs/system_prompts.txt"
BENCHMARK_DIR = ... # Path to the benchmark dataset directory
BENCHMARK_GT_FILE = f"{ROOT}/datasets/ISIC2018_Task3_Test_GroundTruth/ISIC2018_Task3_Test_GroundTruth.csv"
BENCHMARK_GT_METADATA_FILE = f"{ROOT}/datasets/ISIC2018_Task3_Test_GroundTruth/ISIC2018_Task3_Metadata.csv"




model_name = "medDerm"
temperature = 0.2
medDerm_logs = f"{ROOT}/experiments/medDerm_logs"

os.makedirs(medDerm_logs, exist_ok=True)

log_filename = f"{medDerm_logs}/{model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
logging.basicConfig(filename=log_filename, level=logging.INFO, format="%(message)s", force=True)
device = "cuda:2"

Class esplaination:
- MEL: Melanoma is a malignant neoplasm derived from melanocytes that may appear in different variants.

- NV: Melanocytic nevi are benign neoplasms of melanocytes and appear in a myriad of variants, which all are included in our series.

- BCC: Basal cell carcinoma is a common variant of epithelial skin cancer that rarely metastasizes but grows destructively if untreated.

- AKIEC: Actinic Keratoses (Solar Keratoses) and Intraepithelial Carcinoma (Bowen’s disease) are common non-invasive, variants of squamous cell carcinoma that can be treated locally without surgery.

- BKL: "Benign keratosis" is a generic class that includes seborrheic keratoses ("senile wart"), solar lentigo

- DF: Dermatofibroma is a benign skin lesion regarded as either a benign proliferation or an inflammatory reaction to minimal trauma

- VASC: Vascular skin lesions in the dataset range from cherry angiomas to angiokeratomas31 and pyogenic granulomas32. Hemorrhage is also included in this category.

## Utility functions

In [None]:
def get_tools():
    classification_tool = MuteClassifierTool(
        pretrained=False,
        device=device,
        config_path=f"{ROOT}/checkpoints/exp-HAM+Derm7pt-all+BCN+HAM-bin+DermNet+Fitzpatrick.yaml",
        output_head="All",
    )
    explanation_tool = ExplanationTool()
    return [
        classification_tool,
        explanation_tool
    ]

def get_agent(openai_kwargs,model_name="gpt-4o",tools=[],prompt_type="ISIC_CLASSIFICATION_GPT_BINARY"):
    prompts = load_prompts_from_file(PROMPT_FILE)
    prompt = prompts[prompt_type]
    checkpointer = MemorySaver()
    model = ChatOpenAI(model=model_name, temperature=temperature, top_p=0.95,**openai_kwargs)
    agent = Agent(
        model,
        tools=tools,
        log_tools=True,
        log_dir=f"{ROOT}/logs",
        system_prompt=prompt,
        checkpointer=checkpointer,
    )
    thread = {"configurable": {"thread_id": "1"}}
    return agent, thread

def run_medrax(agent, thread, prompt, image_path, use_tools=False):
    """
    Executes the medrax model with images passed as url in the prompt in order to able the agent to send it to the classification model .

    Args:
        agent: The medDerm agent.
        thread: The thread configuration.
        prompt: The prompt to send to the model.
        image_path:image path.

    Returns:
        Final result and agent state.
    """
    messages = []
    with open(image_path, "rb") as img_file:
        img_base64 = base64.b64encode(img_file.read()).decode("utf-8")
    if use_tools:
        messages.append({"role": "user", "content": f"the image is located at: {image_path}"}) #for the tools

    messages.append(
        {
            "role": "user",
            "content": [
                {
                    "type": "image_url",
                    "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"},
                }
            ],
        }
    )
    messages.append(
        {"role": "user", "content": [{"type": "text", "text": prompt}]}
    )
    final_response = None
    for event in agent.workflow.stream({"messages": messages}, thread):
        for v in event.values():
            final_response = v
    final_response = final_response["messages"][-1].content.strip()
    agent_state = agent.workflow.get_state(thread)

    return final_response, str(agent_state)


## Postprocess

In [None]:
def process_data(data):
    """
    Process the answer and return the indended one ('MEL', 'NV', 'BCC', 'AKIEC', 'BKL', 'DF', 'VASC') or -1 in case of error.
    
    Args:
    data: Answer to be processed

    Returns:
    answer: Processed answer
    """
    mapping={
        "Melanocytic Nevus":"NV",
        "benign keratosis-like lesions": "BKL",
        "Melanoma":"MEL",
        "Basal Cell Carcinoma":"BCC"

    }
    classes = ['MEL', 'NV', 'BCC', 'AKIEC', 'BKL', 'DF', 'VASC',"Melanocytic Nevus", "benign keratosis-like lesions","Melanoma","Basal Cell Carcinoma"]
    for cls in classes:
        if re.search(rf'\b{cls}\b', data):
            if cls in mapping:
                return mapping[cls]
            return cls
    with open('wrong_answers.txt', 'a') as file:
        file.write(data + '\n')
    return -1

def process_data_binary(data):
    """
    Process the answer and return the indended one ('1', '0') or -1 in case of error.

    Args:
    data: Answer to be processed
    Returns:
    answer: Processed answer
    """
    classes = [1, 0]
    searched=['YES','NO']
    for i, cls in enumerate(classes):
        if re.search(rf'\b{searched[i]}\b', data.upper()):
            return cls
    with open('wrong_answers.txt', 'a') as file:
        file.write(data + '\n')
   
    return -1

def read_csv_to_dict(csv_file_path):
    """
    Reads a CSV file with one-hot encoded classes and converts it into a dictionary.

    Args:
    csv_file_path: Path to the CSV file

    Returns:
    A dictionary with image names as keys and class names as values
    """
    class_labels = ['MEL', 'NV', 'BCC', 'AKIEC', 'BKL', 'DF', 'VASC']
    image_class_dict = {}

    try:
        with open(csv_file_path, 'r') as csvfile:
            for idx, line in enumerate(csvfile):
                if idx == 0:  # Skip the header line
                    continue
                line = line.strip()
                parts = line.split(',')
                if len(parts) < 2:
                    logging.warning(f"Skipping malformed line: {line}")
                    continue
                image_name = parts[0]
                try:
                    class_values = list(map(float, parts[1:]))
                except ValueError:
                    logging.warning(f"Skipping line with invalid class values: {line}")
                    continue
                if len(class_values) != len(class_labels):
                    logging.warning(f"Skipping line due to length mismatch: {line}")
                    continue
                for i, value in enumerate(class_values):
                    if value == 1:
                        image_class_dict[image_name] = class_labels[i]
                        break
        return image_class_dict
    except FileNotFoundError:
        logging.error(f"File not found: {csv_file_path}")
        return {}
    except Exception as e:
        logging.error(f"Error reading CSV file: {e}")
        return {}
def create_gt_dict(dataset_path):
    """
    Reads creates a dictionary mapping image names to their respective classes.

    Args:
    dataset_path: Path to the dataset directory
    Returns:
    A dictionary with image names as keys and class names as values
    """
    image_class_dict = {}
    
    Ham10k = glob.glob(os.path.join(dataset_path, "*.jpg"))
    Image_net = glob.glob(os.path.join(dataset_path, "*.JPEG"))
    
    for image_path in Ham10k:
        image_name = os.path.basename(image_path)[:-4]
        image_class_dict[image_name] = 1  # 1 for Ham10k

    for image_path in Image_net:
        image_name = os.path.basename(image_path)[:-5]  # Remove the .JPEG extension
        image_class_dict[image_name] = 0  # 0 for ImageNet

    return image_class_dict

def postProcessingResults(file_path_results,classification_type="multi-class",use_tools=False):
    """
    print the accuracy of the results compared to the ground truth.
    
    Args:
    file_path: Path to the JSON file

    Returns:
    accuracy: Accuracy of the predictions
    """
    if classification_type=="multi-class":
        GT_dict = read_csv_to_dict(BENCHMARK_GT_FILE)
        corrects=0
        wrongs=0
        bads=0
        try:
            with open(file_path_results, 'r') as file:
                data = json.load(file)
                for item in data:
                    for image_name, predicted_class in item.items():
                        if predicted_class == "Error" :
                            bads+=1
                        else:
                            predicted_class = process_data(predicted_class)
                            if predicted_class == -1:
                                bads+=1
                                continue
                            if image_name.endswith(".jpg"):
                                image_name = image_name[:-4]  # Remove the .jpg extension
                            if image_name in GT_dict:
                                if GT_dict[image_name] == predicted_class:
                                    corrects += 1
                                else:
                                    wrongs += 1
                            else:
                                print(f"Image {image_name} not found in ground truth.")
            logs = glob.glob(os.path.join(ROOT,"logs", "*.json"))
            cnt=0
            for l in logs:
                with open(l, 'r') as file:
                    tool_calls = json.load(file)
                for t in tool_calls:
                    if t["name"]=="explanation_tool":
                        cnt+=1
            total = corrects + wrongs + bads
            accuracy = corrects / total if total > 0 else 0
            print(f"Accuracy: {accuracy:.4f}")
            print(f"Corrects: {corrects}")
            print(f"Wrongs: {wrongs}")
            print(f"Bads: {bads}")
            print(f"Explanation tool called {cnt} times")
        except Exception as e:
            print(f"Error reading JSON file: {e}")
            return -1
    else:
        print("Error: classification type not supported")
        return -1


## Evaluate the model

In [None]:
def evaluate_medDerm(model_name, image_dir,n_samples=10,output_file=None,classification_type="multi-class",openai_kwargs=None, use_tools=False, prompt_type="ISIC_CLASSIFICATION_GPT_BINARY",only_errors=False):
    """
    Evaluates medDerm on a set of images and classifies each one into one of the seven classes.

    Args:
        tools: list of tools to include in the agent.
        model_name: the name of the model used as central agent.
        image_dir: Directory containing the images to classify.
        n_samples: Number of images to classify.

    Returns:
        A list of JSON objects, each containing the classification result for an image.
    """
    
    results = []
    starts_from=0
    
    if os.path.exists(output_file):
        with open(output_file, "r") as f:
            results = json.load(f)
            print(f"Already processed {len(results)} images, starting from {len(results)}")
            starts_from = len(results)
            n_samples = len(results) + n_samples
    image_paths = glob.glob(os.path.join(image_dir, "*.jpg"))
    image_paths += glob.glob(os.path.join(image_dir, "*.JPEG"))

    if only_errors:
        error_log_file = f"{ROOT}/experiments/results/Ham10kClassification/error_logs_Mute.json"
        with open(error_log_file, "r") as f:
            error_results = json.load(f)
        error_image_names = [item["image_name"] for item in error_results]
        image_paths = [p for p in image_paths if os.path.basename(p).rsplit('.', 1)[0] in error_image_names]
    if n_samples > len(image_paths):
        n_samples = len(image_paths)
    print (f"number of samples to execute: {n_samples-len(results)}")

    

    image_paths = image_paths[starts_from:n_samples]

    if classification_type=="multi-class":
            prompt = "Classify this image into one of the seven classes: MEL, NV, BCC, AKIEC, BKL, DF, VASC." \
                    "output the class name ***only***, without any additional text or explanation."
    
    if use_tools:
        tools = get_tools()
    else :
        tools = []
    
    for image_path in tqdm.tqdm(image_paths, desc="Processing images"):

        # re-initialize the agent for each image in order to avoid too long prompt using the agent as a single chatbot
        if not model_name == None:
            agent, thread = get_agent(openai_kwargs,model_name=model_name,tools=tools,prompt_type=prompt_type)
        else:
            agent, thread = get_agent(openai_kwargs,prompt_type=prompt_type)

        try:
            response, _ = run_medrax(agent, thread, prompt, image_path=image_path,use_tools=use_tools)
            classification = response.strip()
            result={
                os.path.basename(image_path): classification
            }
            results.append(result)
            with open(output_file, "w") as f:
                json.dump(results, f, indent=4)
        except Exception as e:
            print(f"Error processing {image_path}: {e}")
            results.append({os.path.basename(image_path): "Error"})
    
    if results:
        print(f"Classification results saved to {output_file}")
    return results

## HAM10K Classification

In [None]:

model="gpt"
if not load_dotenv(f"{ROOT}/model_env.env"):
    print(f"Error loading environment variables from {model}_env.env")
    exit(1)
openai_kwargs = {}
if api_key := os.getenv("OPENAI_API_KEY"):
    openai_kwargs["api_key"] = api_key

if base_url := os.getenv("OPENAI_BASE_URL"):
    openai_kwargs["base_url"] = base_url

model_name = os.getenv("OPENAI_MODEL_NAME")
classification_type="multi-class"
prompt_type="ISIC_CLASSIFICATION_EXPLANATIONS"
only_errors=True

output_file = os.path.join(ROOT, "experiments","results", "medDerm_classification_results_gpt4o_explanation_overErrors.json")
evaluate_medDerm(model_name, BENCHMARK_DIR,output_file=output_file,n_samples=10000,classification_type=classification_type,prompt_type=prompt_type, use_tools=True, openai_kwargs=openai_kwargs,only_errors=only_errors)
postProcessingResults(output_file,classification_type=classification_type)
evaluate_gpt4o_reliability(output_file,os.path.join(ROOT,"logs"),BENCHMARK_GT_METADATA_FILE)