In [31]:
from tqdm import tqdm
from openai import OpenAI, AsyncOpenAI
import re
from typing import Optional, Union, List, get_origin, get_args, Any, Dict, Literal
import inspect
# from __future__ import annotations
import asyncio
import json
import logging
import pandas as pd
from pydantic import BaseModel, Field, create_model
import math
import demjson3


logging.basicConfig(
    level=logging.DEBUG,
    format="%(asctime)s - %(levelname)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
        handlers=[
        # logging.FileHandler('0310_MA_MedicalQA_step2_async.log', mode='w'),  # Write to file
        logging.StreamHandler()                     # Print to console
    ]
)
logger = logging.getLogger(__name__)

In [52]:
def safe_json_load(s: str) -> Any:
    """
    Attempts to parse a JSON string using the standard json.loads.
    If that fails (e.g. due to an unterminated string), it will try using
    a more forgiving parser (demjson3). If both attempts fail,
    the original string is returned.
    """
    try:
        return json.loads(s)
    except json.JSONDecodeError as e:
        logger.error("Standard json.loads failed: %s", e)
        try:
            logger.info("Attempting to parse with demjson3 as fallback.")
            result = demjson3.decode(s)
            logger.info("demjson3 successfully parsed the JSON.")
            return result
        except Exception as e2:
            logger.error("Fallback parsing with demjson3 also failed: %s. Returning original input.", e2)
            return s

In [119]:
def compare_numbers(num1: float, num2: float) -> str:
    """
    Compares a float (num1) with another number (num2) and returns a string.
    """
    if num1 < num2:
        return f"{num1} is less than {num2}"
    elif num1 > num2:
        return f"{num1} is greater than {num2}"
    else:
        return f"{num1} is equal to {num2}"

def provide_final_prediction(reasoning: str, prediction: Literal["T1", "T2", "T3", "T4"]) -> str:
    """
    Returns a string with the reasoning and the final prediction.
    """
    answer = f"""
    Reasoning: {reasoning}
    Final prediction: {prediction}
    """
    return answer

available_tools = {"compare_numbers": compare_numbers, "provide_final_prediction": provide_final_prediction}

In [120]:
tools = generate_tools_spec(compare_numbers, provide_final_prediction)

In [121]:
tools

[{'type': 'function',
  'function': {'name': 'compare_numbers',
   'description': 'Compares a float (num1) with another number (num2) and returns a string.',
   'parameters': {'type': 'object',
    'properties': {'num1': {'type': 'number'}, 'num2': {'type': 'number'}},
    'required': ['num1', 'num2']}}},
 {'type': 'function',
  'function': {'name': 'provide_final_prediction',
   'description': 'Returns a string with the reasoning and the final prediction.',
   'parameters': {'type': 'object',
    'properties': {'reasoning': {'type': 'string'},
     'prediction': {'type': 'string'}},
    'required': ['reasoning', 'prediction']}}}]

