In [1]:
from tqdm import tqdm
from openai import OpenAI, AsyncOpenAI
import re
from typing import Optional, Union, List, get_origin, get_args, Any, Dict, Literal, Callable
import inspect
# from __future__ import annotations
import asyncio
import json
import logging
import pandas as pd
from pydantic import BaseModel, Field, create_model
import math
logger = logging.getLogger(__name__)
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
        handlers=[
        # logging.FileHandler('log/0327_MA_MedicalQA_mistral_merged.log', mode='a'),  # Write to file
        logging.StreamHandler()                     # Print to console
    ]
)

In [2]:
import json
import logging
logger = logging.getLogger(__name__)

def safe_json_load(s: str) -> any:
    """
    Attempts to parse a JSON string using multiple parsers.
    Order:
    1. json.loads (strict)
    2. demjson3.decode (tolerant)
    3. json5.loads (allows single quotes, unquoted keys, etc.)
    4. dirtyjson.loads (for messy JSON)
    5. jsom (if available)
    6. json_repair (attempt to repair the JSON and parse it)
    
    If all attempts fail, returns the original string.
    """
    # 1. Try standard JSON
    try:
        return json.loads(s)
    except json.JSONDecodeError as e:
        logger.error("Standard json.loads failed: %s", e)
    
    # 2. Try demjson3
    try:
        import demjson3
        logger.info("Attempting to parse with demjson3 as fallback.")
        result = demjson3.decode(s)
        logger.info("demjson3 successfully parsed the JSON.")
        return result
    except Exception as e2:
        logger.error("demjson3 fallback failed: %s", e2)
    
    # 3. Try json5
    try:
        import json5
        logger.info("Attempting to parse with json5 as fallback.")
        result = json5.loads(s)
        logger.info("json5 successfully parsed the JSON.")
        return result
    except Exception as e3:
        logger.error("json5 fallback failed: %s", e3)
    
    # 4. Try dirtyjson
    try:
        import dirtyjson
        logger.info("Attempting to parse with dirtyjson as fallback.")
        result = dirtyjson.loads(s)
        logger.info("dirtyjson successfully parsed the JSON.")
        return result
    except Exception as e4:
        logger.error("dirtyjson fallback failed: %s", e4)
    
    # 5. Try jsom
    try:
        import jsom
        logger.info("Attempting to parse with jsom as fallback.")
        parser = jsom.JsomParser()
        result = parser.loads(s)
        logger.info("jsom successfully parsed the JSON.")
        return result
    except Exception as e5:
        logger.error("jsom fallback failed: %s", e5)
    
    # 6. Try json_repair (attempt to fix the JSON and then load it)
    try:
        import json_repair
        logger.info("Attempting to repair JSON with json_repair as fallback.")
        repaired = json_repair.repair_json(s)
        result = json.loads(repaired)
        logger.info("json_repair successfully parsed the JSON.")
        return result
    except Exception as e6:
        logger.error("json_repair fallback failed: %s", e6)
    
    # All attempts failed; return the original input.
    logger.error("All JSON parsing attempts failed. Returning original input.")
    logger.error("Original input: %s", s)
    return s

test_str = "{'key': 'value" 
parsed = safe_json_load(test_str)
print("Parsed output:", parsed)

