## Close-ended Data generation for Visual Factual Hallucination using IU-Xray

In [3]:
import os
#set the openai env variable here

In [4]:
api_version = "2024-05-01-preview"
from langchain_openai import AzureChatOpenAI

model_kwargs = dict(
    model="gpt-4-128k", 
    azure_endpoint="endpoint here",
    api_key="key here",
    api_version=api_version,
    temperature=0.0,
)

gpt_model = AzureChatOpenAI(**model_kwargs, cache=False)

In [6]:
import logging
logging.basicConfig()

from langchain.chains import LLMChain
from langchain.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
from pydantic import BaseModel, Field
import json

In [8]:
system_msg_close = """ You are provided with the clinical report about a medical image. You task is to synthesize a set of new close-ended QA pairs according to the requirements. Unfortunately, you don’t have access to the actual image.
Below are requirements for generating the questions and answers in the conversation.
- Do not use phrases like "mentioned", "report", "context" in the conversation. Instead, refer to the information as being "in the image."
- Answer responsibly within the given information and knowledge context, avoiding information not included in the given context.
- Make sure your generation contains a key "type" for each QA pair, indicating its question type (as detailed in the following part).
- Make the questions more diverse.
- The ground truth type can be "yes or no" or one choice from multi-choice (you need to synthesize several choices in this case), condition on the question and available materials.

Here are a set of question types for your generation, you must assign each new QA pair a question type.
type_1: Anatomical Hallucination. Example questions: "Which part of the body does this image belong to?" "Does the picture contain liver?"
type_2: Measurement Hallucination like location, size. Example questions: "Where is the liver?"
type_3: Symptom-Based Hallucination. Example questions: "Is the lung healthy?" "Is there evidence of a pneumothorax?" "Is there a fracture?"
type_4: Technique Hallucination. Example questions:  "What modality is used to take this image? "

Instructions:
- Add one key to each QA pair, key= "ground_truth_type", value= "binary" if the type is "yes or no" else "multi-choice"
- When you see diagnosis information of a disease (e.g. lung cancer) in QA pairs, you should generate new QA pair by asking the symptoms of the disease 
- For "multi-choice" type QA, you must include one key "choices" of string type.
- Avoid the question that you can not generate a ground truth, for example avoid the answer "The image does not provide information". 
"""

In [34]:
#for filtering low-quality generated questions
sys_msg_filter = """
You are provided with a set of QA pairs in a json format. Your task is to improve or filter out low quality QA pairs.
Here are some standards for the measurement of low-quality:
- The answer is vague when the question type is multi-choice, try to make the answer clear, which choice is correct.
- The answer is like "xxx is not specified in the image.", which means that this qa pair should be filtered out becuase there is no ground truth.
You can also judge from other common standards.
"""

In [44]:
from typing import Any, Dict, Iterator, List, Optional, Sequence, Union

class Closed_QA(BaseModel):
    Q: str = Field(description="The question (Q)")
    A: str = Field(description="The answer (A)")
    type: str = Field(description="The QA type")
    choices: str = Field(default="", description="The QA choices for multi-choice question type")
    ground_truth_type: str = Field(description="The ground_truth type")
    

class Closed_QueryAnswer(BaseModel):
    qa_pairs: List[Closed_QA] = Field(description="The QA pairs list generated.")
    
#Here we use the classic parser in LangChain: Pydantic, to ensure a strict parsing process
closed_parser = PydanticOutputParser(pydantic_object=Closed_QueryAnswer)

colsed_prompt = ChatPromptTemplate.from_messages([
    SystemMessagePromptTemplate.from_template(system_msg_close),
    HumanMessagePromptTemplate(
        prompt=PromptTemplate(
            template="The medical report: {report}. generate high-quality questions for which the correct answers can be inferred solely from the provided report. ",
            input_variables=["report"],
            partial_variables={"format_instructions": closed_parser.get_format_instructions()},
        )
    ),
])

#for filtering out low-quality generated questions
filter_prompt = ChatPromptTemplate.from_messages([
    SystemMessagePromptTemplate.from_template(sys_msg_filter),
    HumanMessagePromptTemplate(
        prompt=PromptTemplate(
            template="The generated QA pairs in a json format: {qa_pairs}. Try to improve or filter out low-quality QA pairs, return in the original json format, but output the QA pair list as the value of additional key 'qa_pairs'. No other items, ensure the complete json format for parser!",
            input_variables=["qa_pairs"],
            partial_variables={"format_instructions": closed_parser.get_format_instructions()},
        )
    ),
])


from operator import itemgetter, attrgetter

