# **CLD-MEC**
Clinical Linguistics Detection-Medical Error Correction

In [1]:
from dotenv import load_dotenv

load_dotenv()

True

In [2]:
from pydantic import BaseModel, Field
from typing import Any, Dict, Type
from glob import glob

In [3]:
import json
from openai import OpenAI

client = OpenAI()

In [4]:
models = {"gpt4": "gpt-4-0125-preview", "gpt3": "gpt-3.5-turbo-0125"}

In [5]:
def to_openai_tool(pydantic_class: Type[BaseModel]) -> Dict[str, Any]:
    """Convert pydantic class to OpenAI tool."""
    schema = pydantic_class.schema()
    function = {
        "name": schema["title"],
        "description": schema["description"],
        "parameters": pydantic_class.schema(),
    }
    return {"type": "function", "function": function}

In [6]:
class PreProcessedNote(BaseModel):
    """Preprocess a clinical note by deleting the sentence that shows the cause and diagnosis."""
    label: int = Field(..., description="The label of the note. Binary flag of zero (note was not preprocessed) or one (note was preprocessed).")
    deleted_sentence: str = Field(
        ..., description="The sentence that was deleted from the note. Could be an empty string if no sentence was deleted.")
    full_final_note: str = Field(...,
                                 description="The final note after preprocessing.")

In [7]:
def preprocess_note(clinical_note: str) -> str:
    tools = [to_openai_tool(PreProcessedNote)]
    response = client.chat.completions.create(
        model=models["gpt4"],  # change this for the prompt
        messages=[
            {
                "role": "system",
                "content": """
                I will give you a clinical note, you have to delete the shotest sentence that shows the cause or diagnosis, following to these conditions:
                1) If the clinical note mentions any of clincal management actions (treatment, clinical care plan, or any intervention,....ect) related to ( management of past medical history, management history of present illness, diagnosis), then do not delete anything. Give this label 0.
                2) Else, then delete the sentence that shows the cause and diagnosis. Give this label 1
                3) Print the assigned labels 1 or 0.
                4) Print the deleted part if applicable.
                5) Print the full final note.
                """,
            },
            {
                "role": "user",
                "content": clinical_note,
            },
        ],
        seed=42,
        tools=tools,
    )
    response_message = response.choices[0].message
    tool_calls = response_message.tool_calls
    outputs = []
    for tool_call in tool_calls:
        function_call = tool_call.function
        # validations to get passed mypy
        assert function_call is not None
        assert function_call.name is not None
        assert function_call.arguments is not None

        name = function_call.name
        arguments_str = function_call.arguments

        if isinstance(function_call.arguments, dict):
            output = PreProcessedNote.model_validate(function_call.arguments)
        else:
            output = PreProcessedNote.model_validate_json(
                function_call.arguments)

        outputs.append(output.model_dump_json())
        return outputs[0]

In [8]:
def cot(preprocessed_note: str) -> str:
    response = client.chat.completions.create(
        model=models["gpt4"],  # change this for the prompt
        response_format={"type": "json_object"},
        messages=[
            {
                "role": "system",
                "content": """
                1)Based on Evidance-Based Medicine, use step-by-step deduction to create a differential diagnosis and then use step by step deduction to identify both of the most likly causing (Pathogen {name of the bacteria, worm, virus, fungi,....etc.}, poison,.... etc) and diagnosis separately. The answer should also be definitive to one cause and one diagnosis Without requiring any further clinical investigating action.
                2) Then, step by step, deduce the most correct (treatment, clinical care plan, clinical management, intervention. )
                You are designed to output JSON.
                The JSON should be structured like this:
                {
                "Differential Diagnosis Step by Step": {
                    "Step 1": ...,
                    "Step 2": ...,
                    "Step N": ...
                    },
                "Differential Diagnosis": { 
                    "Most Likely Cause": ...,
                    "Explanation": ...
                    },
                "Treatment Step by Step": {
                    "Step 1": ...,
                    "Step 2": ...,
                    "Step N": ...
                    },
                "Definitive Diagnosis": ...,
                "Treatment": {
                    "Definitive Treatment": ...
                    }
                }
                """,
            },
            {
                "role": "user",
                "content": preprocessed_note,
            },
        ],
        seed=42,
    )
    return str(json.loads(response.choices[0].message.content))

