In [3]:
import os
#set the openai env variable here

In [None]:
api_version = "2024-05-01-preview"
from langchain_openai import AzureChatOpenAI
model_kwargs = dict(
    model="gpt-4-128k",
    azure_endpoint="endpoint here",
    api_key="key here",
    api_version=api_version,
    temperature=0.0,
)

gpt_model = AzureChatOpenAI(**model_kwargs, cache=False)

In [241]:
slake_system_msg = """ You are provided with the metadata, object bounding boxes and a set of existing QA pairs about one medical image. You task is to synthesize a set of new close-ended QA pairs according to the requirements.
Unfortunately, you don’t have access to the actual image. Below are requirements for generating the questions and answers in the conversation.
- Do not use phrases like "mentioned", "bounding boxes", "context" in the conversation. Instead, refer to the information as being "in the image."
- Answer responsibly within the given information and knowledge context, avoiding information not included in the given context.
- Make sure your generation contains a key "type" for each QA pair, indicating its question type (as detailed in the following part).
- Make the questions more diverse.
- The ground truth type can be "yes or no" or one choice from multi-choice (you need to synthesize several choices in this case), condition on the question and available materials.

Here are a set of question types for your generation, you must assign each new QA pair a question type.
type_1: Anatomical Hallucination. Example questions: "Which part of the body does this image belong to?" "Does the picture contain liver?"
type_2: Measurement Hallucination like location, size. Example questions: "Where is the liver?"
type_3: Symptom-Based Hallucination. Example questions: "Is the lung healthy?" "Is there evidence of a pneumothorax?" "Is there a fracture?"
type_4: Technique Hallucination. Example questions:  "What modality is used to take this image? "

Instructions:
- Add one key to each QA pair, key= "ground_truth_type", value= "binary" if the type is "yes or no" else "multi-choice"
- When you see diagnosis information of a disease (e.g. lung cancer) in QA pairs, you should generate new QA pair by asking the symptoms of the disease 
- For "multi-choice" type QA, you must include one key "choices" of string type.
- Avoid the question that you can not generate a ground truth, for example avoid the answer "The image does not provide information". 
"""

In [256]:
from langchain.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
from pydantic import BaseModel, Field
from typing import Any, Dict, Iterator, List, Optional, Sequence, Union

def format_qa(qa_pairs):
    # print(qa_pairs)
    return qa_pairs['qa_pairs']

class Slake_closed_QA(BaseModel):
    Q: str = Field(description="The question (Q)")
    A: str = Field(description="The answer (A)")
    type: str = Field(description="The QA type")
    choices: str = Field(default="", description="The QA choices for multi-choice question type")
    ground_truth_type: str = Field(description="The ground_truth type")
    

class SlakeQueryAnswer(BaseModel):
    qa_pairs: List[Slake_closed_QA] = Field(description="The QA pairs list generated.")

#Here we use the classic parser in LangChain: Pydantic, to ensure a strict parsing process
slake_closed_parser = PydanticOutputParser(pydantic_object=SlakeQueryAnswer)

slake_colsed_prompt = ChatPromptTemplate.from_messages([
    SystemMessagePromptTemplate.from_template(slake_system_msg),
    HumanMessagePromptTemplate(
        prompt=PromptTemplate(
            template="The metadata: {metadata} , object bounding boxes:{bounding_boxes}, existing QA pairs: {qa_pairs}. Generate high-quality questions for which the correct answers can be inferred solely from the provided information. Ensure the questions align with the specified question types.",
            input_variables=["metadata", "bounding_boxes", "qa_pairs"],
            partial_variables={"format_instructions": slake_closed_parser.get_format_instructions()},
        )
    ),
])



slake_closed_chain = (
    slake_colsed_prompt 
    | gpt_model 
    | slake_closed_parser
)


## Batch Processing and Generation

In [252]:
# Slake
import os
import json
from collections import defaultdict
from tqdm import tqdm

In [196]:
slake_data = r"SLAKE FOLDER"
test_file_path = os.path.join(slake_data, "the original SLAKE QAs json file")

with open(test_file_path, 'r', encoding='utf8') as file:
    slake_test = json.load(file)

In [201]:
test_data_id_map = dict()
img_data_ids = defaultdict(list)
for i in slake_test:
    test_data_id_map[i['qid']] = i
    img_data_ids[i['img_name']].append(i['qid'])

In [207]:
# img_data_ids.keys()

In [205]:
img_data_ids['xmlab102/source.jpg']

[11934, 11935, 11936, 11937, 11938, 11939, 11940, 11941, 11942, 11943, 11944]