closed_chain = (
    colsed_prompt 
    | gpt_model 
    | closed_parser
    | {"qa_pairs": attrgetter("qa_pairs")}
    | filter_prompt
    | gpt_model
    | closed_parser
)

In [11]:
#batch processing
files_path = "the path of IU-Xray data, the original QAs json file"

with open(files_path, 'r', encoding='utf8') as file:
    xray_files = json.load(file)

In [28]:
len(xray_files['test']), xray_files['test'][5] #use the test set

(590,
 {'id': 'CXR2915_IM-1317',
  'report': 'The heart size is normal. The mediastinal contour is within normal limits. The lungs are free of any focal infiltrates. There are no nodules or masses. No visible pneumothorax. No visible pleural fluid. The XXXX are grossly normal. There is no visible free intraperitoneal air under the diaphragm.',
  'image_path': ['CXR2915_IM-1317/0.png', 'CXR2915_IM-1317/1.png'],
  'split': 'test'})

In [25]:
#use 290 for close-ended generation

In [41]:
xray_close_qas = []
from tqdm import tqdm

In [1]:
for i in tqdm(xray_files['test']):
    img_path, report = i['image_path'][0] if i['image_path'][0][-5]=="0" else i['image_path'][1], i['report']
    ans = closed_chain.invoke(dict(report=report))
    final_new_qa = {"img_path":img_path, "new_qa_pairs": ans}
    xray_close_qas.append(final_new_qa)

In [48]:
len(xray_close_qas)

290

In [32]:
#first version, with low-quality pairs
xray_close_qas

[{'img_path': 'CXR3030_IM-1405/0.png',
  'new_qa_pairs': Closed_QueryAnswer(qa_pairs=[Closed_QA(Q='Does the image show any signs of a large pleural effusion?', A='No', type='type_3', choices='', omission_type=0, ground_truth_type='binary'), Closed_QA(Q='Is there any evidence of pneumothorax in the image?', A='No', type='type_3', choices='', omission_type=0, ground_truth_type='binary'), Closed_QA(Q='Can we see any acute bony abnormalities in the image?', A='No', type='type_3', choices='', omission_type=0, ground_truth_type='binary'), Closed_QA(Q='What part of the body is depicted in the image?', A='Chest', type='type_1', choices='A. Chest, B. Abdomen, C. Skull, D. Pelvis', omission_type=-1, ground_truth_type='multi-choice'), Closed_QA(Q='Is the cardiomediastinal silhouette normal in the image?', A='Yes', type='type_3', choices='', omission_type=1, ground_truth_type='binary'), Closed_QA(Q='Are there any signs of focal consolidation in the image?', A='No', type='type_3', choices='', omiss

In [51]:
# store
def trans_file(qas_list):
    new_list = []
    id_ = 0
    for i in qas_list:
        img_id = i['img_path']
        new_qas = i['new_qa_pairs'].dict()['qa_pairs']
        # old_qa_ids = img_data_ids[img_id]
        # old_qa_pair = test_data_id_map[old_qa_ids[0]]
        for j in new_qas:
            new_dict = dict()
            new_dict['img_name'] = img_id
            new_dict['question'] = j['Q']
            new_dict['answer'] = j['A']
            if "image" in j['A'].lower() and "not" in j['A'] and "provide" in j['A']:
                continue
            if j['ground_truth_type'] == "binary":
                if "yes" not in j['A'].lower() and "no" not in j['A'].lower():
                    print(j['A'].lower())
                    continue
            new_dict['hallucination_type'] = j['type']
            new_dict['question_type'] = j['ground_truth_type']
            new_dict['choices'] = j['choices']
            new_dict['qid'] = id_
            new_dict['img_id'] = img_id
            new_dict['location'] = None
            new_dict['modality'] = None
            new_list.append(new_dict)
            id_ += 1
    return new_list

In [4]:
xray_new_qas_json = trans_file(xray_close_qas)

In [53]:
len(xray_new_qas_json)

2017

In [54]:
xray_new_qas_json[890]

{'img_name': 'CXR141_IM-0260/0.png',
 'question': 'Does the image show any abnormalities in the cardiac size?',
 'answer': 'No',
 'hallucination_type': 'type_3',
 'omission_type': 0,
 'question_type': 'binary',
 'choices': '',
 'qid': 890,
 'img_id': 'CXR141_IM-0260/0.png',
 'location': None,
 'modality': None}

In [55]:
with open("out path for iu_xray data", 'w') as json_file:
    json.dump(xray_new_qas_json, json_file, indent=4)

print(f"Data saved!")

Data saved!
