In [None]:
import numpy as np
import faiss, os, json, pickle
from sentence_transformers import SentenceTransformer, CrossEncoder
from rank_bm25 import BM25Okapi
from datasets import load_dataset
from tqdm import tqdm
import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
nltk.download('stopwords')
nltk.download('punkt')

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package stopwords to
[nltk_data]     /home/f74111102/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to /home/f74111102/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [16]:
from dotenv import load_dotenv
load_dotenv()

from openai import OpenAI
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
model = "gpt-4o-mini"

def call_openai(messages:list[dict]) -> str:
    response = client.chat.completions.create(
        model=model,
        messages=messages,
    )
    return response.choices[0].message.content

In [None]:
# you need to first create the cleaned contexts and embedding file.
ds = load_dataset("pubmed_qa", "artificial", split="train")
contexts = [sent for item in ds['context'] for sent in item['contexts']]
cleaned_contexts = [__bm25_precleaning(context) for context in contexts]
pickle.dump(cleaned_contexts, open("pubmedqa_artificial_cleaned_contexts.txt", "wb"))

def __bm25_precleaning(self, context:str) -> list[str]:
    stop_words = set(stopwords.words('english'))
    words = word_tokenize(context)
    filtered_words = [word for word in words if word.lower() not in stop_words]
    return filtered_words



embedding_model = SentenceTransformer('google/embeddinggemma-300m')
embedded = embedding_model.encode(ds['question'], convert_to_numpy=True, show_progress_bar=True, batch_size=64)
pickle.dump(embedded, open("embedding_pubmedqa_artificial.pkl", "wb"))

# If you've already created, comment out the above code and load the files directly.

In [None]:
class MedQADataset:
    def __init__(self):
        self.dataset = load_dataset("qiaojin/PubMedQA", "pqa_artificial", split="train")
        self.contexts = [sent for item in self.dataset['context'] for sent in item['contexts']]
        self.cleaned_contexts = open("pubmedqa_artificial_cleaned_contexts.txt", "r").read().splitlines()
        
        self.embedded = pickle.load(open("embedding_pubmedqa_artificial.pkl", "rb")) if os.path.exists("embedding_pubmedqa_artificial.pkl") else None

        self.bm25 = BM25Okapi([context.split(" ") for context in self.cleaned_contexts])
        self.embedding_model = SentenceTransformer('google/embeddinggemma-300m')
        self.cross_model = CrossEncoder("BSC-NLP4BIA/Medprocner-CE-Reranker")
        self.device = "cuda" if os.getenv("CUDA_VISIBLE_DEVICES") else "cpu"

    def __bm25_precleaning(self, context:str) -> list[str]:
        stop_words = set(stopwords.words('english'))
        words = word_tokenize(context)
        filtered_words = [word for word in words if word.lower() not in stop_words]
        return filtered_words

    def load_to_device(self):
        self.cross_model.to(self.device)
        self.embedding_model.to(self.device)

        return self
    
    def encode_contexts(self):
        if self.embedded is None:
            self.embedded = self.embedding_model.encode(self.dataset['question'], convert_to_numpy=True, show_progress_bar=True, batch_size=64)
            pickle.dump(self.embedded, open("embedding_pubmedqa_artificial.pkl", "wb"))
        return self

    def hybrid_retrieval(self, query:str, top_k:int=5) -> list[int]:
        if self.embedded is None:
            self.embedded = self.embedding_model.encode(self.dataset['question'], convert_to_numpy=True, show_progress_bar=True, batch_size=64)
            pickle.dump(self.embedded, open("embedding_pubmedqa_artificial.pkl", "wb"))

        # BM25 part
        # bm25 top_k = 1000 
        tokenized_query = self.__bm25_precleaning(query)
        bm25_scores = self.bm25.get_scores(tokenized_query)
        bm25_topk_indices = np.argsort(bm25_scores)[-1000*2:][::-1]

        # SBERT part
        # SBERT top_k = 50
        query_embedding = self.embedding_model.encode([query], convert_to_numpy=True)
        faiss_index = faiss.IndexFlatL2(self.embedded[bm25_topk_indices].shape[1])
        faiss_index.add(self.embedded[bm25_topk_indices])
        _, sbert_topk_indices = faiss_index.search(query_embedding, k=50)

        candidate_indices = bm25_topk_indices[sbert_topk_indices[0]]
        candidate_contexts = [self.contexts[idx] for idx in candidate_indices]
        cross_inputs = [[query, context] for context in candidate_contexts]
        cross_scores = self.cross_model.predict(cross_inputs, batch_size=64, convert_to_numpy=True, show_progress_bar=False)
        final_topk_indices = np.argsort(cross_scores)[-top_k:][::-1]
        # contexts
        final_indices = candidate_indices[final_topk_indices]
        final_contexts = [self.contexts[idx] for idx in final_indices]
        return final_contexts


