In [None]:
import os
#set the openai env variable here

In [7]:
api_version = "2024-05-01-preview"
from langchain_openai import AzureChatOpenAI

model_kwargs = dict(
    model="gpt-4-128k",
    azure_endpoint="endpoint here",
    api_key="key here",
    api_version=api_version,
    temperature=0.0,
)
gpt_model = AzureChatOpenAI(**model_kwargs, cache=False)

In [10]:
from typing import Any, Dict, Iterator, List, Optional, Sequence, Union

### Close-ended

In [51]:
system_msg_close = """ You are provided with the clinical report about a medical image. You task is to synthesize a set of new close-ended QA pairs according to the requirements. Unfortunately, you don’t have access to the actual image. Below are requirements for generating the questions and answers in the conversation.
- Do not use phrases like "mentioned", "report", "context" in the conversation. Instead, refer to the information as being "in the image."
- Answer responsibly within the given information and knowledge context, avoiding information not included in the given context.
- Make sure your generation contains a key "type" for each QA pair, indicating its question type (as detailed in the following part).
- Make the questions more diverse.
- The ground truth type can be "yes or no" or one choice from multi-choice (you need to synthesize several choices in this case), condition on the question and available materials.

Here are a set of question types for your generation, you must assign each new QA pair a question type.
type_1: Anatomical Hallucination. Example questions: "Which part of the body does this image belong to?" "Does the picture contain liver?"
type_2: Measurement Hallucination like location, size. Example questions: "Where is the liver?"
type_3: Symptom-Based Hallucination. Example questions: "Is the lung healthy?" "Is there evidence of a pneumothorax?" "Is there a fracture?"
type_4: Technique Hallucination. Example questions:  "What modality is used to take this image? "

Instructions:
- Add one key to each QA pair, key= "ground_truth_type", value= "binary" if the type is "yes or no" else "multi-choice"
- When you see diagnosis information of a disease (e.g. lung cancer) in QA pairs, you should generate new QA pair by asking the symptoms of the disease 
- For "multi-choice" type QA, you must include one key "choices" of string type.
- Avoid the question that you can not generate a ground truth, for example avoid the answer "The image does not provide information". 
"""

In [55]:
from langchain.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
from pydantic import BaseModel, Field
import json
from typing import Any, Dict, Iterator, List, Optional, Sequence, Union

class Closed_QA(BaseModel):
    Q: str = Field(description="The question (Q)")
    A: str = Field(description="The answer (A)")
    type: str = Field(description="The QA type")
    choices: str = Field(default="", description="The QA choices for multi-choice question type")
    ground_truth_type: str = Field(description="The ground_truth type")
    

class Closed_QueryAnswer(BaseModel):
    qa_pairs: List[Closed_QA] = Field(description="The QA pairs list generated.")
#Here we use the classic parser in LangChain: Pydantic, to ensure a strict parsing process
closed_parser = PydanticOutputParser(pydantic_object=Closed_QueryAnswer)

colsed_prompt = ChatPromptTemplate.from_messages([
    SystemMessagePromptTemplate.from_template(system_msg_close),
    HumanMessagePromptTemplate(
        prompt=PromptTemplate(
            template="The medical report of the image: {report}. generate high-quality questions for which the correct answers can be inferred solely from the provided report. ",
            input_variables=["report"],
            partial_variables={"format_instructions": closed_parser.get_format_instructions()},
        )
    ),
])

closed_chain = (
    colsed_prompt 
    | gpt_model 
    | closed_parser
)

In [190]:
#batch process
mimic_report_path = r"mimic reports folder"
sampled_files_path = r"mimic_cxr/sampled_paths.json" # the sampled images from the MIMIC-CXR

with open(sampled_files_path, 'r', encoding='utf8') as file:
    mimic_files = json.load(file)

In [185]:
len(mimic_files) # 2000 samples from the MIMIC-CXR, will use 500

2000

In [186]:
mimic_files[0] # path sample

'p19/p19454978/s57331547/7d047120-d24a497e-fc26ea7e-6c3acc0c-ce5bc190.jpg'

