In [1]:
from tqdm import tqdm
from openai import OpenAI, AsyncOpenAI
import re
from typing import Optional, Union, List, get_origin, get_args, Any, Dict, Literal
import inspect
import asyncio
import json
import logging
import pandas as pd
from pydantic import BaseModel, Field, create_model
import math
import demjson3

logging.basicConfig(
    level=logging.DEBUG,
    format="%(asctime)s - %(levelname)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
        handlers=[
        logging.FileHandler('TNM_test.log', mode='w'),  # Write to file
        logging.StreamHandler()                     # Print to console
    ]
)
from MA_async import *

2025-03-11 07:41:20 - DEBUG - load_ssl_context verify=True cert=None trust_env=True http2=False
2025-03-11 07:41:20 - DEBUG - load_verify_locations cafile='/usr/lib/ssl/certs/ca-certificates.crt'


In [2]:
def safe_json_load(s: str) -> Any:
    """
    Attempts to parse a JSON string using the standard json.loads.
    If that fails (e.g. due to an unterminated string), it will try using
    a more forgiving parser (demjson3). If both attempts fail,
    the original string is returned.
    """
    try:
        return json.loads(s)
    except json.JSONDecodeError as e:
        logger.error("Standard json.loads failed: %s", e)
        try:
            logger.info("Attempting to parse with demjson3 as fallback.")
            result = demjson3.decode(s)
            logger.info("demjson3 successfully parsed the JSON.")
            return result
        except Exception as e2:
            logger.error("Fallback parsing with demjson3 also failed: %s. Returning original input.", e2)
            return s

In [35]:
def generate_tools_spec(*functions):
    # Mapping of Python types to JSON Schema types
    type_map = {
        str: "string",
        int: "integer",
        float: "number",
        bool: "boolean",
        list: "array",
        dict: "object",
        type(None): "null"
    }
    tools = []
    for func in functions:
        # Basic function info
        func_name = func.__name__
        func_description = func.__doc__.strip() if func.__doc__ else ""
        sig = inspect.signature(func)
        
        properties = {}
        required = []
        for param in sig.parameters.values():
            # Skip *args and **kwargs as they cannot be described in JSON schema easily
            if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
                continue
            param_name = param.name

            # Determine JSON schema type from annotation (if available)
            json_type = "string"  # default type
            annotation = param.annotation
            if annotation is not inspect._empty:
                origin = get_origin(annotation)
                # Handle Optional[X] or Union[X, None]
                if origin is Union:
                    args = [t for t in get_args(annotation) if t is not type(None)]
                    if len(args) == 1:
                        annotation = args[0]
                        origin = get_origin(annotation)
                # Map to JSON type if direct or via origin for generics
                if annotation in type_map:
                    json_type = type_map[annotation]
                elif origin in type_map:
                    json_type = type_map[origin]
                # Handle list item types for generics like list[int]
                if json_type == "array":
                    item_type = "string"  # default for items
                    args = get_args(annotation)
                    if args:
                        # Use first type argument for list item if present
                        item_type = type_map.get(args[0], "string")
                    properties[param_name] = {
                        "type": "array",
                        "items": {"type": item_type}
                    }
                elif json_type == "object":
                    # For dicts or unknown complex types, use object without specifics
                    properties[param_name] = {"type": "object"}
                else:
                    properties[param_name] = {"type": json_type}
            else:
                # No annotation, assume string
                properties[param_name] = {"type": "string"}

            # Mark required if no default value
            if param.default is inspect._empty:
                required.append(param_name)
        
        # Build the tool dictionary for this function
        tool_dict = {
            "type": "function",
            "function": {
                "name": func_name,
                "description": func_description,
                "parameters": {
                    "type": "object",
                    "properties": properties
                }
            }
        }
        if required:
            tool_dict["function"]["parameters"]["required"] = required
        tools.append(tool_dict)
    return tools

def provide_final_prediction(reasoning: str, prediction: Literal["T1", "T2", "T3", "T4"]) -> str:
    """
    Returns an answer string with reasoning and final prediction.
    Args:
        reasoning: A step-by-step explanation for how you arrived at the predicted T stage.
        prediction: The final predicted T stage (T1, T2, T3, or T4).
    """
    answer = f"""
    Reasoning: {reasoning}
    Final prediction: {prediction}
    """
    return answer




def retrieve_ajcc_criteria(
    cancer_type: str,
    staging_category: Literal["T", "N", "M"]
) -> dict:
    """
    Retrieves the AJCC staging criteria for a given cancer type and staging category.
    Args:
        cancer_type (str): The type of cancer (e.g., "lung", "breast").
        staging_category (Literal["T", "N", "M"]): "T" for tumor, "N" for regional lymph nodes, or "M" for distant metastasis.
    """
    sample_data = {
        "T": {
            "T1": {"max_size_cm": 2, "description": "Tumor ≤2 cm"},
            "T2": {"min_size_cm": 2, "max_size_cm": 5, "description": "Tumor >2 cm but ≤5 cm"},
            "T3": {"min_size_cm": 5, "description": "Tumor >5 cm"}
        },
        "N": {
            "N0": {"description": "No regional lymph node metastasis"},
            "N1": {"description": "Metastasis to 1-3 lymph nodes"},
            "N2": {"description": "Metastasis to 4 or more lymph nodes"}
        },
        "M": {
            "M0": {"description": "No distant metastasis"},
            "M1": {"description": "Distant metastasis present"}
        }
    }
    return sample_data.get(staging_category, {})

def extract_information(
    info_to_extract: list[str]
) -> dict:
    """
    Extracts relevent information for cancer staging from a pathology report.
    Args:
        items_to_extract (list[str]): A list of information fields to be extracted (e.g. ["tumor_size", "depth_of_invasion", ... etc]).
    """
    results = {}
    return results

def compare_numerical_values(
    value: float,
    min_value: float = None,
    max_value: float = None,
    inclusive_min: bool = True,
    inclusive_max: bool = True
) -> bool:
    """
    Compares a given numeric value against optional minimum and maximum thresholds.
    Args:
        value (float): The numeric value to compare.
        min_value (float, optional): The lower threshold. 
                                     If None, no lower bound check is performed.
        max_value (float, optional): The upper threshold. 
                                     If None, no upper bound check is performed.
        inclusive_min (bool): Whether the comparison with min_value 
                              should be inclusive (value >= min_value) 
                              or exclusive (value > min_value).
        inclusive_max (bool): Whether the comparison with max_value 
                              should be inclusive (value <= max_value) 
                              or exclusive (value < max_value).

    Returns:
        bool: True if the value satisfies all specified boundary conditions; 
              False otherwise.

    Examples:
        compare_numerical_values(3.2, min_value=2, max_value=5) 
            -> True (assuming inclusive checks)
        compare_numerical_values(2, min_value=2, max_value=5, inclusive_min=False)
            -> False, since 2 is not strictly > 2
    """
    if min_value is not None:
        if inclusive_min:
            if value < min_value:
                return False
        else:
            if value <= min_value:
                return False

    if max_value is not None:
        if inclusive_max:
            if value > max_value:
                return False
        else:
            if value >= max_value:
                return False

    return True


