In [1]:
import os
import pandas as pd
#set the openai env variable here

In [2]:

api_version = "2024-05-01-preview"
from langchain_openai import AzureChatOpenAI
model_kwargs = dict(
    model="gpt-4-128k",
    azure_endpoint="endpoint here",
    api_key="key here",
    api_version=api_version,
    temperature=0.0,
)

gpt_model = AzureChatOpenAI(**model_kwargs, cache=False)

In [3]:
from typing import Any, Dict, Iterator, List, Optional, Sequence, Union
from langchain.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
from pydantic import BaseModel, Field
import json
from langchain_core.messages import HumanMessage, SystemMessage

In [4]:
#batch process
mimic_report_path = r"mimic reports folder"
sampled_files_path = r"mimic_cxr/sampled_paths.json" # the sampled images from the MIMIC-CXR

with open(sampled_files_path, 'r', encoding='utf8') as file:
    mimic_files = json.load(file)

In [5]:
sampled_type4_path = mimic_files

In [1]:
notes = pd.read_csv("MIMIC-IV, csv file for report")
print(notes.shape)
pids = list(set(notes['subject_id'].to_list()))
print(len(pids))

In [4]:
pid_repeat_dict = dict()
import random
def to_report_notes(file_path, notes):
    path = "/".join(file_path.split("/")[:3]) + ".txt"
    with open(os.path.join(mimic_report_path, path), "r") as f:
        report = f.read()
    pid = int(file_path.split("/")[1][1:])
    notes_list_pid = notes[notes['subject_id']==pid]['text'].to_list()
    useful_flag = True
    if pid_repeat_dict.__contains__(pid):
        pid_repeat_dict[pid] += 1
    else:
        pid_repeat_dict[pid] = 1
    useful_flag = True
    sampled_notes = notes_list_pid.copy()
    if len(notes_list_pid) > 15:
        sampled_notes = random.sample(notes_list_pid, 15)
    return path, file_path, report, sampled_notes, useful_flag
to_report_notes(mimic_files[-1], notes)

In [8]:
useful_ids = []
useful_data = []
for _id in sampled_type4_path:
    _data = dict()
    report_path, img_path, report, notes_pid, useful = to_report_notes(_id, notes)
    if useful:
        useful_ids.append(_id)
        _data['img_path'] = img_path
        _data['report'] = report
        _data['notes'] = notes_pid
        useful_data.append(_data)

# Data generation for Contextual Hallucination

In [24]:
system_msg = """
You are provided with the report about a medical image, and the additional clinical notes of this patient. You task is to synthesize a set of close-ended QA pairs according to the requirements. Unfortunately, you don’t have access to the actual image. Below are requirements for generating the questions and answers in the conversation.

Answer responsibly within the given information. Given the clinical notes and x-ray reports provided, you should design questions to test for contextual visual hallucination.  The goal is to ensure that the model interprets the x-ray images accurately within the specific context provided by the clinical notes,  without generating clinically inappropriate or inconsistent responses.

Example yes-or-no questions:
(1) Diagnostic Relevance: "Given the patient’s history of severe back pain and the x-ray findings, is it likely that the back pain is due to a cardiopulmonary issue?" (Expected Answer: No)
(2) Complication Risk: "Is it necessary to monitor the patient for potential complications related to the spine, given the findings of thoracic kyphosis and vertebral wedging?" (Expected Answer: Yes)
(3) Family History: "Given the patient's family history of colorectal cancer, as mentioned in the clinical notes, are there any signs of colorectal abnormalities or precancerous lesions visible in the abdominal CT scan?" (Expected Answer: Yes)
(4) Symptom Analysis: "Does the absence of focal consolidation in the chest x-ray suggest that the patient’s cough is unrelated to a pulmonary infection?" (Expected Answer: Yes)

Instructions:
Generate "yes-or-no" questions, ensuring a balanced distribution of labels (yes and no).
"""

In [69]:
class QA(BaseModel):
    question: str = Field(description="The question")
    answer: str = Field(description="The answer")
    
class QueryAnswer(BaseModel):
    qa_pairs: List[QA] = Field(description="The QA pairs for contextual evaluation.")

#Here we use the classic parser in LangChain: Pydantic, to ensure a strict parsing process
qa_parser = PydanticOutputParser(pydantic_object=QueryAnswer)

qa_prompt = ChatPromptTemplate.from_messages([
    SystemMessage(content=system_msg),
    HumanMessagePromptTemplate(
        prompt=PromptTemplate(
            template="The medical report of the image: {report}, clinical notes: {notes}. Generate high-quality close-ended question-answer pairs.",
            input_variables=["report", "clinical_notes"],
            partial_variables={"format_instructions": qa_parser.get_format_instructions()},
        )
    ),
])


qa_chain = (
    qa_prompt 
    | gpt_model 
    | qa_parser
)

In [73]:
qas = []
from tqdm import tqdm

In [5]:
for i in tqdm(range(len(useful_data))):
    report, notes = useful_data[i]['report'], useful_data[i]["notes"]
    img_path = useful_data[i]['img_path']
    qa = qa_chain.invoke(dict(report=report, clinical_notes=notes))
    final_new_qa = {"img_path":img_path, "qa_pairs": qa.dict()}
    qas.append(final_new_qa)

In [61]:
# store
def trans_file_close(qas_list):
    new_list = []
    id_ = 0
    for i in qas_list:
        img_id = i['img_path']
        new_qas = i['qa_pairs']['qa_pairs']
        
        for j in new_qas:
            new_dict = dict()
            new_dict['img_name'] = img_id
            new_dict['question'] = j["question"]
            new_dict['answer'] = j["answer"]
            new_dict['question_type'] = 'type3_contextual'
            new_dict['ground_truth_type'] = j['ground_truth_type']
            new_dict['choices'] = j['choices']
            
            new_dict['qid'] = id_
            new_dict['img_id'] = img_id

            new_list.append(new_dict)
            id_ += 1

    return new_list

In [42]:
type3_list = trans_file_close(qas)

In [87]:
with open("saved json file", 'w') as json_file:
    json.dump(type3_list, json_file, indent=4)

print(f"Data saved!")

Data saved!