In [9]:
def keyword(cot: str, clinical_note: str) -> dict:
    # step 1 01
    response = client.chat.completions.create(
        model=models["gpt4"],  # change this for the prompt
        response_format={"type": "json_object"},
        messages=[
            {
                "role": "system",
                "content": f"""
                1) Use this interpretable clinical reasoning rationale you have produced for this clinial note: 
                {cot}
                2) Based on the interpretable clinical reasoning rationale, If the clinical note mentions a diagnosis or a medical condition that is based on a clinal presentation or findings that are not directly connected to each other in most common clinical contexts, then there should be a medical error in the diagnosis.
                3) Delete the diagnosis or a medical condition related keyword from the clinical note.
                4) Print the deleted keyword if applicable.
                5) Print the full final note, where the deleted keyword should be masked with this label -> "0"
                You are designed to output JSON.
                It has to be structured like this:
                {{
                "DeletedKeyword": ...,
                "FullFinalNote": ...
                }}
                """,
            },
            {"role": "user", "content": clinical_note},
        ],
        seed=42,
    )
    return json.loads(response.choices[0].message.content)["FullFinalNote"]

In [10]:
def cot_no_keyword(removed_keyword: str) -> str:
    # step 1 02
    # step 1 02
    response = client.chat.completions.create(
        model=models["gpt4"],  # change this for the prompt
        response_format={"type": "json_object"},
        messages=[
            {
                "role": "system",
                "content": """
                1)Based on Evidance-Based Medicine, use step-by-step deduction to create a differential diagnosis and then use step by step deduction to identify both of the most likly causing (Pathogen {name of the bacteria, worm, virus, fungi,....etc.}, poison,.... etc) and diagnosis separately. The answer should also be definitive to one cause and one diagnosis Without requiring any further clinical investigating action.
                2) Then, step by step, deduce the most correct (treatment, clinical care plan, clinical management, intervention. )
                You are designed to output JSON.
                The JSON should be structured like this:
                {
                "Differential Diagnosis Step by Step": {
                    "Step 1": ...,
                    "Step 2": ...,
                    "Step N": ...
                    },
                "Differential Diagnosis": { 
                    "Most Likely Cause": ...,
                    "Explanation": ...
                    },
                "Treatment Step by Step": {
                    "Step 1": ...,
                    "Step 2": ...,
                    "Step N": ...
                    },
                "Definitive Diagnosis": ...,
                "Treatment": {
                    "Definitive Treatment": ...
                    }
                }
                """,
            },
            {
                "role": "user",
                "content": removed_keyword,
            },
        ],
        seed=42,
    )

    return str(json.loads(response.choices[0].message.content))