In [3]:
def to_report(file_path): #get the report of a medical image
    path = "/".join(file_path.split("/")[:3]) + ".txt"
    with open(os.path.join(mimic_report_path, path), "r") as f:
        report = f.read()
    return path, file_path, report

In [192]:
sampled_close_paths = mimic_files[:500] #sample 500 for close-ended

In [193]:
mimic_close_qas = []
from tqdm import tqdm

In [1]:
for _id in tqdm(sampled_close_paths):
    report_path, img_path, report = to_report(_id)
    ans = closed_chain.invoke(dict(report=report))
    final_new_qa = {"img_path":img_path, "new_qa_pairs": ans}
    mimic_close_qas.append(final_new_qa)

In [224]:
len(mimic_close_qas)

500

In [196]:
mimic_close_qas[1] #data sample

{'img_path': 'p11/p11924226/s56990167/dc00203a-4168ce8c-d79d47d2-eef8780b-d3fe037a.jpg',
 'new_qa_pairs': Closed_QueryAnswer(qa_pairs=[Closed_QA(Q='Is the heart size within normal limits in the image?', A='Yes', type='type_3', choices='', omission_type=1, ground_truth_type='binary'), Closed_QA(Q='Does the image show any abnormalities in the lung fields?', A='No', type='type_3', choices='', omission_type=0, ground_truth_type='binary'), Closed_QA(Q='What part of the mediastinum appears altered in the image?', A='The superior mediastinum', type='type_2', choices='A. The superior mediastinum, B. The inferior mediastinum, C. The anterior mediastinum, D. The posterior mediastinum', omission_type=-1, ground_truth_type='multi-choice'), Closed_QA(Q="Is the patient's position affecting the clarity of the image?", A='Yes', type='type_3', choices='', omission_type=1, ground_truth_type='binary'), Closed_QA(Q='Are follow-up films suggested for the patient?', A='Yes', type='type_3', choices='', omiss

In [210]:
#next step: filter out low-quality data

sys_msg_filter = """
You are provided with a set of QA pairs in a json format. Your task is to improve or filter out low quality QA pairs.
Here are some standards for the measurement of low-quality:
- improve the answer format if possible.
- for "imaging modality" type question, ensure that the ground truth is X-ray since all the data are chest x-ray, remove the choice that are not common radiology modality.
- for the qa pair with type='type_2', ensure that the question is asking the attribute (position, color, size, etc) of an **organ**, not other things such as the **patient**, delete it if it is not about an organ.
- The answer is vague when the question type is multi-choice, try to make the answer clear, which choice is correct.
- The answer is like "xxx is not specified in the image." or "The image does not provide information", which means that this qa pair should be filtered out becuase there is no ground truth.
- The questions that can not response only from the context like "Are there any changes in the xxx compared to the **previous study**?"
- ensure the answer of multi-choice question is in this type of format: "A.XXXX, B.XXXX, C.XXXX, ...."
- add more choices if the multi-choice QA with only one choice like only "A.XXXX"!
You can also judge from other common standards.
Then please format the qa pairs into a high-quality format.
Notice: jusr remove the qa_pair if it is of low-quality.
"""
filter_prompt = ChatPromptTemplate.from_messages([
    SystemMessagePromptTemplate.from_template(sys_msg_filter),
    HumanMessagePromptTemplate(
        prompt=PromptTemplate(
            template="The generated QA pairs in a json format: {qa_pairs}. Try to improve or filter out low-quality QA pairs, return in the original json format, but output the QA pair list as the value of additional key 'qa_pairs'. No other items, ensure the complete json format for parser!",
            input_variables=["qa_pairs"],
            partial_variables={"format_instructions": closed_parser.get_format_instructions()},
        )
    ),
])

from operator import itemgetter, attrgetter

filter_chain = (
    filter_prompt
    | gpt_model
    | closed_parser
)

In [221]:
filtered_mimic_close_qas = []

In [2]:
# for qa_pairs in tqdm(mimic_close_qas[]):
for qa_pairs in tqdm(mimic_close_qas):
    ans = filter_chain.invoke(dict(qa_pairs=qa_pairs['new_qa_pairs'].dict()))
    final_new_qa = {"img_path":qa_pairs['img_path'], "new_qa_pairs": ans}
    filtered_mimic_close_qas.append(final_new_qa)

In [236]:
len(filtered_mimic_close_qas)

500

In [237]:
filtered_mimic_close_qas[-1], len(filtered_mimic_close_qas) # data sample after the filtering

({'img_path': 'p16/p16043637/s52793175/1b3d4f71-68977c5e-a070ff6b-29584c84-b70bf667.jpg',
  'new_qa_pairs': Closed_QueryAnswer(qa_pairs=[Closed_QA(Q='Does the image show any signs of pneumonia?', A='No', type='type_3', choices='', omission_type=0, ground_truth_type='binary'), Closed_QA(Q='Is there a pacemaker visible in the image?', A='Yes', type='type_1', choices='', omission_type=1, ground_truth_type='binary'), Closed_QA(Q='Can you identify the presence of pleural effusion in the image?', A='No', type='type_3', choices='', omission_type=0, ground_truth_type='binary'), Closed_QA(Q='What type of medical device is noted on the left side in the image?', A='A: A left-sided pacemaker', type='type_4', choices='A: A left-sided pacemaker, B: A stent, C: An infusion pump, D: A defibrillator', omission_type=-1, ground_truth_type='multi-choice'), Closed_QA(Q='Is there evidence of a pneumothorax in the image?', A='No', type='type_3', choices='', omission_type=0, ground_truth_type='binary'), Close

In [235]:
# save file
def trans_file(qas_list):
    new_list = []
    id_ = 0
    for i in qas_list:
        img_id = i['img_path']
        new_qas = i['new_qa_pairs'].dict()['qa_pairs']
        # old_qa_ids = img_data_ids[img_id]
        # old_qa_pair = test_data_id_map[old_qa_ids[0]]
        for j in new_qas:
            new_dict = dict()
            new_dict['img_name'] = img_id
            new_dict['question'] = j['Q']
            new_dict['answer'] = j['A']
            if "image" in j['A'].lower() and "not" in j['A'] and "provide" in j['A']:
                continue
            if j['ground_truth_type'] == "multi-choice":
                if "B" not in j['choices']:
                    continue
                
            if j['ground_truth_type'] == "binary":
                if "yes" not in j['A'].lower() and "no" not in j['A'].lower():
                    print(j['A'].lower())
                    continue
            new_dict['hallucination_type'] = j['type']
            new_dict['question_type'] = j['ground_truth_type']
            new_dict['choices'] = j['choices']
            new_dict['qid'] = id_
            new_dict['img_id'] = img_id
            new_dict['location'] = None
            new_dict['modality'] = None
            new_list.append(new_dict)
            id_ += 1
    return new_list

In [4]:
mimic_new_qas_json = trans_file(filtered_mimic_close_qas)

In [240]:
with open("saved json file", 'w') as json_file:
    json.dump(mimic_new_qas_json, json_file, indent=4)

print(f"Data saved!")

Data saved!


### Open-ended data generation for Visual Factual Hallucination

In [147]:
system_msg_open = """ You are provided with the clinical report about a medical image. You task is to synthesize a set of open-ended QA pairs according to the requirements. Unfortunately, you don’t have access to the actual image. Below are requirements for generating the questions and answers in the conversation.

Do not use phrases like "mentioned", "report", "context" in the conversation. Instead, refer to the information as being "in the image.“
Answer responsibly within the given report, avoiding information not included in the given context. 

You need to generate the question to query three types of components: (1)anatomical structure; (2)anatomical measurement(organ location, size); (3)symptoms(such as normal or abnormal symptoms, not direct diagnosis). 
Here is an example question: "List your findings of anatomical structures and measurements in detail, as well as the possible symptoms, abnormal findings on these structures."

Instructions:
Additionally return a structured answer with clear classification of three components, which means simply classify the answer into this form:
       {"Sturctured_Answer": {"anatomy": List[str], "measurement": List[str], "symptom": List[str]} }
Make sure the classification is precise and accurate. If you are not sure about the category, do not include it in the structured result.
The structured output of "measurement" should be the measurements of organs or important structures, it could be an empty list if there is no important measurements.
"""

In [148]:

class StructureAns(BaseModel):
    anatomy: List[str] = Field(description="The structured answer about anatomical structure")
    measurement: List[str] = Field(description="The structured answer about anatomical structure measurement")
    symptom: List[str] = Field(description="The structured answer about possible symptoms")

class Type2_QA(BaseModel):
    Question: str = Field(description="The question")
    Answer: str = Field(description="The answer")
    Sturctured_Answer: StructureAns = Field(description="The structured answer")

class OpenQA(BaseModel):
    type_2: Type2_QA = Field(description="The QA generated.")


open_parser = PydanticOutputParser(pydantic_object=OpenQA)

In [132]:
output_format: """
{
"type_1": {"Question": str, "Answer":str, "Sturctured_Answer":{"anatomy": list[str], "measurement": list[str], "symptom": list[str]}}
}

"""
from langchain_core.messages import HumanMessage, SystemMessage

In [149]:
rag_prompt_wo_knowledge = ChatPromptTemplate.from_messages([
    # SystemMessagePromptTemplate.from_messages(system_msg_open_wo_knowledge),
    SystemMessage(content=system_msg_open),
    HumanMessagePromptTemplate(
        prompt=PromptTemplate(
            template="""
            The medical imaging report: {report}, ensure that you follow the output format as detailed in the system format. 
            """,
            input_variables=["report"],
            partial_variables={"format_instructions": open_parser.get_format_instructions()},
        )
    ),
])

rag_chain_wo_knowledge = (
    rag_prompt_wo_knowledge 
    | gpt_model 
    | open_parser
)

In [215]:
open_qas = []
sampled_open_paths = mimic_files[500:1000]

In [5]:
#batch process
# for qa_pairs in tqdm(mimic_close_qas[]):
for _id in tqdm(sampled_open_paths):
    report_path, img_path, report = to_report(_id)
    ans = rag_chain_wo_knowledge.invoke(dict(report=report))
    final_new_qa = {"img_path":img_path, "qa_pairs": ans}
    open_qas.append(final_new_qa)

In [231]:
len(open_qas), open_qas[1]["qa_pairs"].dict() # data example

(500,
 {'type_1': {'Question': 'What does the image show in a few words?',
   'Answer': 'The image is a portable AP chest X-ray showing severe cardiomegaly, mild vascular congestion, retrocardiac opacities, and a possible small left effusion.'},
  'type_2': {'Question': 'List your findings of anatomical structures and measurements in detail, as well as the possible symptoms, abnormal findings on these structures.',
   'Answer': 'The chest X-ray reveals a severely enlarged heart, consistent with cardiomegaly. The vascular markings are mildly congested. There are retrocardiac opacities that have shown improvement, suggesting resolving atelectasis. Additionally, there may be a small effusion on the left side. No evidence of pneumothorax is present.',
   'Sturctured_Answer': {'anatomy': ['heart',
     'vascular markings',
     'retrocardiac area',
     'left pleural space'],
    'measurement': [],
    'symptom': ['severe cardiomegaly',
     'mild vascular congestion',
     'improved retroc

In [232]:
# save file
def trans_file_open(qas_list):
    new_list = []
    id_ = 0
    for i in qas_list:
        img_id = i['img_path']
        new_qas = i['qa_pairs'].dict()
        new_dict = dict()
        
        j = new_qas
        new_dict['img_name'] = img_id
        new_dict['question'] = j['type_1']["Question"]
        new_dict['answer'] = j['type_1']["Answer"]
        new_dict['question_type'] = 'type_1'
        new_dict['qid'] = id_
        new_dict['img_id'] = img_id

        new_list.append(new_dict)
        id_ += 1
        new_list.append(new_dict)

    return new_list

In [6]:
mimic_open_json = trans_file_open(open_qas)

In [234]:
with open("saved json file", 'w') as json_file:
    json.dump(mimic_open_json, json_file, indent=4)

print(f"Data saved!")

Data saved!