2025-04-01 00:16:49 - ERROR - Standard json.loads failed: Expecting property name enclosed in double quotes: line 1 column 2 (char 1)
2025-04-01 00:16:49 - INFO - Attempting to parse with demjson3 as fallback.
2025-04-01 00:16:49 - ERROR - demjson3 fallback failed: String literal is not terminated
2025-04-01 00:16:49 - INFO - Attempting to parse with json5 as fallback.
2025-04-01 00:16:49 - ERROR - json5 fallback failed: <string>:1 Unexpected end of input at column 15
2025-04-01 00:16:49 - INFO - Attempting to parse with dirtyjson as fallback.
2025-04-01 00:16:49 - ERROR - dirtyjson fallback failed: Unterminated string starting at: line 1 column 9 (char 8)
2025-04-01 00:16:49 - INFO - Attempting to parse with jsom as fallback.
/home/yl3427/miniconda3/envs/vllm_env/lib/python3.9/site-packages/jsom/transformer_methods.py:106: SingleQuotedString: Single-quoted string at line 1, column 2 (near "...'key'...")
  warn(
2025-04-01 00:16:49 - ERROR - jsom fallback failed: No terminal matches ''

Parsed output: {'key': 'value'}


In [None]:
class Answer(BaseModel):
    answer: str = Field(..., description="Your answer")

schema = {"guided_json": Answer.model_json_schema()}

client=AsyncOpenAI(base_url="http://localhost:8000/v1", api_key="dummy")
resp = await client.chat.completions.create(
    messages=[
        {"role": "user", "content": "What is the capital of France?"}
    ],
    model="meta-llama/Llama-3.3-70B-Instruct",
    temperature=1.0,
    extra_body = schema
)

In [None]:
resp.choices[0].message

In [None]:
class LLMAgent:
    def __init__(self, system_prompt: str, model_name: str = "meta-llama/Llama-3.3-70B-Instruct",
                 client=AsyncOpenAI(base_url="http://localhost:8000/v1", api_key="dummy")):
        self.model_name = model_name
        self.client = client
        self.messages = [{"role": "system", "content": system_prompt}]
    
    async def __llm_call_with_tools(self, params: dict, available_tools: dict) -> any:
        iter_count = 0
        response = await self.client.chat.completions.create(**params)
        message = response.choices[0].message
          
        while not hasattr(message, "tool_calls") or not message.tool_calls:
            logger.info("No tool calls detected in response.")
      
            self.messages.append({"role": "assistant", "content": message.content})
            self.messages.append({"role": "system", "content": "WARNING! Remember to call a tool."})

            iter_count += 1
            if iter_count > 3:
                logger.error("Exceeded maximum iterations for tool calls.")
                return message.content
            
            response = await self.client.chat.completions.create(**params)
            message = response.choices[0].message

        logger.info("Tool calls detected in response.")
        self.messages.append({"role": "assistant", "tool_calls": message.tool_calls})
        
        for call in message.tool_calls:
            logger.debug(f"Tool call: {call}")
            args = safe_json_load(call.function.arguments)
            result = available_tools[call.function.name](**args) # Assume synchronous for now
            self.messages.append({
                "role": "tool",
                "content": result,
                "tool_call_id": call.id,
                "name": call.function.name,
            })
                    
        response = await self.client.chat.completions.create(**params)
        return response.choices[0].message.content

    async def llm_call(self, user_prompt: str, temperature: float = 0.3,
                       guided_: dict = None,
                       tools_descript: List[dict] = None, available_tools: dict = None) -> Any:
        logger.debug(f"LLMAgent.llm_call() - user_prompt[:60]: {user_prompt[:60]}...")
        self.messages.append({"role": "user", "content": user_prompt})
        params = {
            "model": self.model_name,
            "messages": self.messages,
            "temperature": temperature,
        }
        if guided_:
            logger.debug(f"Guided JSON/choice detected: {guided_}")
            params["extra_body"] = guided_

        if tools_descript:
            params["tools"] = tools_descript
            assert available_tools is not None, "available_tools must be provided if tools_descript is used."
            return await self.__llm_call_with_tools(params, available_tools)
        else:
            params["tool_choice"] = "none"
            response = await self.client.chat.completions.create(**params)
            print("got the response!!!!!!!!!!!!!!!!!")
            return response.choices[0].message.content
    
    def append_message(self, content, role='assistant'):
        logger.debug(f"Appending message with role='{role}' to conversation.")
        self.messages.append({"role": role, "content": content})

# Initate a new Manager agent on every new note
class Manager(LLMAgent):
    def __init__(self, note: str, hadm_id: str, problem: str, label: Literal["Yes", "No"], 
                 n_specialists: Union[int, Literal["auto"]] = 5,  consensus_threshold: float = 0.8,
                 max_consensus_attempts = 3, max_assignment_attempts = 1,
                 static_specialists: Optional[List[object]] = None, 
                 summarizer: Dict[Literal["before_manager", "before_specialist"], Callable[..., any]] = None):

        system_prompt = (
                "You are a Manager agent in a multi-agent AI system designed to handle medical questions.\n"
                f"Your job is to select medical specialists whose expertise best matches the user's query,\n"
                "and to ensure that the specialists reach a consensus on the answer.\n"
            )
        super().__init__(system_prompt)

        self.note = note
        self.hadm_id = hadm_id
        self.problem = problem
        self.label = label
        
        self.status_dict = {"note": note, "hadm_id": hadm_id, "problem": problem, "label": label}
        self.n_specialists = n_specialists
        self.consensus_threshold = consensus_threshold
    
        self.max_consensus_attempts = max_consensus_attempts
        self.consensus_attempts = 0 # round

        self.max_assignment_attempts = max_assignment_attempts
        self.assignment_attempts = 0 # panel

        self.static_specialists = static_specialists
        self.summarizer = summarizer
    
    async def _assign_specialists(self):
        self.assignment_attempts += 1
        self.status_dict[f"panel_{self.assignment_attempts}"] = {}

        if self.n_specialists == "auto":            
            user_prompt = (
                "Here is a 'S' and 'O' from a SOAP note:\n\n"
                f"<SOAP>\n{self.note}\n</SOAP>\n\n"
                f"Based on the information in the SOAP note, we need to confirm if the patient has {self.problem}.\n\n"
                "In order to successfully diagnose on this case, please provide a list of medical specialties that are needed to address this case.")
            
            class Specialties(BaseModel):
                specialties: List[str] = Field(..., description="List of medical specialties needed to address the case.")
            print("Is this schema the problem?: Specialties.model_json_schema()")
            response = await self.llm_call(user_prompt, guided_={"guided_json": Specialties.model_json_schema()})
            self.append_message(content=response)
            specialties_lst = safe_json_load(response)["specialties"]
            logger.debug(f"{len(specialties_lst)} specialties identified: {specialties_lst}")
            self.n_specialists = len(specialties_lst)
        else:
            user_prompt = (
                "Here is a 'S' and 'O' from a SOAP note:\n\n"
                f"<SOAP>\n{self.note}\n</SOAP>\n\n"
                f"Based on the information in the SOAP note, we need to confirm if the patient has {self.problem}.\n\n"
                f"In order to successfully diagnose, please provide a list of {self.n_specialists} medical specialties that are needed to address this case.")
            
            class Specialty(BaseModel):
                name: str = Field(..., description="Name of the medical specialty needed to address the case.")
            specialties_dict = {f"specialty_{i+1}": (Specialty, ...) for i in range(self.n_specialists)}
            Specialties_N = create_model("Specialties_N", **specialties_dict)
            print("Is this schema the problem?: Specialties_N.model_json_schema()")
            response = await self.llm_call(user_prompt, guided_={"guided_json": Specialties_N.model_json_schema()})
            self.append_message(content=response)
            specialties_lst = [safe_json_load(response)[f"specialty_{i+1}"]['name'] for i in range(self.n_specialists)]

        self.status_dict[f"panel_{self.assignment_attempts}"]["Initially Identified Specialties"] = specialties_lst
        self.status_dict[f"panel_{self.assignment_attempts}"]["Collected Specialists"] = {}

        user_prompt = (
            f"Now, based on the specialties you provided ({specialties_lst}), "
            f"please collect a panel of {self.n_specialists} specialists, each of whom is responsible for one of the specialties.\n"
            "For each specialist, please specify their role and a list of relevant expertise areas related to the case.\n"
        )
        class Specialist(BaseModel):
            specialist: str = Field(..., description="The official job title of the specialist.")
            expertise: List[str] = Field(..., description="Areas of expertise for the specialist.")
        panel_dict = {f"specialist_{i+1}": (Specialist, ...) for i in range(self.n_specialists)}
        SpecialistPanel = create_model("SpecialistPanel", **panel_dict)
        print("Is this schema the problem?: SpecialistPanel.model_json_schema()")
        response = await self.llm_call(user_prompt, guided_={"guided_json": SpecialistPanel.model_json_schema()})
        self.append_message(content=response)
        specialists_dict = safe_json_load(response)
        
        for _, specialist in specialists_dict.items():
            role = specialist["specialist"]
            expertise = specialist["expertise"]
            self.status_dict[f"panel_{self.assignment_attempts}"]["Collected Specialists"][role] = {"expertise": expertise, "answer_history": {}} # key of answer_history is round_id
        return self.status_dict
    
    def _check_consensus(self, panel_id:int, round_id:int) -> str:
        choice_counts = {}
        majority_count = math.ceil(self.n_specialists * self.consensus_threshold)

        for role, answ_hist in self.status_dict[f"panel_{panel_id}"]["Collected Specialists"].items():
            final_choice = answ_hist["answer_history"][f"round_{round_id}"]['choice']
            choice_counts[final_choice] = choice_counts.get(final_choice, 0) + 1

        for choice, count in choice_counts.items():
            if count >= majority_count:
                logger.info(f"Consensus found on choice '{choice}' with {count}/{self.n_specialists} specialists.")
                return choice
            
        logger.info("No consensus found.")
        return None
    
    async def _aggregate(self):
        specialists_chat_history = dict(self.status_dict)
        specialists_chat_history.pop("label")
        specialists_chat_history.pop("hadm_id")
        specialists_chat_history.pop("problem")
        specialists_str = json.dumps(specialists_chat_history, indent=4)

        user_prompt = (
            "No consensus was reached among the specialists.\n"
            "Now, you need to analyze each specialist's reasoning and choice, "
            "and provide a single, definitive answer.\n"
            "You have access to the entire conversation history of each specialist.\n"
            "Your task is to read all these conversations in detail and produce one single, definitive final answer.\n"
            "Choose the answer best supported by the specialists' reasoning.\n"
            "\n"
            "Below is the *entire conversation history* for each specialist:\n\n"
            f"{specialists_str}\n\n"
            "Please review all these conversations in detail and produce one single, definitive final answer. "
            "Clearly justify your reasoning, then provide your final recommended answer."
        )

        class AggregatedResponse(BaseModel):
            aggregated_reasoning: str = Field(..., description="Detailed reasoning behind final choice")
            aggregated_choice: Literal["Yes", "No"] = Field(..., description=f"Single recommended choice whether the patient has {self.problem}.")
        print("Is this schema the problem?: AggregatedResponse.model_json_schema()")
        response = await self.llm_call(user_prompt, guided_={"guided_json": AggregatedResponse.model_json_schema()})
        return safe_json_load(response)
    
    async def run(self):

        while self.assignment_attempts < self.max_assignment_attempts:
            logger.info(f"Assignment attempt #{self.assignment_attempts + 1} started.")
            self.status_dict = await self._assign_specialists() # 여기서 self.assignment_attempts += 1
            
            panel = []
            for role in self.status_dict[f"panel_{self.assignment_attempts}"]["Collected Specialists"].keys():
                panel.append(DynamicSpecialist(role, self.status_dict[f"panel_{self.assignment_attempts}"]["Collected Specialists"][role]))
                
            
            analyze_tasks = [asyncio.create_task(specialist.analyze_note(self.note, self.problem)) for specialist in panel]
            analyze_results = await asyncio.gather(*analyze_tasks)
            
            if any(r is None for r in analyze_results):
                logger.error("At least one specialist failed; skipping this panel.")
                continue
            # analyze_results는 이 이후로 딱히 안쓰임

            self.consensus_attempts += 1
            consensus_choice = self._check_consensus(self.assignment_attempts, self.consensus_attempts)
            if consensus_choice:
                self.status_dict["final"] = {"final_choice": consensus_choice, "final_reasoning": "Consensus reached"}
                return self.status_dict
            
            while self.consensus_attempts < self.max_consensus_attempts:                
                debate_tasks = [asyncio.create_task(specialist.debate(self.status_dict[f"panel_{self.assignment_attempts}"]["Collected Specialists"])) for specialist in panel]
                debate_results = await asyncio.gather(*debate_tasks)
                if any(r is None for r in debate_results):
                    logger.error("At least one specialist failed during debate; skipping this round.")
                    continue
                # debate_results는 이 이후로 딱히 안쓰임

                self.consensus_attempts += 1
                consensus_choice = self._check_consensus(self.assignment_attempts, self.consensus_attempts)
                if consensus_choice:
                    self.status_dict["final"] = {"final_choice": consensus_choice, "final_reasoning": "Consensus reached"}
                    return self.status_dict
                
            logger.info("No consensus reached after maximum consensus attempts among the panel.")
          
        logger.info("No consensus reached after maximum assignment attempts.")
        aggregated_response = await self._aggregate()
        self.status_dict["final"] = {
            "final_choice": aggregated_response["aggregated_choice"],
            "final_reasoning": aggregated_response["aggregated_reasoning"]
        }
        return self.status_dict
    

class DynamicSpecialist(LLMAgent):
    def __init__(self, specialist: str, status: dict): 
        self.specialist = specialist
        # {"expertise": expertise, "answer_history": {}} # key of answer_history is panel_id + round_id
        self.expertise = status["expertise"]
        self.answer_history = status["answer_history"]
        self.round_id = 0
        system_prompt = (
            f"You are a {self.specialist}.\n"
            f"Your expertise includes:\n{self.expertise}\n"
            f"Analyze the user's query from the perspective of a {self.specialist}."
        )
        super().__init__(system_prompt)

    async def analyze_note(self, note: str, problem: str):
        self.round_id += 1
        self.note = note
        self.problem = problem

        class Response(BaseModel):
            reasoning: str = Field(..., description="Step-by-step reasoning leading to the final choice")
            choice: Literal["Yes", "No"] = Field(..., description=f"Final choice (Whether the patient has {self.problem})")
        self.schema = Response.model_json_schema()

        user_prompt = (
            "Here is a 'S' and 'O' from a SOAP note:\n\n"
            f"<SOAP>\n{self.note}\n</SOAP>\n\n"
            "Based on the information in the SOAP note, we need to confirm if the patient has the following problem:\n\n"
            f"<Problem>\n{self.problem}\n</Problem>\n\n"
            f"From your perspective as a {self.specialist}, first provide step-by-step reasoning (rationale), "
            "and then clearly state your final answer (Yes or No).\n\n"
        )

        response = await self.llm_call(user_prompt, guided_={"guided_json": self.schema})
        self.append_message(content=response)
        self.answer_history[f"round_{self.round_id}"] = safe_json_load(response)
        return safe_json_load(response)
    
    async def debate(self, stepback_status: dict):
        self.round_id += 1
        other_specialists = {}
        for role, value in stepback_status.items():
            if role != self.specialist:
                other_specialists[role] = value["answer_history"][f"round_{self.round_id - 1}"]

        formatted_other_specialists = json.dumps(other_specialists, indent=4)
        user_prompt = (
            "Regarding the previous query, other specialists have also provided their reasoning and choices.\n"
            "Critically evaluate the reasoning and choice of those specialists.\n\n"
            f"Specialists and their choices:\n{formatted_other_specialists}\n\n"
            "Considering the newly provided perspectives, refine your own reasoning and choice.\n"
            "You can change your choice or stick with the original one.\n\n"
        )

        response = await self.llm_call(user_prompt, guided_={"guided_json": self.schema})
        self.append_message(content=response)
        self.answer_history[f"round_{self.round_id}"] = safe_json_load(response)
        return safe_json_load(response)

In [6]:
note_text = "title: 24 hour events: - family meeting - continued discussion re: goals of care, slowly moving toward cmo, for now dnr - ps 10/5 trial, sao2 in mid 90s, abg looked good - in pm vomited, so tfs stopped - uop continues to be low, stopped dopamine drip - urine cx - yeast --> changed foley allergies: no known drug allergies\nlast dose of antibiotics: vancomycin -  12:00 am piperacillin -  12:00 am levofloxacin -  10:00 pm piperacillin/tazobactam (zosyn) -  06:00 am infusions: fentanyl (concentrate) - 25 mcg/hour midazolam (versed) - 0.5 mg/hour other icu medications: heparin sodium (prophylaxis) -  12:00 am other medications: changes to medical and family history: review of systems is unchanged from admission except as noted below review of systems: flowsheet data as of   07:33 am vital signs hemodynamic monitoring fluid balance 24 hours since 12 am tmax: 36.8 c (98.3 tcurrent: 36.4 c (97.6 hr: 99 (58 - 118) bpm bp: 113/60(73) {84/44(54) - 132/74(87)} mmhg rr: 28 (14 - 30) insp/min spo2: 100% heart rhythm: sr (sinus rhythm) cvp: 14 (14 - 14)mmhg total in: 2,525 ml 853 ml po: tf: 1,082 ml 340 ml ivf: 893 ml 283 ml blood products: total out: 572 ml 112 ml urine: 572 ml 112 ml ng: stool: drains: balance: 1,953 ml 741 ml respiratory support o2 delivery device: endotracheal tube ventilator mode: cpap/psv vt (set): 450 (450 - 450) ml vt (spontaneous): 375 (298 - 400) ml ps : 12 cmh2o rr (set): 14 rr (spontaneous): 33 peep: 5 cmh2o fio2: 40% rsbi: 127 pip: 18 cmh2o spo2: 100% abg: 7.39/38/98./21/-1 ve: 11.6 l/min pao2 / fio2: 245 physical examination general appearance: thin, intubated and sedated eyes / conjunctiva: perrl head, ears, nose, throat: normocephalic, endotracheal tube cardiovascular: (pmi normal), (s1: normal), (s2: normal) respiratory / chest: (expansion: symmetric), (breath sounds: clear : anterior, wheezes : ) abdominal: soft, non-tender, bowel sounds present, moderately distended, peg tube in place extremities: right: absent, left: absent skin:  warm neurologic: unable to assess due to sedation labs / radiology 317 k/ul 8.6 g/dl 117 mg/dl 0.9 mg/dl 21 meq/l 4.0 meq/l 11 mg/dl 111 meq/l 141 meq/l 25.9 % 8.6 k/ul [image002.jpg]   02:54 am   03:46 am   03:42 am   08:07 pm   03:59 am   04:17 am   04:23 pm   03:52 am   04:58 pm   04:27 am wbc 9.0 8.5 11.2 8.1 9.6 7.2 8.6 hct 29.0 27.3 29.1 24.1 26.7 25.2 25.9 plt 47 295 301 317 cr 0.9 0.8 0.9 0.9 1.0 0.9 0.9 tropt 0.03 tco2 28 22 24 glucose 136 117 77 96 104 131 117 other labs: pt / ptt / inr:11.7/31.0/1.0, ck / ckmb / troponin-t:28/3/0.03, alt / ast:, alk phos / t bili:70/0.5, differential-neuts:83.4 %, lymph:9.2 %, mono:3.5 %, eos:3.4 %, lactic acid:2.5 mmol/l, albumin:2.5 g/dl, ldh:139 iu/l, ca++:7.4 mg/dl, mg++:1.8 mg/dl, po4:1.6 mg/dl"
hadm_id = "123"
problem = "sepsis"
label = "Yes"

manager = Manager(
    note=note_text,
    hadm_id=hadm_id,
    problem=problem,
    label=label,
    n_specialists=5,            # or "auto"
    consensus_threshold=0.8,
    max_consensus_attempts=3,
    max_assignment_attempts=1,
    static_specialists=None,    # Or any custom specialists
    summarizer=None
)

# Run the manager's workflow
result = await manager.run()
print("Final Manager Result:")
print(result)


2025-04-01 00:26:12 - INFO - Assignment attempt #1 started.


Is this schema the problem?: Specialties_N.model_json_schema()


2025-04-01 00:26:32 - INFO - HTTP Request: POST http://localhost:8000/v1/chat/completions "HTTP/1.1 200 OK"


got the response!!!!!!!!!!!!!!!!!
Is this schema the problem?: SpecialistPanel.model_json_schema()


2025-04-01 00:26:54 - INFO - HTTP Request: POST http://localhost:8000/v1/chat/completions "HTTP/1.1 200 OK"


got the response!!!!!!!!!!!!!!!!!


2025-04-01 00:27:44 - INFO - HTTP Request: POST http://localhost:8000/v1/chat/completions "HTTP/1.1 200 OK"


got the response!!!!!!!!!!!!!!!!!


2025-04-01 00:27:46 - INFO - HTTP Request: POST http://localhost:8000/v1/chat/completions "HTTP/1.1 200 OK"


got the response!!!!!!!!!!!!!!!!!


2025-04-01 00:27:50 - INFO - HTTP Request: POST http://localhost:8000/v1/chat/completions "HTTP/1.1 200 OK"


got the response!!!!!!!!!!!!!!!!!


2025-04-01 00:27:54 - INFO - HTTP Request: POST http://localhost:8000/v1/chat/completions "HTTP/1.1 200 OK"


got the response!!!!!!!!!!!!!!!!!


2025-04-01 00:27:56 - INFO - HTTP Request: POST http://localhost:8000/v1/chat/completions "HTTP/1.1 200 OK"
2025-04-01 00:27:56 - INFO - Consensus found on choice 'Yes' with 4/5 specialists.


got the response!!!!!!!!!!!!!!!!!
Final Manager Result:
{'note': 'title: 24 hour events: - family meeting - continued discussion re: goals of care, slowly moving toward cmo, for now dnr - ps 10/5 trial, sao2 in mid 90s, abg looked good - in pm vomited, so tfs stopped - uop continues to be low, stopped dopamine drip - urine cx - yeast --> changed foley allergies: no known drug allergies\nlast dose of antibiotics: vancomycin -  12:00 am piperacillin -  12:00 am levofloxacin -  10:00 pm piperacillin/tazobactam (zosyn) -  06:00 am infusions: fentanyl (concentrate) - 25 mcg/hour midazolam (versed) - 0.5 mg/hour other icu medications: heparin sodium (prophylaxis) -  12:00 am other medications: changes to medical and family history: review of systems is unchanged from admission except as noted below review of systems: flowsheet data as of   07:33 am vital signs hemodynamic monitoring fluid balance 24 hours since 12 am tmax: 36.8 c (98.3 tcurrent: 36.4 c (97.6 hr: 99 (58 - 118) bpm bp: 113/60(

In [7]:
result

{'note': 'title: 24 hour events: - family meeting - continued discussion re: goals of care, slowly moving toward cmo, for now dnr - ps 10/5 trial, sao2 in mid 90s, abg looked good - in pm vomited, so tfs stopped - uop continues to be low, stopped dopamine drip - urine cx - yeast --> changed foley allergies: no known drug allergies\nlast dose of antibiotics: vancomycin -  12:00 am piperacillin -  12:00 am levofloxacin -  10:00 pm piperacillin/tazobactam (zosyn) -  06:00 am infusions: fentanyl (concentrate) - 25 mcg/hour midazolam (versed) - 0.5 mg/hour other icu medications: heparin sodium (prophylaxis) -  12:00 am other medications: changes to medical and family history: review of systems is unchanged from admission except as noted below review of systems: flowsheet data as of   07:33 am vital signs hemodynamic monitoring fluid balance 24 hours since 12 am tmax: 36.8 c (98.3 tcurrent: 36.4 c (97.6 hr: 99 (58 - 118) bpm bp: 113/60(73) {84/44(54) - 132/74(87)} mmhg rr: 28 (14 - 30) insp/

In [None]:
!pip install --upgrade outlines

In [None]:
async def process_single_query(
    question_text: str,
    ground_truth: str,
    choices: List[str],
    n_specialists: int) -> Dict[str, Any]:
    """
    Given a single query (question + ground_truth + multiple choices), 
    run the multi-agent system (Initializer -> Specialists -> Debates -> Aggregator if needed).
    Return the final dictionary containing all the specialists' output and aggregator results.
    """

    # 1. Initialize specialists
    initializer = Manager(n_specialists=n_specialists)
    json_resp = await initializer.identify_specialists(query=question_text)
    if not isinstance(json_resp, dict):
        logger.error("Invalid JSON output from initializer; skipping this query.")
        return {}  # Skip processing and continue to the next query

    # Build specialists status dict
    specialists_status = {}
    for _, agent_info in json_resp.items():
        specialist_name = agent_info["specialist"]
        expertise = agent_info["expertise"]
        specialists_status[specialist_name] = {"expertise": expertise}
    
    # 2. Run analyze_query for each specialist in parallel
    async def analyze_specialist(specialist_name: str, status: Dict[str, Any], query: str, choices: List[str]):
        specialist_agent = Specialist(specialist=specialist_name, expertise=status["expertise"])
        status["instance"] = specialist_agent
        message = await specialist_agent.analyze_query(query=query, choices=choices)
        if not isinstance(message, dict):
            logger.error(f"[{specialist_name}] Invalid JSON output from specialist; skipping this specialist.")
            return None
        status["original_response"] = message
        logger.info(f"[{specialist_name}] Completed analyze_query.")
        return specialist_name

    analyze_tasks = [
        asyncio.create_task(analyze_specialist(name, status, question_text, choices))
        for name, status in specialists_status.items()
    ]
    analyze_results = await asyncio.gather(*analyze_tasks)
    if any(r is None for r in analyze_results):
        logger.error("At least one specialist failed; skipping this query.")
        return {}  # Skip processing and continue to the next query

    # Build a minimal dictionary for debate (remove 'instance')
    input_specialists_dict = {
        specialist_name: {
            k: v for k, v in specialist_data.items() 
            if k != "instance"
        }
        for specialist_name, specialist_data in specialists_status.items()
    }

    # 3. Debate step, also in parallel
    async def debate_specialist(specialist_name: str, status: Dict[str, Any], specialists_dict: Dict[str, Any]):
        specialist_agent = status["instance"]
        message = await specialist_agent.debate(specialists_dict)
        if not isinstance(message, dict):
            logger.error(f"[{specialist_name}] Invalid JSON output during debate; skipping this specialist.")
            return None
        status["response_after_debate"] = message
        specialists_dict[specialist_name]["response_after_debate"] = message
        logger.info(f"[{specialist_name}] Completed debate.")
        return specialist_name

    debate_tasks = [
        asyncio.create_task(debate_specialist(name, status, input_specialists_dict))
        for name, status in specialists_status.items()
    ]
    debate_results = await asyncio.gather(*debate_tasks)
    if any(r is None for r in debate_results):
        logger.error("At least one specialist failed during debate; skipping this query.")
        return {}  # Skip processing and continue to the next query

    # 4. Check consensus
    consensus_choice = check_consensus(input_specialists_dict)
    aggregator_result = None

    if consensus_choice is not None:
        logger.info(f"Consensus reached: {consensus_choice}")
        input_specialists_dict["Aggregator"] = {
            "final_choice": consensus_choice, 
            "final_reasoning": "Consensus reached"
        }
    else:
        logger.info("No consensus reached; enabling aggregator path...")
        aggregator = AggregatorAgent()
        aggregated_response = await aggregator.aggregate(
            query=question_text,
            choices=choices,
            specialists_chat_history=input_specialists_dict
        )
        if not isinstance(aggregated_response, dict):
            logger.error("Invalid JSON output from aggregator; skipping this query.")
            return {}  # Skip processing and continue to the next query
        
        final_choice = aggregated_response['aggregated_choice']
        final_reasoning = aggregated_response['aggregated_reasoning']

        logger.info(f"Aggregator final choice: {final_choice}")
        logger.info(f"Aggregator reasoning: {final_reasoning}")

        aggregator_result = {
            "final_choice": final_choice,
            "final_reasoning": final_reasoning
        }
        input_specialists_dict["Aggregator"] = aggregator_result

    # Add question and ground_truth for reference
    input_specialists_dict["Question"] = question_text
    input_specialists_dict["Answer"] = ground_truth

    return input_specialists_dict


async def process_multiple_queries(
    qa_df: pd.DataFrame,
    choices: List[str],
    n_specialists: int,
    max_concurrency: int = 5
) -> List[Dict[str, Any]]:
    """
    Process multiple rows (queries) in `qa_df` asynchronously.
    Each row is passed to `process_single_query`.
    
    :param qa_df: DataFrame with columns ["question", "choice", "ground_truth"] at least.
    :param choices: A list of all possible answer choices, e.g. ["A", "B", "C", "D", "E"].
    :param n_specialists: Number of specialists to initialize for each query.
    :param max_concurrency: Limit on how many queries to process simultaneously.
    :return: A list of result dictionaries, one per row in `qa_df`.
    """

    # This semaphore keeps at most `max_concurrency` tasks running at once
    semaphore = asyncio.Semaphore(max_concurrency)

    async def run_single_query(row_idx: int, row: pd.Series):
        """
        This inner function is used to call `process_single_query` with concurrency control.
        """
        async with semaphore:
            logger.info(f"Starting row {row_idx}")
            question_text = row["question"] + "\n" + str(row["choice"])
            ground_truth = str(row["ground_truth"])
            result = await process_single_query(
                question_text=question_text,
                ground_truth=ground_truth,
                choices=choices,
                n_specialists=n_specialists
            )
            logger.info(f"Finished row {row_idx}")
            return result

    tasks = [
        asyncio.create_task(run_single_query(i, row))
        for i, row in qa_df.iterrows()
    ]

    # Wait for all tasks to complete
    all_results = await asyncio.gather(*tasks)

    # `all_results` is a list of return values from each `run_single_query`
    return all_results

async def main():

    logger.info("===== MAIN START =====")

    # Example CSV loading
    # df_path = "/home/yl3427/cylab/SOAP_MA/Input/step1_ALL.csv"
    df_path = "/home/yl3427/cylab/SOAP_MA/Input/filtered_merged_QA.csv"
    qa_df = pd.read_csv(df_path, lineterminator='\n')  # columns: idx, question, choice, ground_truth, qn_num

    # qa_df = pd.read_csv('/home/yl3427/cylab/SOAP_MA/Input/SOAP_5_problems.csv')
    logger.info("Loaded dataframe with %d rows.", len(qa_df))


    ################# 'process_single_query' Example usage #################
    results = []
    for idx, row in qa_df.iterrows():
        # if row["qn_num"] not in [13, 42]:
        # if row["File ID"] not in ['123147.txt']:
        #     continue

        logger.info(f"Processing row index {idx}")

        question_text = row["question"] + "\n" + str(row["choice"])
        ground_truth = str(row["ground_truth"])
        # patient_info = str(row["Subjective"]) + "\n" + str(row['Objective'])
        # question_text = f"""
        # Based on the following patient report, does the patient have congestive heart failure?"

        # {patient_info}
        # """
        # ground_truth = str(row["terms"])
        

        # Run the multi-agent system for this single query
        result_dict = await process_single_query(
            question_text=question_text,
            ground_truth=ground_truth,
            choices=["A", "B", "C", "D", "E"],
            # choices=["Yes", "No"],
            n_specialists=5
        )
        # result_dict["File ID"] = row["File ID"]
        result_dict["qn_num"] = row["qn_num"]

        # Store result for later evaluation
        results.append(result_dict)

        if idx % 10 == 0:
            output_json_path = f"/home/yl3427/cylab/SOAP_MA/Output/MedicalQA/merged_{idx}_mistral.json"
            with open(output_json_path, "w", encoding="utf-8") as f:
                json.dump(results, f, indent=2, ensure_ascii=False)
            logger.info(f"Saved aggregated results to {output_json_path}")

    # OPTIONAL: Save results to JSON
    output_json_path = "/home/yl3427/cylab/SOAP_MA/Output/MedicalQA/merged_final_mistral.json"
    with open(output_json_path, "w", encoding="utf-8") as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    logger.info(f"Saved aggregated results to {output_json_path}")


    logger.info("===== MAIN END =====")

In [None]:
await main()