In [12]:
def verify_cot(cot_round_two: str) -> str:
    response = client.chat.completions.create(
        model=models["gpt4"],
        response_format={"type": "json_object"},
        messages=[
            {
                "role": "system",
                "content": """
                you have to verify your interpretable clinical reasoning rationale of diagnosis you have produced of its related clinial note. the verification should be by genirating questions that target and retive information from the most apprpriate clinical practice guidelines.
                -make the query adress the name of the guidline you want to retrive that response from.
                -if you want ro check fro the diagnosis clinical findings, make the query adress the related clinical findigs you want to check for the diagnosis.
                -make the directed query adress the most liky correct (cause, diagnosis).
                -make the direced query adress the recommendations part of the guidline related to (diagnosis, clinical management, treatment, drug of choice)
                -search from the directed guidlines.
                -returt the information you gained.
                -compare your interpretable clinical reasoning rationale with the retrived information from the guidline, if there is discrepency, show it.
                -if there is a major discrepency, take the retrived information as ground truth and print out the final COT after being revised.
                You are designed to output JSON.
                It has to be structured like this:
                {{
                "VerificationQueries": {
                    "Query 1": ...,
                    "Query 2": ...,
                    "Query 3": ...,
                    "Query N": ...
                },
                "RetrievedInformation": {
                    "Response 1": ...,
                    "Response 2": ...,
                    "Response 3": ...,
                    "Response N": ...
                },
                "Comparison": {
                    "Clinical Findings": ...,
                    "Causes": ...,
                    "Treatment": ...
                },
                "Discrepancy": ... (could be nullable),
                "FinalCOT": {
                    "Differential Diagnosis Process": {
                    "Step 1": ...,
                    "Step 2": ...,
                    "Step 3": ...,
                    "Step N": ...
                    },
                    "Definitive Cause": {
                    "Most Likely Pathogen/Cause": ...
                    },
                    "Definitive Diagnosis": ...,
                    "Treatment Plan": {
                    "Step 1": ...,
                    "Step 2": ...,
                    "Step 3": ...,
                    "Step 4": ...,
                    "Step N": ...
                    }
                }
                }}
                You are designed to output JSON.
                """,
            },
            {
                "role": "user",
                "content": cot_round_two,
            },
        ],
        seed=42,
    )

    return str(json.loads(response.choices[0].message.content))

In [13]:
class FinalRevision(BaseModel):
    """Revise and correct a clinical note based on the clinical reasoning rationale."""

    error_flag: int = Field(
        ...,
        description="The error flag of the note. Binary flag of zero (note was not revised) or one (note was revised).",
    )
    error_location: int = Field(
        ...,
        description="The location of the error in the note. Could be an empty string if no error was found. The note is split into sentences with an index for each, return that index.",
    )
    sentence_correction: str = Field(
        ...,
        description="The corrected sentence based on the clinical reasoning rationale.",
    )

In [14]:
def revise_and_correct(verified_cot: str, clinical_note: str) -> str | dict:
    tools = [to_openai_tool(FinalRevision)]
    response = client.chat.completions.create(
        model=models["gpt4"],  # change this for the prompt
        messages=[
            {
                "role": "system",
                "content": f"""
                1) Use this interpretable clinical reasoning rationale you have produce as a ground truth
                {verified_cot}
                2) compare if the clinical note match the ground truth to tell if the clinical note has a medical error in (diagnosis (pathogen, poison, disease), clinical manageent (treatment, clinical care plan, intervention (oreder certaint lab test, tranfer, certain image by name, procesure).).
                3) Identify any dicrepency between the ground truth and the clinical note.
                4) then if there is any thing in the clinical note related to either diagnosis or cause  that is not available (referenced) in the groung truth reference, then label it as medical error. and skip the steps related to clinical management.
                5) then else if there is any thing in the clinical note related to clinial managemnt after diagnosis is not available (referenced) in the groung truth reference specifically in (clinical management related sectons), then label it as medical error.  and skip the steps related to the diagnosis or cause.
                If there is a medical error, identify it's type (diagnosis, cause, or clinical management) and print it, identify the exact related shotest part and print it, and correct it with the shortest possiple correction. do not change the format of the correced part only correct the relaed keyword.
                Then if the error type is erelated to clincal management related errors, the corrected sentance should be definative to the exact needed medication, procesure, image,..... ect. not general. not as a recommendation. correct the note directly with the most corret propable needed audit.
                If the error type related to diagnosis, cause, or clinical management consider this error correction to be edited on the final corrected note. the priority to add the correction of diagnosis and cause first to be considered. consider one correction only, depend on the context.
                finally print out the corrected final note.

                The clinical note you have to correct is split into sentences with an index for each.
                The correction you return includes the error flag, the error location, and the sentence correction.

                """,
            },
            {
                "role": "user",
                "content": clinical_note,
            },
        ],
        seed=42,
        tools=tools,
    )
    response_message = response.choices[0].message
    tool_calls = response_message.tool_calls
    if tool_calls:
        outputs = []
        for tool_call in tool_calls:
            function_call = tool_call.function
            # validations to get passed mypy
            assert function_call is not None
            assert function_call.name is not None
            assert function_call.arguments is not None

            name = function_call.name
            arguments_str = function_call.arguments

            if isinstance(function_call.arguments, dict):
                output = FinalRevision.model_validate(function_call.arguments)
            else:
                output = FinalRevision.model_validate_json(
                    function_call.arguments)

            outputs.append(output.model_dump_json())
            return outputs[0]
    else:
        return {"error_flag": 0, "error_location": -1, "sentence_correction": "NA"}

