In [1]:
import os
#set the openai env variable here

In [2]:
api_version = "2024-05-01-preview"
from langchain_openai import AzureChatOpenAI
model_kwargs = dict(
    model="gpt-4-128k",
    azure_endpoint="endpoint here",
    api_key="key here",
    api_version=api_version,
    temperature=0.0,
)

gpt_model = AzureChatOpenAI(**model_kwargs, cache=False)

In [3]:
from typing import Any, Dict, Iterator, List, Optional, Sequence, Union
from langchain.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
from pydantic import BaseModel, Field
import json
from langchain_core.messages import HumanMessage, SystemMessage

In [1]:
mimic_report_path = r"mimic reports folder"
sampled_files_path = r"mimic_cxr/sampled_paths.json" # the sampled images from the MIMIC-CXR

with open(sampled_files_path, 'r', encoding='utf8') as file:
    mimic_files = json.load(file)

def to_report(file_path): #get the report of a medical image
    path = "/".join(file_path.split("/")[:3]) + ".txt"
    with open(os.path.join(mimic_report_path, path), "r") as f:
        report = f.read()
    return path, file_path, report

In [None]:
sampled_type4_paths = mimic_files[1000:1400] #sample 500 for close-ended

In [8]:
useful_ids = []
useful_data = []
for _id in sampled_type4_paths:
    _data = dict()
    img_path, file_path, report = to_report(_id)
    useful_ids.append(_id)
    _data['img_path'] = img_path
    _data['report'] = report
    useful_data.append(_data)
    # ans = closed_chain.invoke(dict(report=report))
    # final_new_qa = {"img_path":img_path, "new_qa_pairs": ans}
    # mimic_close_qas.append(final_new_qa)

# data generation for close-ended evaluation of Knowledge Hallucination beyond Images

In [38]:
system_msg = """
You are provided with the clinical report about a medical image. You task is to synthesize a set of close-ended QA pairs (diagnosis) according to the requirements. Unfortunately, you don’t have access to the actual image. Below are requirements for generating the questions and answers in the conversation.

Do not use phrases like "mentioned", "report", "context" in the conversation. Instead, refer to the information as being "in the image".
Answer responsibly within the given report, avoiding information not included in the given context. 

Instructions:
Ensure balanced labels in your generated questions. For example, "yes-or-no" questions should have an equal number of "yes" and "no" answers. To achieve this balance, you may use negative sampling to generate questions with the answer "no".

"""

In [55]:
class QA(BaseModel):
    question: str = Field(description="The question")
    answer: str = Field(description="The answer")
    
class QueryAnswer(BaseModel):
    qa_pairs: List[QA] = Field(description="The QA pairs for contextual evaluation.")

#Here we use the classic parser in LangChain: Pydantic, to ensure a strict parsing process
qa_parser = PydanticOutputParser(pydantic_object=QueryAnswer)

qa_prompt = ChatPromptTemplate.from_messages([
    SystemMessage(content=system_msg),
    HumanMessagePromptTemplate(
        prompt=PromptTemplate(
            template="The medical report of the image: {report}. Generate high-quality close-ended question-answer pairs focused on the diagnosis.",
            input_variables=["report"],
            partial_variables={"format_instructions": qa_parser.get_format_instructions()},
        )
    ),
])

qa_chain = (
    qa_prompt 
    | gpt_model 
    | qa_parser
)

In [15]:
type2_qas = []
from tqdm import tqdm

In [2]:
for i in tqdm(range(len(useful_data))):
    report = useful_data[i]['report']
    img_path = useful_data[i]['img_path']
    qa = qa_chain.invoke(dict(report=report))
    final_new_qa = {"img_path":img_path, "qa_pairs": qa.dict()}
    type2_qas.append(final_new_qa)

In [26]:
# store
def trans_file_close(qas_list):
    new_list = []
    id_ = 0
    for i in qas_list:
        img_id = i['img_path']
        new_qas = i['qa_pairs']['qa_pairs']
        
        for j in new_qas:
            if j['question_topic'] == "diagnosis":
                new_dict = dict()
                new_dict['img_name'] = img_id
                new_dict['question'] = j["question"]
                new_dict['answer'] = j["answer"]
                new_dict['question_type'] = 'type4_Diagnosis'
                new_dict['ground_truth_type'] = j['ground_truth_type']
                new_dict['choices'] = j['choices']
                
                new_dict['qid'] = id_
                new_dict['img_id'] = img_id
    
                new_list.append(new_dict)
                id_ += 1

    return new_list

In [27]:
type2_close = trans_file_close(type2_qas)
len(type2_close), type2_close[0]

(1172,
 {'img_name': 'p19/p19454978/s57331547/7d047120-d24a497e-fc26ea7e-6c3acc0c-ce5bc190.jpg',
  'question': 'Is there evidence of atelectasis according to the image?',
  'answer': 'Yes',
  'question_type': 'type4_Diagnosis',
  'ground_truth_type': 'binary',
  'choices': '',
  'qid': 0,
  'img_id': 'p19/p19454978/s57331547/7d047120-d24a497e-fc26ea7e-6c3acc0c-ce5bc190.jpg'})

In [28]:
with open("saved json file", 'w') as json_file:
    json.dump(type2_close, json_file, indent=4)

print(f"Data saved!")

Data saved!
