In [4]:
import os
#set the openai env variable here

In [6]:
api_version = "2024-05-01-preview"
from langchain_openai import AzureChatOpenAI
model_kwargs = dict(
    model="gpt-4-128k",
    azure_endpoint="endpoint here",
    api_key="key here",
    api_version=api_version,
    temperature=0.0,
)

gpt_model = AzureChatOpenAI(**model_kwargs, cache=False)

In [19]:
system_msg = """ You are provided with the metadata and a set of existing QA pairs about one medical image. You task is to synthesize a set of new close-ended QA pairs according to the requirements. Unfortunately, you don’t have access to the actual image. Below are requirements for generating the questions and answers in the conversation.

- Do not use phrases like "mentioned", "QA", "context" in the conversation. Instead, refer to the information as being "in the image."
- Answer responsibly within the given information and knowledge context, avoiding information not included in the given context.
- Make sure your generation contains a key "type" for each QA pair, indicating its question type (as detailed in the following part).
- Make the questions diverse.
- The ground truth type can be "yes or no" or one choice from multi-choice (you need to synthesize several choices in this case), condition on the question and available materials.

Here are a set of question types for your generation, you must assign each new QA pair a question type.
type_1: Anatomical Hallucination. Example questions: "Which part of the body does this image belong to?" "Does the picture contain liver?"
type_2: Measurement Hallucination like location, size. Example questions: "Where is the liver?"
type_3: Symptom-Based Hallucination. Example questions: "Is the lung healthy?" "Is there evidence of a pneumothorax?" "Is there a fracture?"
type_4: Technique Hallucination. Example questions:  "What modality is used to take this image? "

Instructions:
- Add one key to each QA pair, key= "ground_truth_type", value= "binary" if the type is "yes or no" else "multi-choice"
- When you see diagnosis information of a disease (e.g. lung cancer) in QA pairs, you should generate new QA pair by asking the symptoms of the disease 
- For "multi-choice" type QA, you must include one key "choices" of string type.
- Avoid the question that you can not generate a ground truth, for example avoid the answer "The image does not provide information". 

"""

In [20]:
from typing import Any, Dict, Iterator, List, Optional, Sequence, Union

def inspect(state):
    """Print the state passed between Runnables in a langchain and pass it on"""
    print(state, '\n')
    return state

from langchain.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
from pydantic import BaseModel, Field
import json


def format_qa(qa_pairs):
    # print(qa_pairs)
    return qa_pairs['qa_pairs']

class Closed_QA(BaseModel):
    Q: str = Field(description="The question (Q)")
    A: str = Field(description="The answer (A)")
    type: str = Field(description="The QA type")
    choices: str = Field(default="", description="The QA choices for multi-choice question type")
    ground_truth_type: str = Field(description="The ground_truth type")
    

class QueryAnswer(BaseModel):
    qa_pairs: List[Closed_QA] = Field(description="The QA pairs list generated.")

#Here we use the classic parser in LangChain: Pydantic, to ensure a strict parsing process
closed_parser = PydanticOutputParser(pydantic_object=QueryAnswer)

colsed_prompt = ChatPromptTemplate.from_messages([
    SystemMessagePromptTemplate.from_template(system_msg),
    HumanMessagePromptTemplate(
        prompt=PromptTemplate(
            template="Given the metadata: {metadata} and existing QA pairs: {qa_pairs}, generate high-quality questions for which the correct answers can be inferred solely from the provided information. Ensure the questions align with the specified question types.",
            input_variables=["metadata", "qa_pairs"],
            partial_variables={"format_instructions": closed_parser.get_format_instructions()},
        )
    ),
])

closed_chain = (
    colsed_prompt 
    | gpt_model 
    | closed_parser
)


## Batch Processing and Generation

In [25]:
# Slake
import os
import json
from collections import defaultdict
from tqdm import tqdm

In [22]:
rad_data = r"RAD-VQA DATA PATH"
test_file_path = os.path.join(rad_data, "JSON file of the original RQA-VQA")

with open(test_file_path, 'r', encoding='utf8') as file:
    rad_test = json.load(file)

In [26]:
test_data_id_map = dict()
img_data_ids = defaultdict(list)
for i in rad_test:
    test_data_id_map[i['qid']] = i
    img_data_ids[i['image_name']].append(i['qid'])