In [32]:
available_tools = {
    "retrieve_ajcc_criteria": retrieve_ajcc_criteria,
    "extract_information": extract_information,
    "compare_numerical_values": compare_numerical_values,
    "provide_final_prediction": provide_final_prediction
}
tools = generate_tools_spec(*available_tools.values())

In [36]:
tools

[{'type': 'function',
  'function': {'name': 'retrieve_ajcc_criteria',
   'description': 'Retrieves the AJCC staging criteria for a given cancer type and staging category.\n    Args:\n        cancer_type (str): The type of cancer (e.g., "lung", "breast").\n        staging_category (Literal["T", "N", "M"]): "T" for tumor, "N" for regional lymph nodes, or "M" for distant metastasis.',
   'parameters': {'type': 'object',
    'properties': {'cancer_type': {'type': 'string'},
     'staging_category': {'type': 'string'}},
    'required': ['cancer_type', 'staging_category']}}},
 {'type': 'function',
  'function': {'name': 'extract_information',
   'description': 'Extracts relevent information for cancer staging from a pathology report.\n    Args:\n        items_to_extract (list[str]): A list of information fields to be extracted (e.g. ["tumor_size", "depth_of_invasion", ... etc]).',
   'parameters': {'type': 'object',
    'properties': {'info_to_extract': {'type': 'array',
      'items': {'ty

In [6]:
class ResponseStage(BaseModel):
    reasoning: str = Field(
        description="A step-by-step explanation for how you arrived at the predicted T stage."
    )
    stage: Literal["T1", "T2", "T3", "T4"] = Field(
        description="The final predicted T stage (T1, T2, T3, or T4)."
    )
{"guided_json": ResponseStage.model_json_schema()}

{'guided_json': {'properties': {'reasoning': {'description': 'A step-by-step explanation for how you arrived at the predicted T stage.',
    'title': 'Reasoning',
    'type': 'string'},
   'stage': {'description': 'The final predicted T stage (T1, T2, T3, or T4).',
    'enum': ['T1', 'T2', 'T3', 'T4'],
    'title': 'Stage',
    'type': 'string'}},
  'required': ['reasoning', 'stage'],
  'title': 'ResponseStage',
  'type': 'object'}}

In [34]:
tools

[{'type': 'function',
  'function': {'name': 'retrieve_ajcc_criteria',
   'description': 'Retrieves the AJCC staging criteria for a given cancer type and staging category.\n    Args:\n        cancer_type (str): The type of cancer (e.g., "lung", "breast").\n        staging_category (Literal["T", "N", "M"]): "T" for tumor, "N" for regional lymph nodes, or "M" for distant metastasis.',
   'parameters': {'type': 'object',
    'properties': {'cancer_type': {'type': 'string'},
     'staging_category': {'type': 'string'}},
    'required': ['cancer_type', 'staging_category']}}},
 {'type': 'function',
  'function': {'name': 'extract_information',
   'description': 'Extracts relevent information for cancer staging from a pathology report.\n    Args:\n        items_to_extract (list[str]): A list of information fields to be extracted (e.g. ["tumor_size", "depth_of_invasion", ... etc]).',
   'parameters': {'type': 'object',
    'properties': {'info_to_extract': {'type': 'array',
      'items': {'ty

In [41]:
system_prompt = \
"""You are an AI agent specialized in accurately determining the pathologic T stage (T1, T2, T3, or T4) for breast cancer, according to the AJCC Cancer Staging Manual (7th edition). 
You have access to multiple tools to help extract and interpret relevant information from pathology reports. Your task is to identify the appropriate tools and plan a logical sequence of actions to achieve correct staging."""

user_prompt = \
"""Here is the pathology report for a breast cancer patient:

<Pathology Report>
{report}
</Pathology Report>

List clearly:
1. Which tools you would use.
2. The exact sequence in which you would use them.
3. Briefly explain why you chose this sequence."""


In [None]:
# system_prompt = """You are an AI assistant specialized in determining the T stage of breast cancer following the AJCC Cancer Staging Manual (7th edition).

# You have access to multiple tools to assist your reasoning if needed. 

# **Important Requirements**:
# - Use the above tools whenever you need additional information, staging criteria, or numeric comparisons.
# You must provide your final T stage by calling the provide_final_prediction function with two arguments:
# 1) reasoning: an explanation of how you arrived at the stage,
# 2) prediction: one of T1, T2, T3, or T4.
# """
# user_prompt = """You are provided with a pathology report of a breast cancer patient. Please review the report carefully and determine the pathologic T stage (T1, T2, T3, or T4) according to the AJCC Cancer Staging Manual (7th edition).

# <Pathology Report>
# {report}
# </Pathology Report>

# **Important Requirements**:
# 1. If you need any additional information or calculations, call the appropriate tool.
# 2. Conclude by calling the provide_final_prediction function with your reasoning and your final T stage.
# """


In [22]:
df = pd.read_csv("/home/yl3427/cylab/selfCorrectionAgent/result_breast/results_t_stage_run_2.csv")[["patient_filename", "t", "text"]]
df.iloc[2]['text']

"FINAL PATHOLOGIC DIAGNOSIS. A. Lymph node, sentinel #1, right axillary, excision: - One lymph node, negative for metastatic carcinoma (0/1). - Biopsy site changes identified. B. Lymph node, sentinel #2, right axillary, excision: - One lymph node, negative for metastatic carcinoma (0/1). C. Lymph node, sentinel #3, right axillary, excision: - One lymph node, negative for metastatic carcinoma (0/1). D. Lymph node, sentinel #4, right axillary, excision: - One lymph node, negative for metastatic carcinoma (0/1). E. Breast, right, partial mastectomy: - Invasive lobular carcinoma, see breast pathologic. parameters. - Margins of excision free of tumor, distance to the closest margin >2 mm. to superior and anterior. - Lobular intraepithelial neoplasia (LCIS), classic type. - Atypical ductal hyperplasia. - Previous biopsy site identified with extensive fat necrosis. - Flat epithelial atypia, focally associated with microcalcifications. - Sclerosing adenosis, focally associated with microcalcif

In [42]:
logger.info("Loaded dataframe with %d rows.", len(df))

row = df.iloc[2]
report = row["text"]
t_stage = row["t"]
formatted_user_prompt = user_prompt.format(report=report)

client=OpenAI(base_url="http://localhost:8000/v1", api_key="dummy")
messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": formatted_user_prompt}]
response = client.chat.completions.create(
    model="meta-llama/Llama-3.3-70B-Instruct",
    messages=messages,
    temperature=0.4,
    tools=tools
)


2025-03-11 10:28:18 - INFO - Loaded dataframe with 1031 rows.
2025-03-11 10:28:18 - DEBUG - load_ssl_context verify=True cert=None trust_env=True http2=False
2025-03-11 10:28:18 - DEBUG - load_verify_locations cafile='/usr/lib/ssl/certs/ca-certificates.crt'
2025-03-11 10:28:18 - DEBUG - Request options: {'method': 'post', 'url': '/chat/completions', 'files': None, 'json_data': {'messages': [{'role': 'system', 'content': 'You are an AI agent specialized in accurately determining the pathologic T stage (T1, T2, T3, or T4) for breast cancer, according to the AJCC Cancer Staging Manual (7th edition). \nYou have access to multiple tools to help extract and interpret relevant information from pathology reports. Your task is to identify the appropriate tools and plan a logical sequence of actions to achieve correct staging.'}, {'role': 'user', 'content': "Here is the pathology report for a breast cancer patient:\n\n<Pathology Report>\nFINAL PATHOLOGIC DIAGNOSIS. A. Lymph node, sentinel #1, ri

In [52]:
messages.append({"role": "system", "content":response.choices[0].message.content})
user_prompt2 = \
"""Great. Now, call each tool step-by-step with the necessary arguments."""
messages.append({"role": "user", "content": user_prompt2})
response = client.chat.completions.create(
    model="meta-llama/Llama-3.3-70B-Instruct",
    messages=messages,
    temperature=0.4,
    tools=tools
)


2025-03-11 10:47:55 - DEBUG - Request options: {'method': 'post', 'url': '/chat/completions', 'files': None, 'json_data': {'messages': [{'role': 'system', 'content': 'You are an AI agent specialized in accurately determining the pathologic T stage (T1, T2, T3, or T4) for breast cancer, according to the AJCC Cancer Staging Manual (7th edition). \nYou have access to multiple tools to help extract and interpret relevant information from pathology reports. Your task is to identify the appropriate tools and plan a logical sequence of actions to achieve correct staging.'}, {'role': 'user', 'content': "Here is the pathology report for a breast cancer patient:\n\n<Pathology Report>\nFINAL PATHOLOGIC DIAGNOSIS. A. Lymph node, sentinel #1, right axillary, excision: - One lymph node, negative for metastatic carcinoma (0/1). - Biopsy site changes identified. B. Lymph node, sentinel #2, right axillary, excision: - One lymph node, negative for metastatic carcinoma (0/1). C. Lymph node, sentinel #3, 

2025-03-11 10:47:55 - DEBUG - send_request_body.complete
2025-03-11 10:47:55 - DEBUG - receive_response_headers.started request=<Request [b'POST']>
2025-03-11 10:48:04 - DEBUG - receive_response_headers.complete return_value=(b'HTTP/1.1', 200, b'OK', [(b'date', b'Tue, 11 Mar 2025 10:47:54 GMT'), (b'server', b'uvicorn'), (b'content-length', b'660'), (b'content-type', b'application/json')])
2025-03-11 10:48:04 - INFO - HTTP Request: POST http://localhost:8000/v1/chat/completions "HTTP/1.1 200 OK"
2025-03-11 10:48:04 - DEBUG - receive_response_body.started request=<Request [b'POST']>
2025-03-11 10:48:04 - DEBUG - receive_response_body.complete
2025-03-11 10:48:04 - DEBUG - response_closed.started
2025-03-11 10:48:04 - DEBUG - response_closed.complete
2025-03-11 10:48:04 - DEBUG - HTTP Response: POST http://localhost:8000/v1/chat/completions "200 OK" Headers({'date': 'Tue, 11 Mar 2025 10:47:54 GMT', 'server': 'uvicorn', 'content-length': '660', 'content-type': 'application/json'})
2025-03-

In [54]:
print(response.choices[0].message.tool_calls)

[ChatCompletionMessageToolCall(id='chatcmpl-tool-024aaf5603cb40ccaa216ac6ff4afdfe', function=Function(arguments='{"info_to_extract": ["tumor_size", "depth_of_invasion", "lymph_node_status"]}', name='extract_information'), type='function')]


In [None]:
warning_prompt = \
    """Please stick to your original plan. You haven't finished calling all the required tools yet.  
Continue by calling the next tool according to your planned sequence."""

print(warning_prompt)

Please stick to your original plan. You haven't finished calling all the required tools yet.  
Continue by invoking the next tool according to your planned sequence.


In [None]:
# final node
def generate_structured_output() -> dict:
    """
    Generates a structured output containing the model's reasoning, determined stage, and confidence score.

    Args:
        last_model_answer (str): The final reasoning or explanation 
                                 from the LLM.
        stage_label (str): The classification stage determined by 
                           the LLM (e.g., "T2").
        confidence_score (float): A numeric value indicating confidence 
                                  (0.0 to 1.0).

    Returns:
        dict: A structured result in the format:
            {
                "reasoning": <str>,
                "stage": <str>,
                "confidence_score": <float>
            }
    """
    return {
        "reasoning": "",
        "stage": "",
        "confidence_score": 0
    }

In [None]:


semaphore = asyncio.Semaphore(5)
async def process_single_query(row_idx: int, row: pd.Series):
    async with semaphore:
        logger.info(f"Processing row index {row_idx}")
        report = row["text"]
        t_stage = row["t"]
        formatted_user_prompt = user_prompt.format(report=report)
        result = await run_multi_agent_system(
            system_instruction=system_instruction,
            user_prompt=formatted_user_prompt,
            available_tools=available_tools,
            tools=tools
        )


results = []
for idx, row in df.iterrows():
    logger.info(f"Processing row index {idx}")
    report = row["text"]
    t_stage = row["t"]

    formatted_user_prompt = user_prompt.format(report=report)



    # Run the multi-agent system for this single query
    result_dict = await process_single_query(
        question_text=question_text,
        ground_truth=ground_truth,
        choices=["A", "B", "C", "D", "E"],
        # choices=["Yes", "No"],
        n_specialists=5
    )
    # result_dict["File ID"] = row["File ID"]
    result_dict["qn_num"] = row["qn_num"]

    # Store result for later evaluation
    results.append(result_dict)

    if idx % 10 == 0:
        output_json_path = f"/home/yl3427/cylab/SOAP_MA/Output/MedicalQA/step3_{idx}.json"
        with open(output_json_path, "w", encoding="utf-8") as f:
            json.dump(results, f, indent=2, ensure_ascii=False)
        logger.info(f"Saved aggregated results to {output_json_path}")

# OPTIONAL: Save results to JSON
output_json_path = "/home/yl3427/cylab/SOAP_MA/Output/MedicalQA/step3_final.json"
with open(output_json_path, "w", encoding="utf-8") as f:
    json.dump(results, f, indent=2, ensure_ascii=False)
logger.info(f"Saved aggregated results to {output_json_path}")

In [None]:
client=OpenAI(base_url="http://localhost:8000/v1", api_key="dummy")
messages = [{"role": "system", "content": system_instruction}, {"role": "user", "content": formatted_user_prompt}]

In [None]:
messages

In [None]:
resp = client.chat.completions.create(
    model = "meta-llama/Llama-3.3-70B-Instruct", 
    messages = messages,
    tools = tools,
    tool_choice="auto", 
    # tool_choice="none",
    )

In [None]:
resp.choices[0].message.tool_calls

In [None]:
if resp.choices[0].message.tool_calls:
    print("Tool calls detected.")
    messages.append({
    "role": "assistant",
    "tool_calls": resp.choices[0].message.tool_calls
    })
    for call in resp.choices[0].message.tool_calls:
        
        args = safe_json_load(call.function.arguments)
        if call.function.name == "provide_final_prediction":
            return args
        result = available_tools[call.function.name](**args)
        print(result)
        messages.append({
        "role": "tool",
        "content": result,
        "tool_call_id": call.id,
        "name": call.function.name,
        })

else:
    final_result = resp.choices[0].message.content
    messages.append({"role": "assistant", "content": final_result})
    print(result)

In [None]:
resp.choices[0].message

In [None]:
args

In [None]:
resp.choices[0].message

In [None]:
def safe_json_load(s: str) -> Any:
    """
    Attempts to parse a JSON string using the standard json.loads.
    If that fails (e.g. due to an unterminated string), it will try using
    a more forgiving parser (demjson3). If both attempts fail,
    the original string is returned.
    """
    try:
        return json.loads(s)
    except json.JSONDecodeError as e:
        logger.error("Standard json.loads failed: %s", e)
        try:
            logger.info("Attempting to parse with demjson3 as fallback.")
            result = demjson3.decode(s)
            logger.info("demjson3 successfully parsed the JSON.")
            return result
        except Exception as e2:
            logger.error("Fallback parsing with demjson3 also failed: %s. Returning original input.", e2)
            return s


class LLMAgent:
    def __init__(self, system_prompt: str, 
                 client=AsyncOpenAI(base_url="http://localhost:8000/v1", api_key="dummy")):
        self.client = client
        self.messages = [{"role": "system", "content": system_prompt}]

    async def llm_call(self, user_prompt: str,
                       guided_: dict = None,
                       tools: List[dict] = None) -> Any:
        logger.debug(f"LLMAgent.llm_call() - user_prompt[:60]: {user_prompt[:60]}...")
        self.messages.append({"role": "user", "content": user_prompt})
        params = {
            "model": "meta-llama/Llama-3.3-70B-Instruct",
            "messages": self.messages,
            "temperature": 0.5,
        }
        if guided_:
            logger.debug(f"Guided JSON/choice detected: {guided_}")
            params["extra_body"] = guided_
        if tools:
            params["tools"] = tools

        response = await self.client.chat.completions.create(**params)
        return response.choices[0].message
    
    def append_message(self, content, role='assistant'):
        logger.debug(f"Appending message with role='{role}' to conversation.")
        self.messages.append({"role": role, "content": content})
        return


class InitializerAgent(LLMAgent):
    def __init__(self, n_specialists: int):
        self.n_specialists = n_specialists
        system_prompt = (
            "You are an initializer agent in a multi-agent AI system designed to handle medical questions.\n"
            f"Your job is to select {self.n_specialists} medical specialists whose expertise best matches the user's query.\n"
            "For each specialist, specify their role and a list of relevant expertise areas related to the query.\n"
        )
        super().__init__(system_prompt)

    async def identify_specialists(self, query: str):
        logger.info("InitializerAgent: Identifying specialists.")
        class Specialist(BaseModel):
            specialist: str = Field(..., description="Role of the specialist")
            expertise: List[str] = Field(..., description="Areas of expertise for the specialist.")
        panel_dict = {f"Specialist_{i+1}": (Specialist, ...) for i in range(self.n_specialists)}
        SpecialistPanel = create_model("SpecialistPanel", **panel_dict)

        user_prompt = (
            "Here is the user's query:\n\n"
            f"<Query>\n{query}\n</Query>\n\n"
            "Based on the above query, identify the most suitable specialists."
        )
        response = await self.llm_call(user_prompt, guided_={"guided_json": SpecialistPanel.schema()})
        self.append_message(content=response.content)
        logger.debug(f"InitializerAgent response: {response.content}")
        return safe_json_load(response.content)


class SpecialistAgent(LLMAgent):
    def __init__(self, specialist: str, expertise: List[str]):
        self.specialist = specialist
        self.expertise = expertise
        system_prompt = (
            f"You are a {specialist}.\n"
            f"Your expertise includes:\n{expertise}\n"
            f"Analyze the user's query from the perspective of a {specialist}."
        )
        super().__init__(system_prompt)

    async def analyze_query(self, query: str, choices: List[str]):
        logger.info(f"[{self.specialist}] Analyzing query...")
        self.query = query
        self.choices = tuple(choices)
        choices_str = ', '.join(choices)

        user_prompt = (
            "Here is the query of interest:\n\n"
            f"<Query>\n{query}\n</Query>\n\n"
            f"The possible answers are: {choices_str}.\n"
            f"From your perspective as a {self.specialist}, first provide step-by-step reasoning (rationale), "
            "and then clearly state your final answer.\n\n"
        )

        class Response(BaseModel):
            reasoning: str = Field(..., description="Step-by-step reasoning leading to the final choice")
            choice: Literal[self.choices] = Field(..., description="Final choice")

        response = await self.llm_call(user_prompt, guided_={"guided_json": Response.model_json_schema()})
        self.append_message(content=response.content)
        logger.debug(f"[{self.specialist}] analyze_query response: {response.content}")
        return safe_json_load(response.content)
    
    async def debate(self, agents: Dict[str, Any]):
        logger.info(f"[{self.specialist}] Debating with other specialists.")
        other_specialists = {}
        for name, value in agents.items():
            if name != self.specialist:
                other_specialists[name] = value

        formatted_other_specialists = json.dumps(other_specialists, indent=4)
        user_prompt = (
            "Regarding the previous query, other specialists have also provided their reasoning and choices.\n"
            "Critically evaluate the reasoning and choice of those specialists.\n\n"
            f"Specialists and their choices:\n{formatted_other_specialists}\n\n"
            "Considering the newly provided perspectives, refine your own reasoning and choice.\n"
            "You can change your choice or stick with the original one.\n\n"
        )

        class Response(BaseModel):
            reasoning: str = Field(..., description="Step-by-step reasoning leading to final choice")
            choice: Literal[self.choices] = Field(..., description="Final choice")

        response = await self.llm_call(user_prompt, guided_={"guided_json": Response.model_json_schema()})
        self.append_message(content=response.content)
        logger.debug(f"[{self.specialist}] debate response: {response.content}")
        return safe_json_load(response.content)


class AggregatorAgent(LLMAgent):
    def __init__(self):
        system_prompt = (
            "You are the aggregator agent in a multi-agent AI system for medical queries.\n"
            "You have access to each specialist's entire chat history.\n"
            "Your job is to read those full conversations, analyze their reasoning and any conflicts, "
            "and then provide a single, definitive answer to the user.\n"
            "Provide a clear explanation for your final conclusion."
        )
        super().__init__(system_prompt)

    async def aggregate(self, query: str, choices: List[str], specialists_chat_history: Dict[str, Any]):
        logger.info("AggregatorAgent: Aggregating final answer from all specialists' chat history.")
        specialists_str = json.dumps(specialists_chat_history, indent=4)

        user_prompt = (
            f"Here is the query of interest:\n\n"
            f"<Query>\n{query}\n</Query>\n\n"
            "Below is the *entire conversation history* for each specialist:\n\n"
            f"{specialists_str}\n\n"
            "Please review all these conversations in detail and produce one single, definitive final answer. "
            "If there is no unanimous or majority choice, choose the answer best supported by the specialists' reasoning. "
            "Clearly justify your reasoning, then provide your final recommended answer."
        )

        class AggregatedResponse(BaseModel):
            aggregated_reasoning: str = Field(..., description="Detailed reasoning behind final choice")
            aggregated_choice: Literal[tuple(choices)] = Field(..., description="Single recommended choice")

        response = await self.llm_call(user_prompt, guided_={"guided_json": AggregatedResponse.model_json_schema()})
        self.append_message(content=response.content)
        logger.debug(f"AggregatorAgent response: {response.content}")
        return safe_json_load(response.content)


def check_consensus(status_dict: Dict[str, Any]) -> str:
    """
    Returns the consensus choice if >= 80% of specialists agree, else returns None.
    """
    logger.info("Checking for consensus among specialists.")
    specialists_count = len(status_dict)
    consensus_threshold = math.ceil(0.8 * specialists_count)

    choice_counts = {}
    for _, specialist_data in status_dict.items():
        final_choice = specialist_data['response_after_debate']['choice']
        choice_counts[final_choice] = choice_counts.get(final_choice, 0) + 1

    for choice, count in choice_counts.items():
        if count >= consensus_threshold:
            logger.info(f"Consensus found on choice '{choice}' with {count}/{specialists_count} specialists.")
            return choice
    logger.info("No consensus found.")
    return None


# --------------------------------
# 3) PROCESS A SINGLE ROW/QUERY
# --------------------------------
async def process_single_query(
    question_text: str,
    ground_truth: str,
    choices: List[str],
    n_specialists: int) -> Dict[str, Any]:
    """
    Given a single query (question + ground_truth + multiple choices), 
    run the multi-agent system (Initializer -> Specialists -> Debates -> Aggregator if needed).
    Return the final dictionary containing all the specialists' output and aggregator results.
    """

    # 1. Initialize specialists
    initializer = InitializerAgent(n_specialists=n_specialists)
    json_resp = await initializer.identify_specialists(query=question_text)
    if not isinstance(json_resp, dict):
        logger.error("Invalid JSON output from initializer; skipping this query.")
        return {}  # Skip processing and continue to the next query

    # Build specialists status dict
    specialists_status = {}
    for _, agent_info in json_resp.items():
        specialist_name = agent_info["specialist"]
        expertise = agent_info["expertise"]
        specialists_status[specialist_name] = {"expertise": expertise}
    
    # 2. Run analyze_query for each specialist in parallel
    async def analyze_specialist(specialist_name: str, status: Dict[str, Any], query: str, choices: List[str]):
        specialist_agent = SpecialistAgent(specialist=specialist_name, expertise=status["expertise"])
        status["instance"] = specialist_agent
        message = await specialist_agent.analyze_query(query=query, choices=choices)
        if not isinstance(message, dict):
            logger.error(f"[{specialist_name}] Invalid JSON output from specialist; skipping this specialist.")
            return None
        status["original_response"] = message
        logger.info(f"[{specialist_name}] Completed analyze_query.")
        return specialist_name

    analyze_tasks = [
        asyncio.create_task(analyze_specialist(name, status, question_text, choices))
        for name, status in specialists_status.items()
    ]
    analyze_results = await asyncio.gather(*analyze_tasks)
    if any(r is None for r in analyze_results):
        logger.error("At least one specialist failed; skipping this query.")
        return {}  # Skip processing and continue to the next query

    # Build a minimal dictionary for debate (remove 'instance')
    input_specialists_dict = {
        specialist_name: {
            k: v for k, v in specialist_data.items() 
            if k != "instance"
        }
        for specialist_name, specialist_data in specialists_status.items()
    }

    # 3. Debate step, also in parallel
    async def debate_specialist(specialist_name: str, status: Dict[str, Any], specialists_dict: Dict[str, Any]):
        specialist_agent = status["instance"]
        message = await specialist_agent.debate(specialists_dict)
        if not isinstance(message, dict):
            logger.error(f"[{specialist_name}] Invalid JSON output during debate; skipping this specialist.")
            return None
        status["response_after_debate"] = message
        specialists_dict[specialist_name]["response_after_debate"] = message
        logger.info(f"[{specialist_name}] Completed debate.")
        return specialist_name

    debate_tasks = [
        asyncio.create_task(debate_specialist(name, status, input_specialists_dict))
        for name, status in specialists_status.items()
    ]
    debate_results = await asyncio.gather(*debate_tasks)
    if any(r is None for r in debate_results):
        logger.error("At least one specialist failed during debate; skipping this query.")
        return {}  # Skip processing and continue to the next query

    # 4. Check consensus
    consensus_choice = check_consensus(input_specialists_dict)
    aggregator_result = None

    if consensus_choice is not None:
        logger.info(f"Consensus reached: {consensus_choice}")
        input_specialists_dict["Aggregator"] = {
            "final_choice": consensus_choice, 
            "final_reasoning": "Consensus reached"
        }
    else:
        logger.info("No consensus reached; enabling aggregator path...")
        aggregator = AggregatorAgent()
        aggregated_response = await aggregator.aggregate(
            query=question_text,
            choices=choices,
            specialists_chat_history=input_specialists_dict
        )
        if not isinstance(aggregated_response, dict):
            logger.error("Invalid JSON output from aggregator; skipping this query.")
            return {}  # Skip processing and continue to the next query
        
        final_choice = aggregated_response['aggregated_choice']
        final_reasoning = aggregated_response['aggregated_reasoning']

        logger.info(f"Aggregator final choice: {final_choice}")
        logger.info(f"Aggregator reasoning: {final_reasoning}")

        aggregator_result = {
            "final_choice": final_choice,
            "final_reasoning": final_reasoning
        }
        input_specialists_dict["Aggregator"] = aggregator_result

    # Add question and ground_truth for reference
    input_specialists_dict["Question"] = question_text
    input_specialists_dict["Answer"] = ground_truth

    return input_specialists_dict


async def process_multiple_queries(
    qa_df: pd.DataFrame,
    choices: List[str],
    n_specialists: int,
    max_concurrency: int = 5
) -> List[Dict[str, Any]]:
    """
    Process multiple rows (queries) in `qa_df` asynchronously.
    Each row is passed to `process_single_query`.
    
    :param qa_df: DataFrame with columns ["question", "choice", "ground_truth"] at least.
    :param choices: A list of all possible answer choices, e.g. ["A", "B", "C", "D", "E"].
    :param n_specialists: Number of specialists to initialize for each query.
    :param max_concurrency: Limit on how many queries to process simultaneously.
    :return: A list of result dictionaries, one per row in `qa_df`.
    """

    # This semaphore keeps at most `max_concurrency` tasks running at once
    semaphore = asyncio.Semaphore(max_concurrency)

    async def run_single_query(row_idx: int, row: pd.Series):
        """
        This inner function is used to call `process_single_query` with concurrency control.
        """
        async with semaphore:
            logger.info(f"Starting row {row_idx}")
            question_text = row["question"] + "\n" + str(row["choice"])
            ground_truth = str(row["ground_truth"])
            result = await process_single_query(
                question_text=question_text,
                ground_truth=ground_truth,
                choices=choices,
                n_specialists=n_specialists
            )
            logger.info(f"Finished row {row_idx}")
            return result

    tasks = [
        asyncio.create_task(run_single_query(i, row))
        for i, row in qa_df.iterrows()
    ]

    # Wait for all tasks to complete
    all_results = await asyncio.gather(*tasks)

    # `all_results` is a list of return values from each `run_single_query`
    return all_results

async def main():

    logger.info("===== MAIN START =====")

    # Example CSV loading
    df_path = "/home/yl3427/cylab/llm_reasoning/reasoning/data/step2_ALL.csv"
    qa_df = pd.read_csv(df_path, encoding="latin-1")  # columns: idx, question, choice, ground_truth, qn_num
    # qa_df = pd.read_csv('/home/yl3427/cylab/SOAP_MA/Input/SOAP_5_problems.csv')
    logger.info("Loaded dataframe with %d rows.", len(qa_df))


    ################# 'process_single_query' Example usage #################
    results = []
    for idx, row in qa_df.iterrows():
        # if idx <= 10:
        #     continue
        logger.info(f"Processing row index {idx}")
        question_text = row["question"] + "\n" + str(row["choice"])
        ground_truth = str(row["ground_truth"])

        # patient_info = str(row["Subjective"]) + "\n" + str(row['Objective'])
        # question_text = f"""
        # Based on the following patient report, does the patient have sepsis?

        # {patient_info}
        # """
        # ground_truth = str(row["terms"])
        

        # Run the multi-agent system for this single query
        result_dict = await process_single_query(
            question_text=question_text,
            ground_truth=ground_truth,
            choices=["A", "B", "C", "D", "E"],
            # choices=["Yes", "No"],
            n_specialists=5
        )
        # result_dict["File ID"] = row["File ID"]
        result_dict["qn_num"] = row["qn_num"]

        # Store result for later evaluation
        results.append(result_dict)

        if idx % 10 == 0:
            output_json_path = f"/home/yl3427/cylab/SOAP_MA/Output/MedicalQA/step2_{idx}.json"
            with open(output_json_path, "w", encoding="utf-8") as f:
                json.dump(results, f, indent=2, ensure_ascii=False)
            logger.info(f"Saved aggregated results to {output_json_path}")

    # OPTIONAL: Save results to JSON
    output_json_path = "/home/yl3427/cylab/SOAP_MA/Output/MedicalQA/step2_final.json"
    with open(output_json_path, "w", encoding="utf-8") as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    logger.info(f"Saved aggregated results to {output_json_path}")

    logger.info("===== MAIN END =====")


In [None]:
# log
for name, status in specialists_status.items():
    print(f"Specialist: {name}")
    message = status["instance"].messages
    print(message)


In [None]:
system_instruction = """
You are a knowledgeable and meticulous medical expert specialized in diagnosing diseases based on partial information from SOAP notes. 
You will receive either:
1. A single-disease assessment request (“specialist” scenario), or 
2. A multiple-disease assessment request (“generalist” scenario).

In the “specialist” scenario, you focus on one disease and analyze evidence within the Subjective (S) and Objective (O) sections for or against that single disease. Your final answer must be in valid JSON with:
    {
        "reasoning": "Concise explanation of your thought process",
        "diagnosis": true_or_false
    }

In the “generalist” scenario, you must assess each disease from a given list. For each disease, identify subjective and objective evidence that supports or refutes the disease. If evidence strongly supports it, conclude the diagnosis is true; if not, conclude false. If conflicting or incomplete, offer a reasoned explanation and a likely conclusion. Your final answer must be in valid JSON with each disease as a key:
    {
      "DiseaseName1": { "reasoning": "Your reasoning...", "diagnosis": true_or_false },
      "DiseaseName2": { "reasoning": "Your reasoning...", "diagnosis": true_or_false },
      ...
    }

When reasoning, consider clinical clues like symptoms, exam findings, risk factors, and labs. Clearly and succinctly justify why each disease is likely or unlikely. If any information is missing or ambiguous, note the uncertainty and choose the most probable conclusion.

Follow these instructions precisely:
• Always return output in the exact JSON format requested (no extra fields or text).
• Provide concise, medically sound rationale for each decision.
"""

prompt_specialist = """
You are a medical expert specializing in {PROBLEM}.

You are provided with only the Subjective (S) and Objective (O) sections of a patient's SOAP-formatted progress note for a potential case of {PROBLEM}.
Identify relevant clues in the subjective and objective sections that align with or argue against {PROBLEM}. If evidence strongly suggests {PROBLEM}, conclude the diagnosis is true; if not, conclude it is false. If the evidence is uncertain or conflicting, explain your reasoning and lean toward the most likely conclusion.

Patient Report:
<Subjective>
{SUBJ}
</Subjective>

<Objective>
{OBJ}
</Objective>

Your answer must be output as valid JSON formatted exactly as follows:
    {{
        "reasoning": "Your reasoning here...",
        "diagnosis": true_or_false
    }}
"""

prompt_generalist = """
You are a medical expert in diagnostic reasoning.

You are provided with only the Subjective (S) and Objective (O) sections of a patient's SOAP-formatted progress note that may be relevant to one or more of the following diseases:
{PROBLEM_LIST}

The patient may have one or more of these diseases, or none at all. Evaluate each disease independently.
Identify relevant clues in the subjective and objective sections that align with or argue against each disease. If evidence strongly suggests the disease, conclude the diagnosis is true; if not, conclude it is false. If the evidence is uncertain or conflicting, explain your reasoning and lean toward the most likely conclusion.

Patient Report:
<Subjective>
{SUBJ}
</Subjective>

<Objective>
{OBJ}
</Objective>

Your answer must be output as valid JSON formatted exactly as follows:
{{
{json_keys}
}}
"""

system_instruction_mediator = """
You are the mediator agent in a medical multi-agent diagnostic system. 
"""

In [None]:
class Response(BaseModel):
    reasoning: str = Field(..., description="Step-by-step reasoning leading to the final diagnosis.")
    diagnosis: bool = Field(..., description="True if patient has the disease, False otherwise.")

In [None]:
from typing import get_origin, get_args, Union, Any

def generate_tools_spec(*functions):
    """
    Generate a list of tool definitions (function schemas) for OpenAI's tool calling.
    
    Each function's name, docstring, and parameters (with types and required flags)
    are extracted to form the JSON schema as a dictionary.
    
    Args:
        *functions: One or more Python function objects to document.
    Returns:
        List[dict]: A list of tool definition dictionaries compatible with OpenAI API.
    """
    # Mapping of Python types to JSON Schema types
    type_map = {
        str: "string",
        int: "integer",
        float: "number",
        bool: "boolean",
        list: "array",
        dict: "object",
        type(None): "null"
    }
    tools = []
    for func in functions:
        # Basic function info
        func_name = func.__name__
        func_description = func.__doc__.strip() if func.__doc__ else ""
        sig = inspect.signature(func)
        
        properties = {}
        required = []
        for param in sig.parameters.values():
            # Skip *args and **kwargs as they cannot be described in JSON schema easily
            if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
                continue
            param_name = param.name

            # Determine JSON schema type from annotation (if available)
            json_type = "string"  # default type
            annotation = param.annotation
            if annotation is not inspect._empty:
                origin = get_origin(annotation)
                # Handle Optional[X] or Union[X, None]
                if origin is Union:
                    args = [t for t in get_args(annotation) if t is not type(None)]
                    if len(args) == 1:
                        annotation = args[0]
                        origin = get_origin(annotation)
                # Map to JSON type if direct or via origin for generics
                if annotation in type_map:
                    json_type = type_map[annotation]
                elif origin in type_map:
                    json_type = type_map[origin]
                # Handle list item types for generics like list[int]
                if json_type == "array":
                    item_type = "string"  # default for items
                    args = get_args(annotation)
                    if args:
                        # Use first type argument for list item if present
                        item_type = type_map.get(args[0], "string")
                    properties[param_name] = {
                        "type": "array",
                        "items": {"type": item_type}
                    }
                elif json_type == "object":
                    # For dicts or unknown complex types, use object without specifics
                    properties[param_name] = {"type": "object"}
                else:
                    properties[param_name] = {"type": json_type}
            else:
                # No annotation, assume string
                properties[param_name] = {"type": "string"}

            # Mark required if no default value
            if param.default is inspect._empty:
                required.append(param_name)
        
        # Build the tool dictionary for this function
        tool_dict = {
            "type": "function",
            "function": {
                "name": func_name,
                "description": func_description,
                "parameters": {
                    "type": "object",
                    "properties": properties
                }
            }
        }
        if required:
            tool_dict["function"]["parameters"]["required"] = required
        tools.append(tool_dict)
    return tools


In [None]:
def retrieve_synonyms(problem: str) -> Optional[List[str]]: 
    """
    Retrieve the list of synonyms for a given problem.
    """
    problem = problem.lower()
    mi = ["myocardial infarction", "elevation mi", "non-stemi", " NSTEMI", " stemi"]
    chf = ["congestive heart failure", " chf", "HFrEF", "HFpEF"]
    pulmonary_embolism = ["pulmonary embolism"]
    pulmonary_hypertension = ["pulmonary hypertension", "pulmonary htn"]
    sepsis = ["sepsis", "septic shock"]
    urosepsis = ["urosepsis"]
    meningitis = ["meningitis"]
    aki = ["acute kidney injury", " aki", "acute renal failure", " arf"] # -> Acute tubular necrosis (ATN)인가 아닌가
    atn = ["acute tubular necrosis", " atn"]
    pancreatitis = ["pancreatitis"]
    gi_bleed = ["gastrointestinal bleed", "gi bleed"]
    hepatitis = ["hepatitis", " hep"]
    cholangitis = ["cholangitis"]
    asp_pneumonia = ["aspiration pneumonia"]

    prob_dict = {'myocardial infarction': mi, 
                 'congestive heart failure': chf, 
                 'pulmonary embolism': pulmonary_embolism, 
                 'pulmonary hypertension': pulmonary_hypertension, 
                 'sepsis': sepsis, 
                 'urosepsis': urosepsis, 
                 'meningitis': meningitis, 
                 'acute kidney injury': aki, 
                 'acute tubular necrosis': atn, 
                 'pancreatitis': pancreatitis, 
                 'gastrointestinal bleed': gi_bleed, 
                 'hepatitis': hepatitis, 
                 'cholangitis': cholangitis, 
                 'aspiration pneumonia': asp_pneumonia}
    result = prob_dict.get(problem, None)
    return result
tools = generate_tools_spec(retrieve_synonyms)

In [None]:
tools

In [None]:
client = OpenAI(api_key="dummy_key", base_url="http://localhost:8000/v1")

In [None]:
messages = [
    {"role": "user", "content": "What's the synonym for acute kidney injury?"}
]
client = OpenAI(api_key="dummy_key", base_url="http://localhost:8000/v1")
response = client.chat.completions.create(
    model=client.models.list().data[0].id,
    messages=messages,
    temperature= 0.1,
    tools=tools,
    tool_choice="auto" #none
)
response.choices[0].message

In [None]:
for tool_call in response.choices[0].message.tool_calls:
    print(tool_call)

In [None]:
def call_function(name, args):
    if name == "retrieve_synonyms":
        return retrieve_synonyms(**args)
    
for tool_call in response.choices[0].message.tool_calls:
    name = tool_call.function.name
    args = json.loads(tool_call.function.arguments)

    result = str(call_function(name, args))
    messages.append({
        "role": "tool",
        "tool_call_id": tool_call.id,
        "name": name,
        "output": result
    })

In [None]:
messages

툴콜링 됐을 때와 아닐때 모델 아웃풋 차이
```
ChatCompletionMessage(content=None, refusal=None, role='assistant', audio=None, function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='chatcmpl-tool-e9f31a3069694cc69887d4e03d16b412', function=Function(arguments='{"problem": "acute kidney injury"}', name='retrieve_synonyms'), type='function')], reasoning_content=None)


ChatCompletionMessage(content='The synonym for acute kidney injury (AKI) is acute renal failure (ARF).', refusal=None, role='assistant', audio=None, function_call=None, tool_calls=[], reasoning_content=None)
```

In [None]:
response = client.chat.completions.create(
    model=client.models.list().data[0].id,
    messages=messages,
)

In [None]:
print(response.choices[0].message)

In [None]:
response

In [None]:
class LLM:
    def __init__(self, client: OpenAI):
        self.client = client

    def get_response(
        self, 
        messages: List[Dict], 
        temperature: Optional[float] = 0.1,
        guided_: Optional[dict] = None, # {"guided_json": json_schema}, {"guided_choice": ["positive", "negative"]}
        tools: Optional[List[Dict]] = None
    ):
        try:
            request_params = {
                "model": self.client.models.list().data[0].id,
                "messages": messages,
                "temperature": temperature,
            }
            if guided_:
                request_params["extra_body"] = guided_
            if tools:
                request_params["tools"] = tools

            response = self.client.chat.completions.create(**request_params)

            return response.choices[0].message

        except Exception as e:
            print(f"An error occurred: {e}")
            return None


    # def test_single_prob(self, dataset: pd.DataFrame, problem: str):
    #     pbar = tqdm(total=dataset.shape[0], desc=f"Testing {problem}")
    #     for idx, row in dataset.iterrows():
    #         subj_text = row["Subjective"]
    #         obj_text = row["Objective"]

    #         prompt_specialist_formatted = prompt_specialist.format(
    #             PROBLEM=problem,
    #             SUBJ=subj_text,
    #             OBJ=obj_text
    #         )
    #         messages = [
    #             {"role": "system", "content": system_instruction},
    #             {"role": "user", "content": prompt_specialist_formatted}
    #         ]
    #         response = self.get_response(
    #             messages,
    #             schema= DiseaseDiagnosis.model_json_schema()
    #         )
    #         if response:
    #             dataset.at[idx, f"is_{problem.lower().replace(' ', '_')}_pred_single"] = response["diagnosis"]
    #             dataset.at[idx, f"is_{problem.lower().replace(' ', '_')}_reasoning_single"] = response["reasoning"]

    #         pbar.update(1)
    #     pbar.close()
    #     return dataset
    
    # def test_multi_prob(self, dataset: pd.DataFrame, problem_lst: list):

    #     problem_dict = {problem: (DiseaseDiagnosis, ...) for problem in problem_lst}

    #     DynamicResponseMultiDiagnosis = create_model(
    #                 'DynamicResponseMultiDiagnosis',
    #                 **problem_dict
    #             )

    #     pbar = tqdm(total=dataset.shape[0], desc="Testing Multi-Diagnosis")
    #     for idx, row in dataset.iterrows():
    #         subj_text = row["Subjective"]
    #         obj_text = row["Objective"]

    #         json_keys_list = [
    #             f'  "{disease}": {{"reasoning": "Your reasoning here...", "diagnosis": true_or_false}}'
    #             for disease in problem_lst
    #         ]
    #         json_keys = ",\n".join(json_keys_list)

    #         prompt_generalist_formatted = prompt_generalist.format(
    #             PROBLEM_LIST=", ".join(problem_lst),
    #             SUBJ=subj_text,
    #             OBJ=obj_text,
    #             json_keys=json_keys,
    #         )

    #         messages = [
    #             {"role": "system", "content": system_instruction},
    #             {"role": "user", "content": prompt_generalist_formatted}
    #         ]

    #         response = self.get_response(
    #             messages,
    #             schema=DynamicResponseMultiDiagnosis.model_json_schema()
    #         )
    #         if response:
    #             for problem in problem_lst:
    #                 dataset.at[idx, f"is_{problem.lower().replace(' ', '_')}_pred_multi"] = response[problem]["diagnosis"]
    #                 dataset.at[idx, f"is_{problem.lower().replace(' ', '_')}_reasoning_multi"] = response[problem]["reasoning"]
    #         pbar.update(1)
    #     pbar.close()
    #     return dataset



In [None]:
client = OpenAI(api_key="dummy_key", base_url="http://localhost:8000/v1")
df = pd.read_csv(
    '/home/yl3427/cylab/SOAP_MA/data/mergedBioNLP2023.csv',
    usecols=['File ID', 'Subjective', 'Objective', 'Summary', 'cleaned_expanded_Summary', 'terms']
)
df = df.fillna('').apply(lambda x: x.str.lower())
df['combined_summary'] = df['Summary'] + df['cleaned_expanded_Summary'] + df['terms']

mi = ["myocardial infarction", "elevation mi", "non-stemi", " NSTEMI", " stemi"]
chf = ["congestive heart failure", " chf", "HFrEF", "HFpEF"]
pulmonary_embolism = ["pulmonary embolism"]
pulmonary_hypertension = ["pulmonary hypertension", "pulmonary htn"]
sepsis = ["sepsis", "septic shock"]
urosepsis = ["urosepsis"]
meningitis = ["meningitis"]
aki = ["acute kidney injury", " aki", "acute renal failure", " arf"] # -> Acute tubular necrosis (ATN)인가 아닌가
atn = ["acute tubular necrosis", " atn"]
pancreatitis = ["pancreatitis"]
gi_bleed = ["gastrointestinal bleed", "gi bleed"]
hepatitis = ["hepatitis", " hep"]
cholangitis = ["cholangitis"]
asp_pneumonia = ["aspiration pneumonia"]

prob_dict = {'myocardial infarction': mi, 
                'congestive heart failure': chf, 
                'pulmonary embolism': pulmonary_embolism, 
                'pulmonary hypertension': pulmonary_hypertension, 
                'sepsis': sepsis, 
                'urosepsis': urosepsis, 
                'meningitis': meningitis, 
                'acute kidney injury': aki, 
                'acute tubular necrosis': atn, 
                'pancreatitis': pancreatitis, 
                'gastrointestinal bleed': gi_bleed, 
                'hepatitis': hepatitis, 
                'cholangitis': cholangitis, 
                'aspiration pneumonia': asp_pneumonia}

ids = set()
for name, lst in prob_dict.items():
    problem_terms = lst
    problem_terms = [term.lower() for term in problem_terms]

    # Use the first term as the primary term to check in the combined summary.
    primary_term = problem_terms[0]

    # Build a regex pattern that matches any of the problem terms.
    # pattern = '|'.join(problem_terms)
    pattern = '|'.join(re.escape(term) for term in problem_terms)

    mask = (
        df['combined_summary'].str.contains(pattern, na=False) &
        ~df['Subjective'].str.contains(pattern, na=False) &
        ~df['Objective'].str.contains(pattern, na=False)
    )

    filtered_df = df[mask]

    ids.update(filtered_df['File ID'])

agent = Agent(client=client)

df = df[df['File ID'].isin(ids)]
df = df.reset_index(drop=True)

result_df = agent.test_multi_prob(df, list(prob_dict.keys()))
result_df.to_csv("multi_result_full.csv", index=False)

for name, lst in prob_dict.items():
    result_df = agent.test_single_prob(result_df, name)
    result_df.to_csv(f"single_result_{name}.csv", index=False)
result_df.to_csv("single_result_full.csv", index=False)

In [None]:
import pandas as pd
qa_df = pd.read_csv("/home/yl3427/cylab/llm_reasoning/reasoning/data/step1_ALL.csv", encoding="latin-1")
print(len(qa_df))
for idx, row in qa_df.iterrows():
    if idx == 10:
        # print(row['question'] + "\n" + str(row['choice']))
        print(row)

In [None]:
qa_df