# Sistemas inteligentes para respostas a perguntas médicas

Gyovana M. Moriyama (216190)

Rafael A. Matumoto (273085)

In [None]:
!pip install -qU langchain_openai langchain-community langchain_experimental  faiss-cpu sentence-transformers openai datasets pydantic langgraph rank_bm25

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/50.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.4/50.4 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25h

## Experiments with segmentation methods

In [None]:
import os
import datasets
import datetime
import re
import pickle
import pandas as pd

from google.colab import userdata, drive

from tqdm import tqdm
from pydantic import BaseModel, Field
from typing import List, Optional, Literal

from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter
from langchain_experimental.text_splitter import SemanticChunker
from langchain_community.document_loaders import TextLoader, DirectoryLoader
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceBgeEmbeddings

from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain.agents import AgentExecutor, create_react_agent
from langchain.tools.retriever import create_retriever_tool
from langchain_core.load import dumps, loads
from langchain_community.retrievers import BM25Retriever
from rank_bm25 import BM25Okapi

In [None]:
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
filepath = '/content/drive/MyDrive/IA024A_Processamento_de_Linguagem_Natural/Projeto Final/projeto/entrega3/'

In [None]:
os.environ['OPENAI_API_KEY'] = userdata.get('OPENAI_API_KEY')

In [None]:
def save_retriever(filename, retriever):
  #Saves retriever in a file
  with open(f'{filepath}retriever/bm25/{filename}.pkl', 'wb') as f:
      pickle.dump(retriever, f)

### MedQA-USMLE-4-options dataset

In [None]:
data = datasets.load_dataset('GBaker/MedQA-USMLE-4-options', split='train')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
data = data.add_column('original_quest_id', range(len(data)))

In [None]:
#Sample dataset with 50 questions
sampled_data = data.shuffle(seed=42).select(range(50))
sampled_data.to_pandas()