In [15]:
def correct_note(clinical_note: str, clinical_note_sentences: str) -> str | dict:
    preprocessed_note = json.loads(preprocess_note(clinical_note))[
        "full_final_note"]
    cot_round_1 = cot(preprocessed_note)
    removed_keyword = keyword(cot_round_1, clinical_note)
    cot_round_two = cot_no_keyword(removed_keyword)
    verified_cot = verify_cot(cot_round_two)
    final_note = revise_and_correct(verified_cot, clinical_note_sentences)
    return final_note

In [16]:
import pandas as pd

In [17]:
test = pd.read_csv(
    "March-26-2024-MEDIQA-CORR-Official-Test-Set.csv", encoding='unicode_escape')

In [18]:
batch_1 = [test.iloc[i, :].values.tolist() for i in range(0, 100)]
batch_2 = [test.iloc[i, :].values.tolist() for i in range(100, 200)]
batch_3 = [test.iloc[i, :].values.tolist() for i in range(200, 300)]
batch_4 = [test.iloc[i, :].values.tolist() for i in range(300, 400)]
batch_5 = [test.iloc[i, :].values.tolist() for i in range(400, 500)]
batch_6 = [test.iloc[i, :].values.tolist() for i in range(500, 600)]
batch_7 = [test.iloc[i, :].values.tolist() for i in range(600, 700)]
batch_8 = [test.iloc[i, :].values.tolist() for i in range(700, 800)]
batch_9 = [test.iloc[i, :].values.tolist() for i in range(800, 900)]
batch_10 = [test.iloc[i, :].values.tolist() for i in range(900, 925)]

In [19]:
assert len(batch_1) + len(batch_2) + len(batch_3) + len(batch_4) + len(batch_5) + \
    len(batch_6) + len(batch_7) + len(batch_8) + len(batch_9) + len(batch_10) == len(test)

925

In [21]:
import concurrent.futures
import logging

In [22]:
failed = []

In [23]:
def process_batch(batch) -> None:
    for item in batch:
        try:
            result = correct_note(item[1], item[2])
            with open(f"corrected_notes/{item[0]}.txt", "w") as f:
                f.write(f"{result}\n")
        except Exception as e:
            failed.append(item[0])
            print(f"{e} | {item[0]}")

In [None]:
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
    batches = [batch_1, batch_2, batch_3, batch_4, batch_5,
               batch_6, batch_7, batch_8, batch_9, batch_10]
    futures = {executor.submit(process_batch, batch)               : i for i, batch in enumerate(batches, start=1)}

    for future in concurrent.futures.as_completed(futures):
        batch_num = futures[future]
        try:
            future.result()
        except Exception as exc:
            logging.error(f'Batch_{batch_num} generated an exception: {exc}')
        else:
            logging.info(f'Batch_{batch_num} is complete')

In [None]:
corrected_notes = glob("corrected_notes/*.txt")

In [None]:
assert len(corrected_notes) == len(test)

In [None]:
notes = {}

In [None]:
for path in corrected_notes:
    id = path.split("/")[-1].split(".")[0]
    with open(path) as f:
        notes[id] = eval(f.read())

In [None]:
for note in notes:
    if notes[note]["error_flag"] == 0:
        notes[note]["error_location"] = -1
        notes[note]["sentence_correction"] = "NA"

In [None]:
with open("prediction.txt", "w") as f:
    for note in notes:
        f.write(
            f"{note} {notes[note]['error_flag']} {notes[note]['error_location']} \"{notes[note]['sentence_correction']}\"\n")