In [4]:
DS = MedQADataset().load_to_device()

True


## Adaptive RAG

In [40]:
question = "What is the evidence that intermittent fasting improves insulin sensitivity in adults?"

In [41]:
ADAPTIVE_PROMPT_TEMPLATE = ""
with open("adaptive_prompt.txt", "r") as f:
    ADAPTIVE_PROMPT_TEMPLATE = f.read()
    f.close()

In [42]:
def Adaptive_RAG(query:str) -> dict:
    prompt = ADAPTIVE_PROMPT_TEMPLATE.replace("{{query}}", query)
    print("Query: ", query)

    # LLM choise: 1. RETRIEVE 2. RESPONSE
    messages = [
        {"role": "user", "content": prompt}
    ]
    resp = call_openai(messages)
    response = json.loads(resp)

    while response['action'] == "RETRIEVE":
        contexts = DS.hybrid_retrieval(query, top_k=5)
        messages.append({"role": "assistant", "content": response['query']})
        messages.append({"role": "user", "content": "Retrieved Context:\n" + "\n".join(contexts)})
        
        print("AI: ", response['query'])
        print("Retrieved Contexts: ", contexts)

        resp = call_openai(messages)
        response = json.loads(resp)
    
    print("Final Answer: ", response['text'])
    return response['text']

In [43]:
Adaptive_RAG(question)

Query:  What is the evidence that intermittent fasting improves insulin sensitivity in adults?
AI:  intermittent fasting insulin sensitivity adults clinical trials systematic review meta-analysis
Retrieved Contexts:  ['The high-fat diet induced a significant drop in insulin sensitivity (determined by euglycaemic-hyperinsulinaemic clamp) compared to baseline (0.100+/-0.009 vs 0.083+/-0.007 micro mol.kg(-1).min(-1).(pmol.l(-1)), p=0.01). The drop in insulin sensitivity was more pronounced in subjects with low serum adiponectin (0.094+/-0.011 vs 0.077+/-0.010 micro mol.kg(-1).min(-1).(pmol.l(-1)), p=0.02) than in subjects with high serum adiponectin (0.103+/-0.011 vs 0.090+/-0.040 micro mol.kg(-1).min(-1).(pmol.l(-1)), p=0.16). In the whole group the high-carbohydrate, low-fat diet did not cause an increase in insulin sensitivity (0.095+/-0.007 vs 0.102+/-0.009 micro mol.kg(-1).min(-1).(pmol.l(-1)), p=0.06). However, insulin sensitivity was significantly increased in the subgroup with low

'After multiple retrieval attempts, relevant documents indicate that intermittent fasting (IF) has been shown to increase whole-body insulin sensitivity. However, the exact mechanisms and the degree to which IF influences intermediary metabolism remain uncertain. Therefore, some evidence supports the idea that IF can improve insulin sensitivity in adults, but further investigation into the interaction with factors like adiponectin levels and magnesium intake may be necessary.'