In [None]:
report = "SPECIMENS: 1. F/S LINGULAR NODULE. 2. F/S LEFT LOWER LOBE WEDGE. 3. MARGIN OF LEFT LOWER LOBE WEDGE. 4. LEVEL 5 LYMPH NODE, LEFT. 5. LEVEL 7 LYMPH NODE, LEFT. 6. LEVEL 9 LYMPH NODE, LEFT. 7. BASILAR SEGMENT LEFT LOWER LOBE. SEE ADDENDUM. Reason For Addendum #1: Molecular Studies. Reason For Addendum #2: Molecular Studies. DIAGNOSIS: 1. LUNG, LINGULA: WEDGE RESECTION. - ADENOCARCINOMA, SOLID PREDOMINANT (1 CM), SEE NOTE. - THE STAPLE LINE MARGIN IS FREE OF TUMOR. Note: The tumor consists of solid (60%), acinar (30%) and lepidic. (10%) components. Immunohistochemical stains show that the tumor. cells are positive for TTF-1 and Napsin-A while negative for pó3,. supporting the diagnosis. 2,3,7. LUNG, LEFT LOWER LOBE, BASILAR SEGMENT: WEDGE BIOPSIES AND. COMPLETION SEGMENTECTOMY. - ADENOCARCINOMA, MICROPAPILLARY PREDOMINANT (3.5 CM),. SEE NOTE. - THE BRONCHIAL AND VASCULAR MARGINS ARE FREE OF TUMOR. - FOUR LYMPH NODES, NEGATIVE FOR CARCINOMA (0/4). Note: The tumor consists of micropapillary (60%), papillary (20%),. lepidic (10%) and acinar (10%) components and measures 3.5 cm in. aggregate dimension in Parts 2, 3 and 7. It is morphologically. distinct from the concurrent lingular adenocarcinoma, favoring a. separate primary. Results of mutational studies will be reported in. addenda. 4. LYMPH NODE, LEFT LEVEL 5: BIOPSY. - FOUR LYMPH NODES, NEGATIVE FOR CARCINOMA (0/4). 5. LYMPH NODE, LEFT LEVEL 7: BIOPSY. - ONE LYMPH NODE, NEGATIVE FOR CARCINOMA (0/1). 6. LYMPH NODE, LEFT LEVEL 9: BIOPSY. - ONE LYMPH NODE, NEGATIVE FOR CARCINOMA (0/1). Specimens: 1: F/S LINGULAR NODULE. 2: F/S LEFT LOWER LOBE WEDGE. 3: MARGIN OF LEFT LOWER LOBE WEDGE. 4: LEVEL 5 LYMPH NODE, LEFT. 5: LEVEL 7 LYMPH NODE, LEFT. 6: LEVEL 9 LYMPH NODE, LEFT. 7: BASILAR SEGMENT LEFT LOWER LOBE. LUNG: Resection. SPECIMEN. Specimen: Lobe(s) of lung (specify). lingula and left lower. Procedure: Segmentectomy. Specimen Laterality: Left. Tumor Site: Lower lobe. Tumor Focality: Synchronous carcinomas (specify sites). lingula and left lower lobe. TUMOR. Histologic Type: Adenocarcinoma, mixed subtype. Histologic Grade: G3: Poorly differentiated. EXTENT. Tumor Size: Greatest dimension (cm). 3.5cm. Visceral Pleura Invasion: Not identified. MARGINS. Bronchial Margin. Bronchial Margin Involvement by Invasive Carcinoma: Uninvolved by. invasive carcinoma. Vascular Margin: Uninvolved by invasive carcinoma. Parenchymal Margin: Uninvolved by invasive carcinoma. ACCESSORY FINDINGS. Lymph-Vascular Invasion: Not identified. STAGE (pTNM). TNM Descriptors: m (multiple primary tumors). Primary Tumor (pT): pT2a: Tumor greater than 3 cm, but 5 cm or less in greatest. dimension surrounded by lung or visceral pleura without. bronchoscopic evidence of invasion more proximal than the lobar. bronchus (i.e., not in the main bronchus); or Tumor 5 cm or less. in greatest dimension with any of the following features of extent: involves main bronchus, 2 cm or more distal to the carina; invades. the visceral pleura; associated with atelectasis or obstructive. pneumonitis that extends to the hilar region but does not inolve the. entire lung. Regional Lymph Nodes (pN). pNO: No regional lymph node metastasis. Distant Metastases (pM): Not applicable. ADDITIONAL NON-TUMOR. Additional Pathologic Finding(s): Emphysema. CLINICAL HISTORY AND PRE - OPERATIVE DIAGNOSIS: M ex-smoker (120 pack-years) with a 1.6 cm RUL nodule, 3.7 cm LLL. ground-glass mass and a 1 cm lingular nodule. MACROSCOPIC DESCRIPTION: The specimen is received in seven parts, each labeled with the. patient's name. 1. Part one is received fresh, labeled 'lingular nodule'. It. consists of a lingular segment of the lung measuring 15 x 7 x 3.5. cm. The staple line measures 17 cm which is shaved and the bronchial. around inked blue. The pleura is gray pink and glistening and. mottled moderately with fine black streaks. Also noted is a precut. area on the pleura which reveals a sub pleural gray white firm. nodule, measuring 1 x 0.9 x 0.8 cm located 2 cm from the staple line. margin. The remainder of the bronchial is pink red blotchy and. crepitant. No other nodule is grossly identified. Frozen section is. performed on the nodule and resubmitted for permanent section. Representative section of the specimen and entire nodule are. submitted. 2. Part two is labeled 'left lower lobe wedge'. It consists of one. piece of yellow-pink soft tissue measuring 2 x 1 x 0.5 cm. Entirely. submitted in one cassette. 3. Part three is received fresh, labeled 'margin of left lower lobe. wedge'. It consists of a stapled strip of light tan tissue. measuring 3.5 x 0.4 x 0.3 cm. The staple line is shaved and the soft. tissue is submitted. 4. Part four is labeled 'level #5 lymph node'. It consists of. multiple soft red-black lymph nodes measuring 1.5 x 1.1 x 0.6 cm in. aggregate. Entirely submitted in one cassette. 5. Part five is labeled 'level #7 lymph node'. It consists of a 1.5. x 1 x 0.8 cm red-black lymph node. Entirely submitted in one. cassette. 6. Part six is labeled 'level #9 lymph node'. It consists of a 1.2. x 1 x 0.6 cm red-black lymph node. Entirely submitted in one. cassette. 7. Part seven is received fresh, labeled 'basilar segment left lower. lobe silk stitches on tumor location'. It consists of a segment of. lung measuring 20 x 14 x 4.5 cm. The bronchial margin measures 1.5. cm and the vascular margin measures 1 cm. The pleura is gray pink. glistening and mottled moderately with fine black streaks. There are. two staple lines measuring 8 cm and 5 cm in length. There is a. precut area on the pleura which is marked with a stitch. At this. area there is a grey tan firm mass, measuring 1.5 x 1 cm and located. 1.5 cm from the bronchial resection margin, and 0.4 cm from the. closest staple line which is shaved and the bronchial is inked. black. The mass is located on the mediastinal surface of the lung. The remaining of the bronchial is pink red blotchy and crepitant. No. other nodule is grossly identified. Representative sections of the. specimen and the entire mass area with adjacent tissue are. submitted. SUMMARY OF SECTIONS: 1A frozen section of the nodule. 1B-1D remaining of the nodule. 1E random section of the lung parenchyma. 1F staple line margin. 2A in toto. 3A soft tissue submitted, shaved. 4A in toto. 5A in toto. 6A in toto. 7A bronchial margin, shaved. 7B vascular margin, shaved. 7C-7H mass from stitch area with overlying pleura. 7I-7L mass from stitch area after shaving the staple line. 7M anthracotic lymph node from the hilum. SPECIAL PROCEDURES: INTRA - OPERATIVE CONSULTATION: 1. Lung, lingula: wedge resection (Frozen Section). - Adenocarcinoma with lepidic, acinar and solid patterns. Result reported by. 2. Lung, left lower lobe: wedge biopsy (Frozen Section). - Adenocarcinoma with acinar and micropapillary patterns. Result reported by. on. Intra-Operative Consultation #1 performed by. Intra-Operative Consultation #2 performed by. Final Diagnosis performed by. ADDENDUM 1: Integrated Oncology. MOLECULAR ONCOLOGY. KRAS MUTATION ANALYSIS. Specimen #: RESULTS: Positive for a p.G12A (c.34G>C) mutation in codon 12 of. the KRAS gene. INTERPRETATION: Mutations in the KRAS gene are reported to. correlate with poor prognosis and resistance to tyrosine kinase. inhibitor therapies in patients with non-small lung cancer. COMMENT: KRAS mutations occur in 15-30% of non-small-cell lung cancer (NSCLC). patients and are strongly associated with adenocarcinoma and smoking. history. This assay analyzes codons 12 and 13 in exon 2 of the KRAS gene;. based on the current literature, approximately 98% of mutations are. expected to occur in these codons. The analytical sensitivity of. the assay is approximately 10%; thus mutations present in a low. percentage of cells may not be detected. This test is validated for use in identifying KRAS codon 12 and. codon 13 mutations in fresh, frozen, or formalin-fixed paraffin. embedded tissue. In particular the test performance has been. established in samples of colorectal cancer and non-small cell lung. carcinoma which harbor these mutations, although several other. tissues are also known to harbor KRAS mutations (e.g. tumors of. pancreas, bile duct, ovary, appendix, etc.). METHOD/LIMITATION: Tissue sections are reviewed by a pathologist and relevant tumor is. selected for analysis. DNA is isolated from the sample, quantified. and amplified by polymerase chain reaction (PCR) using primers to. exon 2 of the KRAS gene. PCR products are subjected to single. nucleotide primer extension to detect mutations at codons 12 and 13;. primer extension products are analyzed using capillary gel. electrophoresis and fluorescence detection. False positive or. negative results may occur for reasons that include genetic variants. or somatic heterogeneity of the tissue sample. REFERENCES: NSCLC. Mascaux C, Iannino N, et al. British Journal of Cancer, 2005;. 92:11-139. Pao W, Wang TY, et al. PLoS Medicine, 205; 2(1):57-61. Eberhard DA, Johnson BE, et al. J Clin Oncol, 2005; 23:5900-5909. Han SW, Kim TY, et al. Clin Cancer Res. 2006; 12(8):253888888-2544. CRC. DiFiore C, Blanchard F, et al. Br J Cancer, 20007; 96:1166-1169. Lievre A, Bachet J-B, et al. Cancer Res, 2006; 66:3992-3995,. This test was developed and its performance characteristics. determined by. The laboratory is. regulated under the Clinical Laboratory Improvement Amendments of. 1988 (CLIA) as qualified to perform high complexity clinical. testing. This particular test is not considered a stand alone test. and should be only used in the context of other diagnostic tests or. clinical work-up related to treatment decisions. Addendum #1 performed by. ADDENDUM 2: EGFR Mutation Analysis. Specimen #: Clinical Data: Adenocarcinoma. RESULTS: No mutation detected. INTERPRETATION: No mutations were identified; in the sample. provided for analysis. Fewer than %5 of non-small-cell lung. carcinoma patients without identifiable mutations are reported to be. responsive to EGFR tyrosine kinase inhibitor therapies. COMMENT: Forty percent (40%) or more cellularity is optimal for. this mutation analysis. The sample submitted showed 50% tumor. cellularity upon pathologist review. A frequently occurring sequence change 2361G>A (Q787Q) was. identified. This polymorphism is known not to have clinical. significance. Mutations in th vrosine kinase domain of the epidermal growth. factor recept. gene are reported to be associated with. differenti. ness or resistance to EGFR tyrosine kinase. inhibitr. The objective response rate among. patien. ing. itation ranges from 55 to 82%. 8-21 of the EGFR tyrosine kinase domain;. noni 11: most mutations in non-small-cell lung. (NSCLC). are expected to occur in these exons. Mutations. present a less that 10-. 0% of extracted DNA may not be detected by. this Muta on. au is in a sample may change during tumor. progress on or the burse. therapy; therefore, this result cannot. be used. Ince. absence of a mutation in another sample. ! fr. n this tumor. This. non-small cell lung carcinoma. The. clinica. jit. tility of this test in other tumor types. is unknown. METHOD: Tissue sections are reviewed by a pathologist and relevant tumor is. selected for analysis. DNA is isolated from the samples, quantified. and amplified by polymerase chain reaction (PCR) using primers to. exons 18-21 of the EGFR gene. PCR products are analyzed by. bi-directional direct DNA sequencing using capillary gel. electrophoresis and fluorescence detection. False positive or. negative results may occur for reasons that include genetic variants. or somatic heterogeneity of the tissue sample. REFERENCES: 1. Azzoli CG, et al. J Clin Oncol. 2009; 27-6251-6266. 2. Jackman DM et al. Clin Carcinoma Res 2009; 15:5267-5273. 3. Mok TS, et al. N Engl J Med. 2009;361:947-957. 4. Sharma SV, et al. Nat Rev Carcinoma. 2007;169-181. DISCLAIMER: This test was developed and its performance characteristics. determined by. The laboratory is. regulated under the Clinical Laboratory Improvement Amendments of. 1988 (CLIA) as qualified to perform high complexity clinical. testing. This particular test is not considered a stand alone test. and should be only used in the context of other diagnostic tests or. clinical work-up related to treatment decisions. Addendum #2 performed by. The electronic signature attests that the named Attending. Pathologist has evaluated the specimen referred to in the signed. section of the report and formulated the diagnosis therein. This report may include one or more immunohistochemical stain. results that use analyte specific reagents. The tests were developed and their performance characteristics. determined by. They have not been cleared or approved by the US Food and Drug. Administration. The FDA has determined that such clearance or approval is not. necessary."