In [27]:
len(img_data_ids)

314

In [31]:
img_data_ids['synpic54610.jpg']

['0', 13, 14, 16, 17, 21]

In [32]:
img_path = 'data/rad_vqa/images'
def get_slake_qa(img_id):
    final_qas = ''
    qa_ids = img_data_ids[img_id]
    for _id in qa_ids:
        qa_pair = test_data_id_map[_id]
        final_qas += f"Q: {qa_pair['question']} \nA: {qa_pair['answer']} \n"
    metadata = "image_organ:" + qa_pair['image_organ'] # take the last qa pair and us the metadata of this image
    return final_qas, metadata

### filter rules:
- binary type -- if no "yes" or "no", filter out
- multi-choice type -- if no "choice", filter out

In [36]:
slake_new_qas = [] # {"img_path":XXX, "new_qa_pairs":[xxx]}

In [37]:
slake_img_ids = list(img_data_ids.keys())

In [3]:
for _id in tqdm(slake_img_ids):
    qa, metadata = get_slake_qa(_id)
    ans = closed_chain.invoke(dict(qa_pairs=qa, metadata=metadata))
    final_new_qa = {"img_path":_id, "new_qa_pairs": ans}
    slake_new_qas.append(final_new_qa)

In [47]:
slake_new_qas[313]

{'img_path': 'synpic19114.jpg',
 'new_qa_pairs': QueryAnswer(qa_pairs=[Closed_QA(Q='Is there evidence of pleural plaques in the image?', A='Yes', type='type_3', choices='', omission_type=1, ground_truth_type='binary'), Closed_QA(Q='What is the abnormality detected in the lungs?', A='Pleural plaques', type='type_3', choices='A. Pleural effusion, B. Pleural plaques, C. Pneumothorax, D. Lung nodules', omission_type=-1, ground_truth_type='multi-choice'), Closed_QA(Q='Is the aortic arch enlarged?', A='No', type='type_3', choices='', omission_type=0, ground_truth_type='binary'), Closed_QA(Q='Does the image show a normal aortic arch contour?', A='Yes', type='type_3', choices='', omission_type=1, ground_truth_type='binary'), Closed_QA(Q='Are the pleural plaques localized to the hemithoraces?', A='Yes', type='type_2', choices='', omission_type=1, ground_truth_type='binary'), Closed_QA(Q='What part of the body is shown in the image?', A='Chest', type='type_1', choices='A. Chest, B. Abdomen, C. P

In [61]:
def trans_file(qas_list):
    new_list = []
    id_ = 0
    for i in qas_list:
        img_id = i['img_path']
        new_qas = i['new_qa_pairs'].dict()['qa_pairs']
        old_qa_ids = img_data_ids[img_id]
        old_qa_pair = test_data_id_map[old_qa_ids[0]]
        for j in new_qas:
            new_dict = dict()
            new_dict['img_name'] = img_id
            new_dict['question'] = j['Q']
            new_dict['answer'] = j['A']
            if "image" in j['A'].lower() and "not" in j['A'] and "provide" in j['A']:
                continue
            if j['ground_truth_type'] == "binary":
                if "yes" not in j['A'].lower() and "no" not in j['A'].lower():
                    print(j['A'].lower())
                    continue
            new_dict['hallucination_type'] = j['type']
            new_dict['question_type'] = j['ground_truth_type']
            new_dict['choices'] = j['choices']
            new_dict['qid'] = id_
            new_dict['img_id'] = old_qa_pair['image_name']
            new_dict['location'] = old_qa_pair['image_organ']
            new_dict['modality'] = None
            new_list.append(new_dict)
            id_ += 1
    return new_list

In [2]:
rad_new_qas_json = trans_file(slake_new_qas)

In [63]:
rad_new_qas_json[789] #data example

{'img_name': 'synpic51926.jpg',
 'question': 'Does the image show any abnormalities in the digestive system?',
 'answer': 'Yes',
 'hallucination_type': 'type_3',
 'omission_type': 1,
 'question_type': 'binary',
 'choices': '',
 'qid': 789,
 'img_id': 'synpic51926.jpg',
 'location': 'ABD',
 'modality': None}

In [64]:
with open("saved json file", 'w') as json_file:
    json.dump(rad_new_qas_json, json_file, indent=4)

print(f"Data saved!")

Data saved!