Unnamed: 0,question,answer,options,meta_info,answer_idx,metamap_phrases,original_quest_id
0,A 35-year-old woman comes to your office with ...,Hypercoagulable state,"{'A': 'Pallor, cyanosis, and erythema of the h...",step1,C,"[35 year old woman, office, variety, complaint...",2622
1,An 8-year-old boy is brought to the pediatrici...,GAA,"{'A': 'CGG', 'B': 'GAA', 'C': 'CAG', 'D': 'GCC'}",step1,B,"[year old boy, brought, pediatrician, mother, ...",1754
2,A 36-year-old man is brought to the emergency ...,Breakdown of endothelial tight junctions,{'A': 'Release of vascular endothelial growth ...,step1,C,"[36 year old man, brought, emergency departmen...",3718
3,A 35-year-old woman presents to the ER with sh...,Cor pulmonale,"{'A': 'Left-sided heart failure', 'B': 'Corona...",step1,D,"[35 year old woman presents, ER, shortness of ...",9107
4,A 5-year-old boy is brought in by his parents ...,Begin cognitive behavioral therapy,{'A': 'Increase oral hydration and fiber intak...,step2&3,D,"[5 year old boy, brought, parents, recurrent a...",1838
5,A 5-day-old male newborn is brought to the eme...,IV acyclovir,"{'A': 'IV ganciclovir', 'B': 'Pyrimethamine', ...",step2&3,C,"[5 day old male newborn, brought, emergency de...",4147
6,A 30-year-old woman presents to the clinic for...,Discoid lupus erythematosus (DLE),"{'A': 'Alopecia areata', 'B': 'Discoid lupus e...",step2&3,B,"[30 year old woman presents, clinic, 3 month h...",3631
7,A 68-year-old man is admitted to the intensive...,Low urine sodium,"{'A': 'Decreased urine osmolarity', 'B': 'Leuk...",step2&3,C,"[68 year old man, admitted to, intensive care ...",2397
8,A 44-year-old with a past medical history sign...,Herpes simplex virus,"{'A': 'Cryptococcus', 'B': 'Group B streptococ...",step1,C,"[year old, past medical history significant, h...",9469
9,A 8-month-old boy is brought to the physician ...,CT scan of the head,"{'A': 'Growth hormone therapy', 'B': 'Levothyr...",step2&3,C,"[month old boy, brought, physician, the evalua...",9045


## MedQA

In [None]:
# download MedQA data from https://drive.google.com/file/d/1ImYUSLk9JbgHXOemfvyiDiirluZHPeQw/view?usp=sharing
!gdown -q 1ImYUSLk9JbgHXOemfvyiDiirluZHPeQw
!unzip -q data_clean.zip -d /content/medQA

In [None]:
def load_documents():
    '''
    Loads reference textbooks for MedQA as a list of Document objects.
    '''

    medqa_path = '/content/medQA/data_clean/textbooks/en/'
    loader = DirectoryLoader(medqa_path, glob='**/*.txt', loader_cls=TextLoader)

    # list of 'Document' files, with 'source'/file path as metadata
    documents = loader.load()

    return documents

In [None]:
docs = load_documents()

In [None]:
# splits list of documents into chunks and adds the source txt filename as metadata
chunk_size = 900
overlap_size = 90

## Recursive Character Text Splitter

In [None]:
#Recursive
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=overlap_size, separators=['\n\n', '\n', ' '])
split_documents = text_splitter.create_documents([doc.page_content for doc in docs], metadatas=[{'ref': doc.metadata['source'].split('/')[-1]} for doc in docs])
retriever_bm25 = BM25Retriever.from_documents(split_documents)

In [None]:
save_retriever(f'bm25_recsplitter_{chunk_size}_{overlap_size}', retriever_bm25)

## Load saved retriever

In [None]:
bm25_recsplitter_500_50.pkl
bm25_recsplitter_300_30.pkl
bm25_recsplitter_600_60.pkl
bm25_recsplitter_900_90.pkl

In [None]:
with open(f'{filepath}retriever/bm25/bm25_recsplitter_900_90.pkl', 'rb') as f:
    retriever_bm25 = pickle.load(f)

In [None]:
# create a search tool from the bm25 retriever
tool_search_bm25 = create_retriever_tool(
    retriever=retriever_bm25,
    name='search_docs',
    description='Searchs the query in documents and returns the top-3 most relevant ones.',
)

### ReAct agent

In [None]:
model = 'gpt-4o-mini'
model_temp = 0.5

In [None]:
llm = ChatOpenAI(
    model=model,
    temperature=model_temp,
)

In [None]:
# react prompt
react_prompt_template = '''
Solve the question answering task alternating between Question, Thought, Action, Input and Observation steps.
You only have access to the following tools: {tools}

Use the following format:
Question: the input question you must answer
Thought: you should always think about what to do. Break the problem down into subproblems and smaller steps and decide which action to take.
Action: the action to take, must be one of [{tool_names}]. If no action is needed, return your Final answer instead.
Action Input: the input to the action
Observation: the output of the action.
... (this Thought/Action/Action Input/Observation can repeat any number of times)
Final Answer: A/B/C/D

Begin!
Question: {question}
Options: {options}
Thought: {agent_scratchpad}
'''.strip()

In [None]:
react_prompt = ChatPromptTemplate.from_template(react_prompt_template)

### ReAct with BM25 Okapi retriever

https://python.langchain.com/docs/how_to/serialization/

In [None]:
agent = create_react_agent(
    llm=llm,
    tools=[tool_search_bm25],
    prompt=react_prompt
)

In [None]:
agent_executor = AgentExecutor(
    agent=agent,
    tools=[tool_search_bm25],
    verbose=False,
    handle_parsing_errors=True,
    return_intermediate_steps=True,
    max_iterations=5
)

In [None]:
with open(f'{filepath}results/answers_bm25_recsplitter_900_90_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}.txt', 'w') as f:

    f.write(''.center(10, '-'))
    f.write('\n')
    f.write(f'Model: {model}\n')
    f.write(f'Temperature: {model_temp}\n')
    f.write(f'Prompt: {react_prompt}\n')
    f.write(''.center(10, '-'))
    f.write('\n[{')

    for n, question in enumerate(tqdm(sampled_data)):

        res = agent_executor.invoke({'question': question['question'], 'options': question['options']})
        f.write(f'"{n}": ')

        # serialize results as a string
        f.write(dumps(res))
        f.write(',')

    f.write('"placeholder": "placeholder"') # add a placeholder k:v pair for trailling comma above
    f.write('}]')

100%|██████████| 50/50 [08:41<00:00, 10.43s/it]


#### Results with BM25 in 50 sample questions (train dataset)

In [None]:
os.listdir(f'{filepath}results')

['answers_faiss_20241103_151837.txt',
 'answers_faiss_20241103_152229.txt',
 'answers_faiss_20241103_152940.txt',
 'answers_faiss_20241103_153308.txt',
 'answers_faiss_20241103_153501.txt',
 'answers_faiss_20241103_154054.txt',
 'answers_faiss_20241103_155219.txt',
 'answers_faiss_20241103_165217.txt',
 'dataset_faiss_500_50',
 'dataset_faiss_500_50.csv',
 'answers_bm25_20241104_000018.txt',
 'dataset_bm25_500_50',
 'dataset_bm25_500_50.csv',
 'answers_bm25_recsplitter_500_50_20241105_001257.txt',
 'answers_faiss_600_20_20241105_003556.txt',
 'answers_bm25_recsplitter_500_50_20241105_010747.txt',
 'answers_bm25_recsplitter_600_20_20241105_012522.txt',
 'answers_bm25_charsplitter_600_20_20241105_013434.txt',
 'answers_bm25_charsplitter_500_50_20241105_014929.txt',
 'answers_bm25_recsplitter_500_50_20241105_225628.txt',
 'answers_bm25_recsplitter_600_20_20241105_230708.txt',
 'answers_bm25_charsplitter_500_50_20241105_231550.txt',
 'answers_bm25_charsplitter_600_20_20241105_232617.txt',


In [None]:
answers_files = ['answers_bm25_recsplitter_500_50_20241105_225628.txt',
                'answers_bm25_recsplitter_300_30_20241106_000239.txt',
                'answers_bm25_recsplitter_600_60_20241106_001127.txt',
                'answers_bm25_recsplitter_900_90_20241106_002117.txt']

In [None]:
acc_df = pd.DataFrame(columns=['Model', 'Chunk Size', 'Overlap Size', 'Accuracy'])

In [None]:
for ans_file in answers_files:
  model_name = ans_file.split('_')[2]
  chunk_size = ans_file.split('_')[3]
  overlap_size = ans_file.split('_')[4].split('_')[0]
  print(f'Model: {model_name}, Chunk Size: {chunk_size}, Overlap Size: {overlap_size}')
  with open(f'{filepath}results/{ans_file}') as f:
    tmp = f.readlines()
  recovered_results = loads(tmp[-1])[0]

  limit_error = list()
  not_parsed = list()
  answers = list()

  # matches literals A, B, C, or D at the beginning of a sentence (followed by the text of the corresponding option)
  # e.g. A (answer text) => A
  # other cases will be ignored
  ans_ptn = re.compile(r'(^[A|B|C|D])\s\(')

  for n, (k, v) in enumerate(recovered_results.items()):
    # skip placeholder key
    if k == 'placeholder':
        continue

    # max_iteration set to 5
    if v['output'] == 'Agent stopped due to iteration limit or time limit.':
        answers.append(None)
        limit_error.append(k)

    elif len(v['output']) > 1:
        tmp_output = ans_ptn.findall(v['output'])
        if len(tmp_output) == 1:
            answers.append(tmp_output[0])
        elif len(tmp_output) > 1:
            answers.append(None)
            not_parsed.append(k)
            print(tmp_output)
        else:
            answers.append(None)
            not_parsed.append(k)
            print(k, v['output'])

    else:
        if v['output'] in ['A', 'B', 'C', 'D']:
            answers.append(v['output'])
        else:
            answers.append(None)

    # unanswered questions or answers not properly parsed
    print('Unanswered questions or answers not properly parsed: ', len(limit_error))
    # parsing errors/other errors
    print('Parsing errors/other errors: ', len(not_parsed))
    #Accuracy
    accuracy = sum([pred == gt for gt, pred in zip(sampled_data['answer_idx'], answers)]) / len(sampled_data)
    print('Accuracy: ', accuracy)
    print('\n')

    acc_df.loc[len(acc_df)] = [model_name, chunk_size, overlap_size, accuracy]


Model: recsplitter, Chunk Size: 500, Overlap Size: 50
Unanswered questions or answers not properly parsed:  5
Parsing errors/other errors:  0
Accuracy:  0.68


Model: recsplitter, Chunk Size: 300, Overlap Size: 30
Unanswered questions or answers not properly parsed:  1
Parsing errors/other errors:  0
Accuracy:  0.74


Model: recsplitter, Chunk Size: 600, Overlap Size: 60
Unanswered questions or answers not properly parsed:  5
Parsing errors/other errors:  0
Accuracy:  0.7


Model: recsplitter, Chunk Size: 900, Overlap Size: 90
Unanswered questions or answers not properly parsed:  6
Parsing errors/other errors:  0
Accuracy:  0.7




  recovered_results = loads(tmp[-1])[0]


In [None]:
acc_df

Unnamed: 0,Model,Chunk Size,Overlap Size,Accuracy
0,recsplitter,500,50,0.68
1,recsplitter,300,30,0.74
2,recsplitter,600,60,0.7
3,recsplitter,900,90,0.7


## Context relevance evaluation with Ragas


Using Recursive Character Splitter with chunk size 300 and overlap 30 as it presented the highest accuracy.

In [None]:
class RelevantSentence(BaseModel):
    '''A sentence you believe is relevant to answer the provided question.
    '''
    sentence: str

class Context(BaseModel):
    '''The list of contexts you retrieved based on a question.
    '''
    contexts: List[RelevantSentence]

In [None]:
ext_sentence_prompt = '''
Please extract relevant sentences from the provided context that can potentially help answer the following question.
If no relevant sentences are found, or if you believe the question cannot be answered from the given context, return the phrase "Insufficient Information".
While extracting candidate sentences you’re not allowed to make any changes to sentences from given context.
Here is the context: {context}
And here is the question: {question}
'''.strip()

ext_sentence_prompt_template = ChatPromptTemplate.from_template(ext_sentence_prompt)

In [None]:
chain_extraction = ext_sentence_prompt_template | llm.with_structured_output(Context)

In [None]:
ctxt_rel_df = pd.DataFrame(columns=['Chunk Size', 'Overlap Size', 'Context relevancy'])

In [None]:
# result_dict[question]['intermediate_steps'] -> List[List[AgentAction(tool, tool_input, log), response: str]]
def format_context(answer):

    formatted_output = ''
    for i in answer['intermediate_steps']:
        formatted_output += i[1]

    return formatted_output

In [None]:
for ans_file in answers_files:

    chunk_size = ans_file.split('_')[3]
    chunk_overlap = ans_file.split('_')[4].split('_')[0]

    print(chunk_size, chunk_overlap)

    with open(f'{filepath}results/{ans_file}') as f:
        tmp = f.readlines()

    recovered_results = loads(tmp[-1])[0]

    react_logs = list()
    for i in recovered_results.values():
        if i == 'placeholder':
            continue
        react_logs.append(format_context(i))

    sampled_data_2 = sampled_data.add_column('answer_react_faiss', answers)
    sampled_data_2 = sampled_data_2.add_column('react_log', react_logs)

    relevant_contexts = list()
    context_rel = list()

    for question in tqdm(sampled_data_2):

        try:
            # select relevant sentences
            contexts = chain_extraction.invoke({'context': question['react_log'], 'question': question['question']})
            relevant_contexts.append(contexts)

            # fraction of relevant sentences in the context
            context_rel.append(len(contexts.contexts) / len(question['react_log'].split('.')))
        except:
            relevant_contexts.append(None)
            context_rel.append(np.nan)

    cont_rel_dataset = np.nanmean(context_rel)
    ctxt_rel_df.loc[len(ctxt_rel_df)] = [chunk_size, chunk_overlap, cont_rel_dataset]


500 50


100%|██████████| 50/50 [03:23<00:00,  4.08s/it]


300 30


100%|██████████| 50/50 [03:15<00:00,  3.90s/it]


600 60


100%|██████████| 50/50 [03:41<00:00,  4.44s/it]


900 90


100%|██████████| 50/50 [04:33<00:00,  5.47s/it]


In [None]:
ctxt_rel_df

Unnamed: 0,Chunk Size,Overlap Size,Context relevancy
0,500,50,0.297487
1,300,30,0.378731
2,600,60,0.473106
3,900,90,0.29558
