In [1]:
import os
#set the openai env variable here

In [1]:
from langchain_chroma import Chroma

collection_name = "name of the constructed RAG data base"
db_path = "path to the constructed RAG data base"
db = Chroma(
    collection_name=collection_name,
    persist_directory=db_path,
)

### data prepare

In [2]:
#batch process
import json
mimic_report_path = r"mimic reports folder"
sampled_files_path = r"mimic_cxr/sampled_paths.json" # the sampled images from the MIMIC-CXR

with open(sampled_files_path, 'r', encoding='utf8') as file:
    mimic_files = json.load(file)

def to_report(file_path): #get the report of a medical image
    path = "/".join(file_path.split("/")[:3]) + ".txt"
    with open(os.path.join(mimic_report_path, path), "r") as f:
        report = f.read()
    return path, file_path, report

In [3]:
sampled_type4_paths = mimic_files[1400:1800] #sample 500 for close-ended

In [None]:
useful_ids = []
useful_data = []
for _id in sampled_type4_paths:
    _data = dict()
    img_path, file_path, report = to_report(_id)
    useful_ids.append(_id)
    _data['img_path'] = img_path
    _data['report'] = report
    useful_data.append(_data)

In [69]:
from typing import Any, Dict, Iterator, List, Optional, Sequence, Union

### 

In [25]:
api_version = "2024-05-01-preview"
from langchain_openai import AzureChatOpenAI
model_kwargs = dict(
    model="gpt-4-128k",
    azure_endpoint="endpoint here",
    api_key="key here",
    api_version=api_version,
    temperature=0.0,
)

gpt_model = AzureChatOpenAI(**model_kwargs, cache=False)

In [24]:
system_msg = """
You are provided with the clinical report about a medical image, and relevant retrieved knowledge. You task is to synthesize a set of open-ended QA pairs (asking medical terminologies, such as disease, clinical symptoms) according to the requirements. Unfortunately, you don’t have access to the actual image. Below are requirements for generating the questions and answers in the conversation.

- Do not use phrases like "mentioned", "report" in the conversation. Instead, refer to the information as being "In the image."
- Answer responsibly within the given information.
- You could rely on the knowledge if it is useful. 
- Only focus on the most crucial several terminologies in the reports.
Here is one example question: "What does mediastinal lipomatosis indicate when seen in an image?"
"""

In [41]:
system_msg_knowledge = """
You should filter out useless and noisy retrieval knowledge, keep the important and useful knowledge about the given report, especially the knowledge about the medical terminologies.
"""

In [42]:
from pprint import pprint
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_community.document_transformers import LongContextReorder


import logging

logging.basicConfig()
logging.getLogger("langchain.retrievers.multi_query").setLevel(logging.INFO)

sem_retriever = db.as_retriever(search_kwargs=dict(k=10))

from langchain.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
from pydantic import BaseModel, Field
import json
from langchain_core.messages import HumanMessage, SystemMessage

class _BaseModel(BaseModel):

    def __str__(self):
        return json.dumps(self.__dict__, indent=4)
        
def format_qa(qa_pairs):
    # print(qa_pairs)
    return qa_pairs['report']

def reorder_docs(docs):
    for i, doc in enumerate(docs):
        doc.metadata["rank"] = i+1
    reordering = LongContextReorder()
    return reordering.transform_documents(docs)

def format_docs(docs):
    # print(f"Found docs: {docs}")
    return "\n".join("{}. ({}) {}".format(i+1, doc.metadata["source"], doc.page_content) for i, doc in enumerate(docs))

class QA(BaseModel):
    question: str = Field(description="The question")
    answer: str = Field(description="The answer")

class QueryAnswer(BaseModel):
    qa_pairs: List[QA] = Field(description="The QA pairs list generated.")

class KnowledgeFilter(_BaseModel):
    knowledge: str = Field(description="The filtered knowledge.")

#Here we use the classic parser in LangChain: Pydantic, to ensure a strict parsing process
rag_parser = PydanticOutputParser(pydantic_object=QueryAnswer)
knowledge_filter_parser = PydanticOutputParser(pydantic_object=KnowledgeFilter)
    

filter_prompt = ChatPromptTemplate.from_messages([
    # "Try to explain some concepts mentioned in the QA pairs with the provided knowledge"
    SystemMessage(content=system_msg_knowledge),
    HumanMessagePromptTemplate(
        prompt=PromptTemplate(
            template="The medical report: {report}, retrieved knowledge: {knowledge_ori}, return the filtered knowledge in a json format.",
            input_variables=["knowledge_ori", "report"],
            partial_variables={"format_instructions": knowledge_filter_parser.get_format_instructions()},
        )
    ),
])


rag_prompt = ChatPromptTemplate.from_messages([
    SystemMessage(content=system_msg),
    HumanMessagePromptTemplate(
        prompt=PromptTemplate(
            template="The medical report: {report}, retrieved knowledge: {knowledge}. Generate high-quality open-ended question-answer pairs.",
            input_variables=["knowledge", "report"],
            partial_variables={"format_instructions": rag_parser.get_format_instructions()},
        )
    ),
])


from operator import itemgetter, attrgetter

rag_chain = (
    {"knowledge_ori": format_qa | sem_retriever | reorder_docs | format_docs, "report": RunnablePassthrough()} 
    | filter_prompt 
    | gpt_model 
    | knowledge_filter_parser
    | {"knowledge": attrgetter("knowledge"), "report": RunnablePassthrough()}
    | rag_prompt 
    | gpt_model 
    | rag_parser
)

In [71]:
from tqdm import tqdm
open_qas = []

In [4]:
for i in tqdm(range(len(useful_data))):
    report = useful_data[i]['report']
    img_path = useful_data[i]['img_path']
    qa = rag_chain.invoke(dict(report=report))
    final_new_qa = {"img_path":img_path, "qa_pairs": qa.dict()}
    open_qas.append(final_new_qa)

In [82]:
# store
def trans_file_open(qas_list):
    new_list = []
    id_ = 0
    for i in qas_list:
        img_id = i['img_path']
        new_qas = i['qa_pairs']['qa_pairs']
        for j in new_qas:
            new_dict = dict()
            new_dict['img_name'] = img_id
            new_dict['question'] = j["question"]
            new_dict['answer'] = j["answer"]
            new_dict['question_type'] = 'type4_Knowledge'
            new_dict['structured_answer'] = None
            new_dict['qid'] = id_
            new_dict['img_id'] = img_id
    
            new_list.append(new_dict)
            id_ += 1

    return new_list

In [5]:
mimic_open_json = trans_file_open(open_qas)

In [86]:
with open("saved json file", 'w') as json_file:
    json.dump(mimic_open_json, json_file, indent=4)

print(f"Data saved!")

Data saved!