In [220]:
img_path = 'data/slake/imgs'
def get_slake_qa(img_id):
    final_qas = ''
    qa_ids = img_data_ids[img_id]
    for _id in qa_ids:
        qa_pair = test_data_id_map[_id]
        final_qas += f"Q: {qa_pair['question']} \nA: {qa_pair['answer']} \n"
    metadata = "image_organ:" + qa_pair['location'] + "image modality:"+ qa_pair['metadata'] # take the last qa pair and us the metadata of this image
    bounding_path = os.path.join(img_path, f'{img_id.split("/")[0]}/detection.json')
    with open(bounding_path, 'r', encoding='utf8') as file:
        bounding_dectection = json.load(file)
    return metadata, final_qas, bounding_dectection

In [221]:
metadata, qa, bounding = get_slake_qa('xmlab102/source.jpg')

### filter rules:
- binary type -- if no "yes" or "no", filter out
- multi-choice type -- if no "choice", filter out
- 

In [253]:
slake_new_qas = [] # {"img_path":XXX, "new_qa_pairs":[xxx]}

In [250]:
slake_img_ids = list(img_data_ids.keys())

In [1]:
for _id in tqdm(slake_img_ids):
    metadata, qa, bounding = get_slake_qa(_id)
    ans = slake_closed_chain.invoke(dict(qa_pairs=qa, bounding_boxes=bounding, metadata=metadata))
    final_new_qa = {"img_path":_id, "new_qa_pairs": ans}
    slake_new_qas.append(final_new_qa)

In [269]:
slake_new_qas[-1]

{'img_path': 'xmlab80/source.jpg',
 'new_qa_pairs': SlakeQueryAnswer(qa_pairs=[Slake_closed_QA(Q='Is the imaging modality used for this image an MRI?', A='No', type='type_4', choices='', omission_type=0, ground_truth_type='binary'), Slake_closed_QA(Q='Is the primary organ shown in the image the heart?', A='No', type='type_1', choices='', omission_type=0, ground_truth_type='binary'), Slake_closed_QA(Q='Is the abnormality found in the left lung?', A='Yes', type='type_2', choices='', omission_type=1, ground_truth_type='binary'), Slake_closed_QA(Q='Does the image show any signs of pneumonia?', A='No', type='type_3', choices='', omission_type=0, ground_truth_type='binary'), Slake_closed_QA(Q='What disease is present in the image?', A='A. Lung Cancer', type='type_3', choices='A. Lung Cancer, B. Liver Cirrhosis, C. Cardiomegaly', omission_type=-1, ground_truth_type='multi-choice'), Slake_closed_QA(Q='Is there a presence of a tumor in the image?', A='Yes', type='type_3', choices='', omission_t

In [293]:
def trans_file(qas_list):
    id_ = 0
    new_list = []
    for i in qas_list:
        img_id = i['img_path']
        new_qas = i['new_qa_pairs'].dict()['qa_pairs']
        old_qa_ids = img_data_ids[img_id]
        old_qa_pair = test_data_id_map[old_qa_ids[0]]
        for j in new_qas:
            new_dict = dict()
            new_dict['img_name'] = img_id
            new_dict['question'] = j['Q']
            new_dict['answer'] = j['A']
            if "image" in j['A'].lower() and "not" in j['A'] and "provide" in j['A']:
                continue
            if j['ground_truth_type'] == "binary":
                if "yes" not in j['A'].lower() and "no" not in j['A'].lower():
                    print(j['A'].lower())
                    continue
            new_dict['hallucination_type'] = j['type']
            new_dict['question_type'] = j['ground_truth_type']
            new_dict['choices'] = j['choices']
            new_dict['img_id'] = old_qa_pair['img_id']
            new_dict['qid'] = id_
            new_dict['location'] = old_qa_pair['location']
            new_dict['modality'] = old_qa_pair['modality']
            new_list.append(new_dict)
            id_ += 1
    return new_list

In [2]:
slake_new_qas_json = trans_file(slake_new_qas)

In [300]:
slake_new_qas_json[9] #data example

{'img_name': 'xmlab103/source.jpg',
 'question': 'What is the condition of the right lung?',
 'answer': 'Abnormal',
 'hallucination_type': 'type_3',
 'omission_type': -1,
 'question_type': 'multi-choice',
 'choices': 'A: Healthy, B: Abnormal, C: Not visible',
 'img_id': 103,
 'qid': 9,
 'location': 'Abdomen',
 'modality': 'CT'}

In [298]:
with open("saved json file", 'w') as json_file:
    json.dump(slake_new_qas_json, json_file, indent=4)

print(f"Data saved!")

Data saved!