prompt = """
You are provided with a pathology report of a cancer patient.
Determine the T stage by comparing the tumor size with standard thresholds.
If you need to compare any numerical values with the threshold values, use the 'compare_numbers' function.
When you have decided the stage, provide the final answer by calling 'provide_final_prediction' function with your reasoning and the final prediction as arguments.

Pathology report:
{report}

"""

formatted_prompt = prompt.format(report=report)

In [111]:
client=OpenAI(base_url="http://localhost:8000/v1", api_key="dummy")
messages = [{"role": "system", "content": "You are an AI assistant specialized in cancer staging.",
    "role": "user", "content": formatted_prompt}]

2025-03-11 00:04:13 - DEBUG - load_ssl_context verify=True cert=None trust_env=True http2=False
2025-03-11 00:04:13 - DEBUG - load_verify_locations cafile='/usr/lib/ssl/certs/ca-certificates.crt'


In [117]:
resp = client.chat.completions.create(
    model = "meta-llama/Llama-3.3-70B-Instruct", 
    messages = messages,
    tools = tools,
    tool_choice="auto", 
    # tool_choice="none",
    )


2025-03-11 00:05:38 - DEBUG - Request options: {'method': 'post', 'url': '/chat/completions', 'files': None, 'json_data': {'messages': [{'role': 'user', 'content': "\nYou are provided with a pathology report of a cancer patient.\n\nPathology report:\nSPECIMENS: 1. F/S LINGULAR NODULE. 2. F/S LEFT LOWER LOBE WEDGE. 3. MARGIN OF LEFT LOWER LOBE WEDGE. 4. LEVEL 5 LYMPH NODE, LEFT. 5. LEVEL 7 LYMPH NODE, LEFT. 6. LEVEL 9 LYMPH NODE, LEFT. 7. BASILAR SEGMENT LEFT LOWER LOBE. SEE ADDENDUM. Reason For Addendum #1: Molecular Studies. Reason For Addendum #2: Molecular Studies. DIAGNOSIS: 1. LUNG, LINGULA: WEDGE RESECTION. - ADENOCARCINOMA, SOLID PREDOMINANT (1 CM), SEE NOTE. - THE STAPLE LINE MARGIN IS FREE OF TUMOR. Note: The tumor consists of solid (60%), acinar (30%) and lepidic. (10%) components. Immunohistochemical stains show that the tumor. cells are positive for TTF-1 and Napsin-A while negative for pó3,. supporting the diagnosis. 2,3,7. LUNG, LEFT LOWER LOBE, BASILAR SEGMENT: WEDGE BIO

In [None]:
if resp.choices[0].message.tool_calls:
    print("Tool calls detected.")
    messages.append({
    "role": "assistant",
    "tool_calls": resp.choices[0].message.tool_calls
    })
    for call in resp.choices[0].message.tool_calls:
        
        args = safe_json_load(call.function.arguments)
        if call.function.name == "provide_final_prediction":
            return args
        result = available_tools[call.function.name](**args)
        print(result)
        messages.append({
        "role": "tool",
        "content": result,
        "tool_call_id": call.id,
        "name": call.function.name,
        })

else:
    final_result = resp.choices[0].message.content
    messages.append({"role": "assistant", "content": final_result})
    print(result)

Tool calls detected.

    Reasoning: The tumor size is 3.5 cm, which is greater than 3 cm but less than 5 cm. According to the TNM staging system, this corresponds to a T2a stage.
    Final prediction: T2a
    
Final prediction detected.


In [87]:
resp.choices[0].message

ChatCompletionMessage(content='In order to determine the T stage of the lung cancer based on the provided pathology report and to follow the format required for the response, we need to compare the size of the tumor with standard thresholds defined by the TNM staging system. The size of the tumor is mentioned as 3.5 cm in the report. According to the TNM staging system for lung cancer, a tumor size greater than 3 cm but less than or equal to 5 cm corresponds to T2a.\n\nHere is how you could write the Python code to implement these steps and provide the final answer:\n\n```python\ndef provide_final_prediction(reasoning, prediction):\n    return f"The final answer is {prediction} because {reasoning}."\n\n# Compare the tumor size with the threshold for T2a\nnum1 = 3.5  # Tumor size in cm\nnum2 = 3    # Threshold size for T2a in cm\n\n# Determine the T stage based on the comparison\nif num1 > num2 and num1 <= 5:\n    T_stage = "T2a"\n    reasoning = "the tumor size is greater than 3 cm but

In [125]:
args

{'reasoning': 'The tumor size is 3.5 cm, which is greater than 3 cm but less than 5 cm. According to the TNM staging system, this corresponds to a T2a stage.',
 'prediction': 'T2a'}

In [127]:
resp.choices[0].message

ChatCompletionMessage(content=None, refusal=None, role='assistant', audio=None, function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='chatcmpl-tool-8f3cefefc2434cb9ab9bed620d8071f9', function=Function(arguments='{"reasoning": "The tumor size is 3.5 cm, which is greater than 3 cm but less than 5 cm. According to the TNM staging system, this corresponds to a T2a stage.", "prediction": "T2a"}', name='provide_final_prediction'), type='function')], reasoning_content=None)

In [None]:
def safe_json_load(s: str) -> Any:
    """
    Attempts to parse a JSON string using the standard json.loads.
    If that fails (e.g. due to an unterminated string), it will try using
    a more forgiving parser (demjson3). If both attempts fail,
    the original string is returned.
    """
    try:
        return json.loads(s)
    except json.JSONDecodeError as e:
        logger.error("Standard json.loads failed: %s", e)
        try:
            logger.info("Attempting to parse with demjson3 as fallback.")
            result = demjson3.decode(s)
            logger.info("demjson3 successfully parsed the JSON.")
            return result
        except Exception as e2:
            logger.error("Fallback parsing with demjson3 also failed: %s. Returning original input.", e2)
            return s


class LLMAgent:
    def __init__(self, system_prompt: str, 
                 client=AsyncOpenAI(base_url="http://localhost:8000/v1", api_key="dummy")):
        self.client = client
        self.messages = [{"role": "system", "content": system_prompt}]

    async def llm_call(self, user_prompt: str,
                       guided_: dict = None,
                       tools: List[dict] = None) -> Any:
        logger.debug(f"LLMAgent.llm_call() - user_prompt[:60]: {user_prompt[:60]}...")
        self.messages.append({"role": "user", "content": user_prompt})
        params = {
            "model": "meta-llama/Llama-3.3-70B-Instruct",
            "messages": self.messages,
            "temperature": 0.5,
        }
        if guided_:
            logger.debug(f"Guided JSON/choice detected: {guided_}")
            params["extra_body"] = guided_
        if tools:
            params["tools"] = tools

        response = await self.client.chat.completions.create(**params)
        return response.choices[0].message
    
    def append_message(self, content, role='assistant'):
        logger.debug(f"Appending message with role='{role}' to conversation.")
        self.messages.append({"role": role, "content": content})
        return


class InitializerAgent(LLMAgent):
    def __init__(self, n_specialists: int):
        self.n_specialists = n_specialists
        system_prompt = (
            "You are an initializer agent in a multi-agent AI system designed to handle medical questions.\n"
            f"Your job is to select {self.n_specialists} medical specialists whose expertise best matches the user's query.\n"
            "For each specialist, specify their role and a list of relevant expertise areas related to the query.\n"
        )
        super().__init__(system_prompt)

    async def identify_specialists(self, query: str):
        logger.info("InitializerAgent: Identifying specialists.")
        class Specialist(BaseModel):
            specialist: str = Field(..., description="Role of the specialist")
            expertise: List[str] = Field(..., description="Areas of expertise for the specialist.")
        panel_dict = {f"Specialist_{i+1}": (Specialist, ...) for i in range(self.n_specialists)}
        SpecialistPanel = create_model("SpecialistPanel", **panel_dict)

        user_prompt = (
            "Here is the user's query:\n\n"
            f"<Query>\n{query}\n</Query>\n\n"
            "Based on the above query, identify the most suitable specialists."
        )
        response = await self.llm_call(user_prompt, guided_={"guided_json": SpecialistPanel.schema()})
        self.append_message(content=response.content)
        logger.debug(f"InitializerAgent response: {response.content}")
        return safe_json_load(response.content)


class SpecialistAgent(LLMAgent):
    def __init__(self, specialist: str, expertise: List[str]):
        self.specialist = specialist
        self.expertise = expertise
        system_prompt = (
            f"You are a {specialist}.\n"
            f"Your expertise includes:\n{expertise}\n"
            f"Analyze the user's query from the perspective of a {specialist}."
        )
        super().__init__(system_prompt)

    async def analyze_query(self, query: str, choices: List[str]):
        logger.info(f"[{self.specialist}] Analyzing query...")
        self.query = query
        self.choices = tuple(choices)
        choices_str = ', '.join(choices)

        user_prompt = (
            "Here is the query of interest:\n\n"
            f"<Query>\n{query}\n</Query>\n\n"
            f"The possible answers are: {choices_str}.\n"
            f"From your perspective as a {self.specialist}, first provide step-by-step reasoning (rationale), "
            "and then clearly state your final answer.\n\n"
        )

        class Response(BaseModel):
            reasoning: str = Field(..., description="Step-by-step reasoning leading to the final choice")
            choice: Literal[self.choices] = Field(..., description="Final choice")

        response = await self.llm_call(user_prompt, guided_={"guided_json": Response.model_json_schema()})
        self.append_message(content=response.content)
        logger.debug(f"[{self.specialist}] analyze_query response: {response.content}")
        return safe_json_load(response.content)
    
    async def debate(self, agents: Dict[str, Any]):
        logger.info(f"[{self.specialist}] Debating with other specialists.")
        other_specialists = {}
        for name, value in agents.items():
            if name != self.specialist:
                other_specialists[name] = value

        formatted_other_specialists = json.dumps(other_specialists, indent=4)
        user_prompt = (
            "Regarding the previous query, other specialists have also provided their reasoning and choices.\n"
            "Critically evaluate the reasoning and choice of those specialists.\n\n"
            f"Specialists and their choices:\n{formatted_other_specialists}\n\n"
            "Considering the newly provided perspectives, refine your own reasoning and choice.\n"
            "You can change your choice or stick with the original one.\n\n"
        )

        class Response(BaseModel):
            reasoning: str = Field(..., description="Step-by-step reasoning leading to final choice")
            choice: Literal[self.choices] = Field(..., description="Final choice")

        response = await self.llm_call(user_prompt, guided_={"guided_json": Response.model_json_schema()})
        self.append_message(content=response.content)
        logger.debug(f"[{self.specialist}] debate response: {response.content}")
        return safe_json_load(response.content)


class AggregatorAgent(LLMAgent):
    def __init__(self):
        system_prompt = (
            "You are the aggregator agent in a multi-agent AI system for medical queries.\n"
            "You have access to each specialist's entire chat history.\n"
            "Your job is to read those full conversations, analyze their reasoning and any conflicts, "
            "and then provide a single, definitive answer to the user.\n"
            "Provide a clear explanation for your final conclusion."
        )
        super().__init__(system_prompt)

    async def aggregate(self, query: str, choices: List[str], specialists_chat_history: Dict[str, Any]):
        logger.info("AggregatorAgent: Aggregating final answer from all specialists' chat history.")
        specialists_str = json.dumps(specialists_chat_history, indent=4)

        user_prompt = (
            f"Here is the query of interest:\n\n"
            f"<Query>\n{query}\n</Query>\n\n"
            "Below is the *entire conversation history* for each specialist:\n\n"
            f"{specialists_str}\n\n"
            "Please review all these conversations in detail and produce one single, definitive final answer. "
            "If there is no unanimous or majority choice, choose the answer best supported by the specialists' reasoning. "
            "Clearly justify your reasoning, then provide your final recommended answer."
        )

        class AggregatedResponse(BaseModel):
            aggregated_reasoning: str = Field(..., description="Detailed reasoning behind final choice")
            aggregated_choice: Literal[tuple(choices)] = Field(..., description="Single recommended choice")

        response = await self.llm_call(user_prompt, guided_={"guided_json": AggregatedResponse.model_json_schema()})
        self.append_message(content=response.content)
        logger.debug(f"AggregatorAgent response: {response.content}")
        return safe_json_load(response.content)


def check_consensus(status_dict: Dict[str, Any]) -> str:
    """
    Returns the consensus choice if >= 80% of specialists agree, else returns None.
    """
    logger.info("Checking for consensus among specialists.")
    specialists_count = len(status_dict)
    consensus_threshold = math.ceil(0.8 * specialists_count)

    choice_counts = {}
    for _, specialist_data in status_dict.items():
        final_choice = specialist_data['response_after_debate']['choice']
        choice_counts[final_choice] = choice_counts.get(final_choice, 0) + 1

    for choice, count in choice_counts.items():
        if count >= consensus_threshold:
            logger.info(f"Consensus found on choice '{choice}' with {count}/{specialists_count} specialists.")
            return choice
    logger.info("No consensus found.")
    return None


# --------------------------------
# 3) PROCESS A SINGLE ROW/QUERY
# --------------------------------
async def process_single_query(
    question_text: str,
    ground_truth: str,
    choices: List[str],
    n_specialists: int) -> Dict[str, Any]:
    """
    Given a single query (question + ground_truth + multiple choices), 
    run the multi-agent system (Initializer -> Specialists -> Debates -> Aggregator if needed).
    Return the final dictionary containing all the specialists' output and aggregator results.
    """

    # 1. Initialize specialists
    initializer = InitializerAgent(n_specialists=n_specialists)
    json_resp = await initializer.identify_specialists(query=question_text)
    if not isinstance(json_resp, dict):
        logger.error("Invalid JSON output from initializer; skipping this query.")
        return {}  # Skip processing and continue to the next query

    # Build specialists status dict
    specialists_status = {}
    for _, agent_info in json_resp.items():
        specialist_name = agent_info["specialist"]
        expertise = agent_info["expertise"]
        specialists_status[specialist_name] = {"expertise": expertise}
    
    # 2. Run analyze_query for each specialist in parallel
    async def analyze_specialist(specialist_name: str, status: Dict[str, Any], query: str, choices: List[str]):
        specialist_agent = SpecialistAgent(specialist=specialist_name, expertise=status["expertise"])
        status["instance"] = specialist_agent
        message = await specialist_agent.analyze_query(query=query, choices=choices)
        if not isinstance(message, dict):
            logger.error(f"[{specialist_name}] Invalid JSON output from specialist; skipping this specialist.")
            return None
        status["original_response"] = message
        logger.info(f"[{specialist_name}] Completed analyze_query.")
        return specialist_name

    analyze_tasks = [
        asyncio.create_task(analyze_specialist(name, status, question_text, choices))
        for name, status in specialists_status.items()
    ]
    analyze_results = await asyncio.gather(*analyze_tasks)
    if any(r is None for r in analyze_results):
        logger.error("At least one specialist failed; skipping this query.")
        return {}  # Skip processing and continue to the next query

    # Build a minimal dictionary for debate (remove 'instance')
    input_specialists_dict = {
        specialist_name: {
            k: v for k, v in specialist_data.items() 
            if k != "instance"
        }
        for specialist_name, specialist_data in specialists_status.items()
    }

    # 3. Debate step, also in parallel
    async def debate_specialist(specialist_name: str, status: Dict[str, Any], specialists_dict: Dict[str, Any]):
        specialist_agent = status["instance"]
        message = await specialist_agent.debate(specialists_dict)
        if not isinstance(message, dict):
            logger.error(f"[{specialist_name}] Invalid JSON output during debate; skipping this specialist.")
            return None
        status["response_after_debate"] = message
        specialists_dict[specialist_name]["response_after_debate"] = message
        logger.info(f"[{specialist_name}] Completed debate.")
        return specialist_name

    debate_tasks = [
        asyncio.create_task(debate_specialist(name, status, input_specialists_dict))
        for name, status in specialists_status.items()
    ]
    debate_results = await asyncio.gather(*debate_tasks)
    if any(r is None for r in debate_results):
        logger.error("At least one specialist failed during debate; skipping this query.")
        return {}  # Skip processing and continue to the next query

    # 4. Check consensus
    consensus_choice = check_consensus(input_specialists_dict)
    aggregator_result = None

    if consensus_choice is not None:
        logger.info(f"Consensus reached: {consensus_choice}")
        input_specialists_dict["Aggregator"] = {
            "final_choice": consensus_choice, 
            "final_reasoning": "Consensus reached"
        }
    else:
        logger.info("No consensus reached; enabling aggregator path...")
        aggregator = AggregatorAgent()
        aggregated_response = await aggregator.aggregate(
            query=question_text,
            choices=choices,
            specialists_chat_history=input_specialists_dict
        )
        if not isinstance(aggregated_response, dict):
            logger.error("Invalid JSON output from aggregator; skipping this query.")
            return {}  # Skip processing and continue to the next query
        
        final_choice = aggregated_response['aggregated_choice']
        final_reasoning = aggregated_response['aggregated_reasoning']

        logger.info(f"Aggregator final choice: {final_choice}")
        logger.info(f"Aggregator reasoning: {final_reasoning}")

        aggregator_result = {
            "final_choice": final_choice,
            "final_reasoning": final_reasoning
        }
        input_specialists_dict["Aggregator"] = aggregator_result

    # Add question and ground_truth for reference
    input_specialists_dict["Question"] = question_text
    input_specialists_dict["Answer"] = ground_truth

    return input_specialists_dict


async def process_multiple_queries(
    qa_df: pd.DataFrame,
    choices: List[str],
    n_specialists: int,
    max_concurrency: int = 5
) -> List[Dict[str, Any]]:
    """
    Process multiple rows (queries) in `qa_df` asynchronously.
    Each row is passed to `process_single_query`.
    
    :param qa_df: DataFrame with columns ["question", "choice", "ground_truth"] at least.
    :param choices: A list of all possible answer choices, e.g. ["A", "B", "C", "D", "E"].
    :param n_specialists: Number of specialists to initialize for each query.
    :param max_concurrency: Limit on how many queries to process simultaneously.
    :return: A list of result dictionaries, one per row in `qa_df`.
    """

    # This semaphore keeps at most `max_concurrency` tasks running at once
    semaphore = asyncio.Semaphore(max_concurrency)

    async def run_single_query(row_idx: int, row: pd.Series):
        """
        This inner function is used to call `process_single_query` with concurrency control.
        """
        async with semaphore:
            logger.info(f"Starting row {row_idx}")
            question_text = row["question"] + "\n" + str(row["choice"])
            ground_truth = str(row["ground_truth"])
            result = await process_single_query(
                question_text=question_text,
                ground_truth=ground_truth,
                choices=choices,
                n_specialists=n_specialists
            )
            logger.info(f"Finished row {row_idx}")
            return result

    tasks = [
        asyncio.create_task(run_single_query(i, row))
        for i, row in qa_df.iterrows()
    ]

    # Wait for all tasks to complete
    all_results = await asyncio.gather(*tasks)

    # `all_results` is a list of return values from each `run_single_query`
    return all_results

async def main():

    logger.info("===== MAIN START =====")

    # Example CSV loading
    df_path = "/home/yl3427/cylab/llm_reasoning/reasoning/data/step2_ALL.csv"
    qa_df = pd.read_csv(df_path, encoding="latin-1")  # columns: idx, question, choice, ground_truth, qn_num
    # qa_df = pd.read_csv('/home/yl3427/cylab/SOAP_MA/Input/SOAP_5_problems.csv')
    logger.info("Loaded dataframe with %d rows.", len(qa_df))


    ################# 'process_single_query' Example usage #################
    results = []
    for idx, row in qa_df.iterrows():
        # if idx <= 10:
        #     continue
        logger.info(f"Processing row index {idx}")
        question_text = row["question"] + "\n" + str(row["choice"])
        ground_truth = str(row["ground_truth"])

        # patient_info = str(row["Subjective"]) + "\n" + str(row['Objective'])
        # question_text = f"""
        # Based on the following patient report, does the patient have sepsis?

        # {patient_info}
        # """
        # ground_truth = str(row["terms"])
        

        # Run the multi-agent system for this single query
        result_dict = await process_single_query(
            question_text=question_text,
            ground_truth=ground_truth,
            choices=["A", "B", "C", "D", "E"],
            # choices=["Yes", "No"],
            n_specialists=5
        )
        # result_dict["File ID"] = row["File ID"]
        result_dict["qn_num"] = row["qn_num"]

        # Store result for later evaluation
        results.append(result_dict)

        if idx % 10 == 0:
            output_json_path = f"/home/yl3427/cylab/SOAP_MA/Output/MedicalQA/step2_{idx}.json"
            with open(output_json_path, "w", encoding="utf-8") as f:
                json.dump(results, f, indent=2, ensure_ascii=False)
            logger.info(f"Saved aggregated results to {output_json_path}")

    # OPTIONAL: Save results to JSON
    output_json_path = "/home/yl3427/cylab/SOAP_MA/Output/MedicalQA/step2_final.json"
    with open(output_json_path, "w", encoding="utf-8") as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    logger.info(f"Saved aggregated results to {output_json_path}")

    logger.info("===== MAIN END =====")


2025-03-07 15:14:33 - DEBUG - load_ssl_context verify=True cert=None trust_env=True http2=False
2025-03-07 15:14:33 - DEBUG - load_verify_locations cafile='/usr/lib/ssl/certs/ca-certificates.crt'


In [None]:
# log
for name, status in specialists_status.items():
    print(f"Specialist: {name}")
    message = status["instance"].messages
    print(message)


In [None]:
system_instruction = """
You are a knowledgeable and meticulous medical expert specialized in diagnosing diseases based on partial information from SOAP notes. 
You will receive either:
1. A single-disease assessment request (“specialist” scenario), or 
2. A multiple-disease assessment request (“generalist” scenario).

In the “specialist” scenario, you focus on one disease and analyze evidence within the Subjective (S) and Objective (O) sections for or against that single disease. Your final answer must be in valid JSON with:
    {
        "reasoning": "Concise explanation of your thought process",
        "diagnosis": true_or_false
    }

In the “generalist” scenario, you must assess each disease from a given list. For each disease, identify subjective and objective evidence that supports or refutes the disease. If evidence strongly supports it, conclude the diagnosis is true; if not, conclude false. If conflicting or incomplete, offer a reasoned explanation and a likely conclusion. Your final answer must be in valid JSON with each disease as a key:
    {
      "DiseaseName1": { "reasoning": "Your reasoning...", "diagnosis": true_or_false },
      "DiseaseName2": { "reasoning": "Your reasoning...", "diagnosis": true_or_false },
      ...
    }

When reasoning, consider clinical clues like symptoms, exam findings, risk factors, and labs. Clearly and succinctly justify why each disease is likely or unlikely. If any information is missing or ambiguous, note the uncertainty and choose the most probable conclusion.

Follow these instructions precisely:
• Always return output in the exact JSON format requested (no extra fields or text).
• Provide concise, medically sound rationale for each decision.
"""

prompt_specialist = """
You are a medical expert specializing in {PROBLEM}.

You are provided with only the Subjective (S) and Objective (O) sections of a patient's SOAP-formatted progress note for a potential case of {PROBLEM}.
Identify relevant clues in the subjective and objective sections that align with or argue against {PROBLEM}. If evidence strongly suggests {PROBLEM}, conclude the diagnosis is true; if not, conclude it is false. If the evidence is uncertain or conflicting, explain your reasoning and lean toward the most likely conclusion.

Patient Report:
<Subjective>
{SUBJ}
</Subjective>

<Objective>
{OBJ}
</Objective>

Your answer must be output as valid JSON formatted exactly as follows:
    {{
        "reasoning": "Your reasoning here...",
        "diagnosis": true_or_false
    }}
"""

prompt_generalist = """
You are a medical expert in diagnostic reasoning.

You are provided with only the Subjective (S) and Objective (O) sections of a patient's SOAP-formatted progress note that may be relevant to one or more of the following diseases:
{PROBLEM_LIST}

The patient may have one or more of these diseases, or none at all. Evaluate each disease independently.
Identify relevant clues in the subjective and objective sections that align with or argue against each disease. If evidence strongly suggests the disease, conclude the diagnosis is true; if not, conclude it is false. If the evidence is uncertain or conflicting, explain your reasoning and lean toward the most likely conclusion.

Patient Report:
<Subjective>
{SUBJ}
</Subjective>

<Objective>
{OBJ}
</Objective>

Your answer must be output as valid JSON formatted exactly as follows:
{{
{json_keys}
}}
"""

system_instruction_mediator = """
You are the mediator agent in a medical multi-agent diagnostic system. 
"""

In [None]:
class Response(BaseModel):
    reasoning: str = Field(..., description="Step-by-step reasoning leading to the final diagnosis.")
    diagnosis: bool = Field(..., description="True if patient has the disease, False otherwise.")

In [90]:
from typing import get_origin, get_args, Union, Any

def generate_tools_spec(*functions):
    """
    Generate a list of tool definitions (function schemas) for OpenAI's tool calling.
    
    Each function's name, docstring, and parameters (with types and required flags)
    are extracted to form the JSON schema as a dictionary.
    
    Args:
        *functions: One or more Python function objects to document.
    Returns:
        List[dict]: A list of tool definition dictionaries compatible with OpenAI API.
    """
    # Mapping of Python types to JSON Schema types
    type_map = {
        str: "string",
        int: "integer",
        float: "number",
        bool: "boolean",
        list: "array",
        dict: "object",
        type(None): "null"
    }
    tools = []
    for func in functions:
        # Basic function info
        func_name = func.__name__
        func_description = func.__doc__.strip() if func.__doc__ else ""
        sig = inspect.signature(func)
        
        properties = {}
        required = []
        for param in sig.parameters.values():
            # Skip *args and **kwargs as they cannot be described in JSON schema easily
            if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
                continue
            param_name = param.name

            # Determine JSON schema type from annotation (if available)
            json_type = "string"  # default type
            annotation = param.annotation
            if annotation is not inspect._empty:
                origin = get_origin(annotation)
                # Handle Optional[X] or Union[X, None]
                if origin is Union:
                    args = [t for t in get_args(annotation) if t is not type(None)]
                    if len(args) == 1:
                        annotation = args[0]
                        origin = get_origin(annotation)
                # Map to JSON type if direct or via origin for generics
                if annotation in type_map:
                    json_type = type_map[annotation]
                elif origin in type_map:
                    json_type = type_map[origin]
                # Handle list item types for generics like list[int]
                if json_type == "array":
                    item_type = "string"  # default for items
                    args = get_args(annotation)
                    if args:
                        # Use first type argument for list item if present
                        item_type = type_map.get(args[0], "string")
                    properties[param_name] = {
                        "type": "array",
                        "items": {"type": item_type}
                    }
                elif json_type == "object":
                    # For dicts or unknown complex types, use object without specifics
                    properties[param_name] = {"type": "object"}
                else:
                    properties[param_name] = {"type": json_type}
            else:
                # No annotation, assume string
                properties[param_name] = {"type": "string"}

            # Mark required if no default value
            if param.default is inspect._empty:
                required.append(param_name)
        
        # Build the tool dictionary for this function
        tool_dict = {
            "type": "function",
            "function": {
                "name": func_name,
                "description": func_description,
                "parameters": {
                    "type": "object",
                    "properties": properties
                }
            }
        }
        if required:
            tool_dict["function"]["parameters"]["required"] = required
        tools.append(tool_dict)
    return tools


In [None]:
def retrieve_synonyms(problem: str) -> Optional[List[str]]: 
    """
    Retrieve the list of synonyms for a given problem.
    """
    problem = problem.lower()
    mi = ["myocardial infarction", "elevation mi", "non-stemi", " NSTEMI", " stemi"]
    chf = ["congestive heart failure", " chf", "HFrEF", "HFpEF"]
    pulmonary_embolism = ["pulmonary embolism"]
    pulmonary_hypertension = ["pulmonary hypertension", "pulmonary htn"]
    sepsis = ["sepsis", "septic shock"]
    urosepsis = ["urosepsis"]
    meningitis = ["meningitis"]
    aki = ["acute kidney injury", " aki", "acute renal failure", " arf"] # -> Acute tubular necrosis (ATN)인가 아닌가
    atn = ["acute tubular necrosis", " atn"]
    pancreatitis = ["pancreatitis"]
    gi_bleed = ["gastrointestinal bleed", "gi bleed"]
    hepatitis = ["hepatitis", " hep"]
    cholangitis = ["cholangitis"]
    asp_pneumonia = ["aspiration pneumonia"]

    prob_dict = {'myocardial infarction': mi, 
                 'congestive heart failure': chf, 
                 'pulmonary embolism': pulmonary_embolism, 
                 'pulmonary hypertension': pulmonary_hypertension, 
                 'sepsis': sepsis, 
                 'urosepsis': urosepsis, 
                 'meningitis': meningitis, 
                 'acute kidney injury': aki, 
                 'acute tubular necrosis': atn, 
                 'pancreatitis': pancreatitis, 
                 'gastrointestinal bleed': gi_bleed, 
                 'hepatitis': hepatitis, 
                 'cholangitis': cholangitis, 
                 'aspiration pneumonia': asp_pneumonia}
    result = prob_dict.get(problem, None)
    return result
tools = generate_tools_spec(retrieve_synonyms)

In [None]:
tools

In [None]:
client = OpenAI(api_key="dummy_key", base_url="http://localhost:8000/v1")

In [None]:
messages = [
    {"role": "user", "content": "What's the synonym for acute kidney injury?"}
]
client = OpenAI(api_key="dummy_key", base_url="http://localhost:8000/v1")
response = client.chat.completions.create(
    model=client.models.list().data[0].id,
    messages=messages,
    temperature= 0.1,
    tools=tools,
    tool_choice="auto" #none
)
response.choices[0].message

In [None]:
for tool_call in response.choices[0].message.tool_calls:
    print(tool_call)

In [None]:
def call_function(name, args):
    if name == "retrieve_synonyms":
        return retrieve_synonyms(**args)
    
for tool_call in response.choices[0].message.tool_calls:
    name = tool_call.function.name
    args = json.loads(tool_call.function.arguments)

    result = str(call_function(name, args))
    messages.append({
        "role": "tool",
        "tool_call_id": tool_call.id,
        "name": name,
        "output": result
    })

In [None]:
messages

툴콜링 됐을 때와 아닐때 모델 아웃풋 차이
```
ChatCompletionMessage(content=None, refusal=None, role='assistant', audio=None, function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='chatcmpl-tool-e9f31a3069694cc69887d4e03d16b412', function=Function(arguments='{"problem": "acute kidney injury"}', name='retrieve_synonyms'), type='function')], reasoning_content=None)


ChatCompletionMessage(content='The synonym for acute kidney injury (AKI) is acute renal failure (ARF).', refusal=None, role='assistant', audio=None, function_call=None, tool_calls=[], reasoning_content=None)
```

In [None]:
response = client.chat.completions.create(
    model=client.models.list().data[0].id,
    messages=messages,
)

In [None]:
print(response.choices[0].message)

In [None]:
response

In [None]:
class LLM:
    def __init__(self, client: OpenAI):
        self.client = client

    def get_response(
        self, 
        messages: List[Dict], 
        temperature: Optional[float] = 0.1,
        guided_: Optional[dict] = None, # {"guided_json": json_schema}, {"guided_choice": ["positive", "negative"]}
        tools: Optional[List[Dict]] = None
    ):
        try:
            request_params = {
                "model": self.client.models.list().data[0].id,
                "messages": messages,
                "temperature": temperature,
            }
            if guided_:
                request_params["extra_body"] = guided_
            if tools:
                request_params["tools"] = tools

            response = self.client.chat.completions.create(**request_params)

            return response.choices[0].message

        except Exception as e:
            print(f"An error occurred: {e}")
            return None


    # def test_single_prob(self, dataset: pd.DataFrame, problem: str):
    #     pbar = tqdm(total=dataset.shape[0], desc=f"Testing {problem}")
    #     for idx, row in dataset.iterrows():
    #         subj_text = row["Subjective"]
    #         obj_text = row["Objective"]

    #         prompt_specialist_formatted = prompt_specialist.format(
    #             PROBLEM=problem,
    #             SUBJ=subj_text,
    #             OBJ=obj_text
    #         )
    #         messages = [
    #             {"role": "system", "content": system_instruction},
    #             {"role": "user", "content": prompt_specialist_formatted}
    #         ]
    #         response = self.get_response(
    #             messages,
    #             schema= DiseaseDiagnosis.model_json_schema()
    #         )
    #         if response:
    #             dataset.at[idx, f"is_{problem.lower().replace(' ', '_')}_pred_single"] = response["diagnosis"]
    #             dataset.at[idx, f"is_{problem.lower().replace(' ', '_')}_reasoning_single"] = response["reasoning"]

    #         pbar.update(1)
    #     pbar.close()
    #     return dataset
    
    # def test_multi_prob(self, dataset: pd.DataFrame, problem_lst: list):

    #     problem_dict = {problem: (DiseaseDiagnosis, ...) for problem in problem_lst}

    #     DynamicResponseMultiDiagnosis = create_model(
    #                 'DynamicResponseMultiDiagnosis',
    #                 **problem_dict
    #             )

    #     pbar = tqdm(total=dataset.shape[0], desc="Testing Multi-Diagnosis")
    #     for idx, row in dataset.iterrows():
    #         subj_text = row["Subjective"]
    #         obj_text = row["Objective"]

    #         json_keys_list = [
    #             f'  "{disease}": {{"reasoning": "Your reasoning here...", "diagnosis": true_or_false}}'
    #             for disease in problem_lst
    #         ]
    #         json_keys = ",\n".join(json_keys_list)

    #         prompt_generalist_formatted = prompt_generalist.format(
    #             PROBLEM_LIST=", ".join(problem_lst),
    #             SUBJ=subj_text,
    #             OBJ=obj_text,
    #             json_keys=json_keys,
    #         )

    #         messages = [
    #             {"role": "system", "content": system_instruction},
    #             {"role": "user", "content": prompt_generalist_formatted}
    #         ]

    #         response = self.get_response(
    #             messages,
    #             schema=DynamicResponseMultiDiagnosis.model_json_schema()
    #         )
    #         if response:
    #             for problem in problem_lst:
    #                 dataset.at[idx, f"is_{problem.lower().replace(' ', '_')}_pred_multi"] = response[problem]["diagnosis"]
    #                 dataset.at[idx, f"is_{problem.lower().replace(' ', '_')}_reasoning_multi"] = response[problem]["reasoning"]
    #         pbar.update(1)
    #     pbar.close()
    #     return dataset



In [None]:
client = OpenAI(api_key="dummy_key", base_url="http://localhost:8000/v1")
df = pd.read_csv(
    '/home/yl3427/cylab/SOAP_MA/data/mergedBioNLP2023.csv',
    usecols=['File ID', 'Subjective', 'Objective', 'Summary', 'cleaned_expanded_Summary', 'terms']
)
df = df.fillna('').apply(lambda x: x.str.lower())
df['combined_summary'] = df['Summary'] + df['cleaned_expanded_Summary'] + df['terms']

mi = ["myocardial infarction", "elevation mi", "non-stemi", " NSTEMI", " stemi"]
chf = ["congestive heart failure", " chf", "HFrEF", "HFpEF"]
pulmonary_embolism = ["pulmonary embolism"]
pulmonary_hypertension = ["pulmonary hypertension", "pulmonary htn"]
sepsis = ["sepsis", "septic shock"]
urosepsis = ["urosepsis"]
meningitis = ["meningitis"]
aki = ["acute kidney injury", " aki", "acute renal failure", " arf"] # -> Acute tubular necrosis (ATN)인가 아닌가
atn = ["acute tubular necrosis", " atn"]
pancreatitis = ["pancreatitis"]
gi_bleed = ["gastrointestinal bleed", "gi bleed"]
hepatitis = ["hepatitis", " hep"]
cholangitis = ["cholangitis"]
asp_pneumonia = ["aspiration pneumonia"]

prob_dict = {'myocardial infarction': mi, 
                'congestive heart failure': chf, 
                'pulmonary embolism': pulmonary_embolism, 
                'pulmonary hypertension': pulmonary_hypertension, 
                'sepsis': sepsis, 
                'urosepsis': urosepsis, 
                'meningitis': meningitis, 
                'acute kidney injury': aki, 
                'acute tubular necrosis': atn, 
                'pancreatitis': pancreatitis, 
                'gastrointestinal bleed': gi_bleed, 
                'hepatitis': hepatitis, 
                'cholangitis': cholangitis, 
                'aspiration pneumonia': asp_pneumonia}

ids = set()
for name, lst in prob_dict.items():
    problem_terms = lst
    problem_terms = [term.lower() for term in problem_terms]

    # Use the first term as the primary term to check in the combined summary.
    primary_term = problem_terms[0]

    # Build a regex pattern that matches any of the problem terms.
    # pattern = '|'.join(problem_terms)
    pattern = '|'.join(re.escape(term) for term in problem_terms)

    mask = (
        df['combined_summary'].str.contains(pattern, na=False) &
        ~df['Subjective'].str.contains(pattern, na=False) &
        ~df['Objective'].str.contains(pattern, na=False)
    )

    filtered_df = df[mask]

    ids.update(filtered_df['File ID'])

agent = Agent(client=client)

df = df[df['File ID'].isin(ids)]
df = df.reset_index(drop=True)

result_df = agent.test_multi_prob(df, list(prob_dict.keys()))
result_df.to_csv("multi_result_full.csv", index=False)

for name, lst in prob_dict.items():
    result_df = agent.test_single_prob(result_df, name)
    result_df.to_csv(f"single_result_{name}.csv", index=False)
result_df.to_csv("single_result_full.csv", index=False)

In [20]:
import pandas as pd
qa_df = pd.read_csv("/home/yl3427/cylab/llm_reasoning/reasoning/data/step1_ALL.csv", encoding="latin-1")
print(len(qa_df))
for idx, row in qa_df.iterrows():
    if idx == 10:
        # print(row['question'] + "\n" + str(row['choice']))
        print(row)

119
ï»¿idx                                                         11
question        A 25-year-old man volunteers to participate in...
choice                     A: , B: , C: , D: , E: , F: , G: , H: 
ground_truth                                                    H
qn_num                                                         11
Name: 10, dtype: object


In [21]:
qa_df

Unnamed: 0,ï»¿idx,question,choice,ground_truth,qn_num
0,1,Serum LDL-cholesterol concentrations are measu...,"A: 105-155, B: 120-140, C: 125-135, D: 128-132...",B,1
1,2,A 48-year-old man dies suddenly of a cardiac a...,"A. Acute inflammation, B. Fibrinous exudate, C...",E,2
2,3,"In a sample of 100 individuals, the mean leuko...","A: 5500-9500/mm3, B: <6500/mm3 or >8500/mm3, C...",D,3
3,4,A 55-year-old woman comes to the clinic becaus...,"A: Candida albicans, B: Chlamydia trachomatis,...",A,4
4,5,A 39-year-old man comes to the physician becau...,"A: Asthma, B: Bronchiectasis, C: Chronic pulmo...",E,5
...,...,...,...,...,...
114,115,A 15-year-old boy is brought to the office by ...,"A: Cytomegalovirus, B: Epstein-Barr virus, C: ...",B,115
115,116,A 14-year-old boy is brought to the emergency ...,"A: BMI, B: Family history, C: Medication use, ...",A,116
116,117,A 38-year-old woman comes to the clinic to dis...,"A: ""As the patient, you really should make any...",B,117
117,118,A 36-year-old woman with hypertension comes to...,"A: Fetal lung/epithelial differentiation, B: F...",C,118
