# Components

In [117]:
import os
from langchain_ollama import OllamaEmbeddings
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_community.vectorstores import FAISS
import faiss

embeddings = OllamaEmbeddings(model="mxbai-embed-large")

d = len(embeddings.embed_query("hello world"))
index = faiss.IndexFlatL2(d)

vector_store = FAISS(
    embedding_function=embeddings,
    index=index,
    docstore=InMemoryDocstore(),
    index_to_docstore_id={}
)

In [None]:
from openai import OpenAI
import backoff
import openai

client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

sys = """
Please answer the question based on the provided information and context.
Return a single letter "A", "B", "C" or "D" corresponding to the correct answer.
Do not return any other text.
"""

@backoff.on_exception(backoff.expo, openai.RateLimitError, max_time=999, max_tries=99)
def completions_with_backoff(**kwargs):
    return client.chat.completions.create(**kwargs)

def ask_gpt(prompt: str) -> str:
    response = completions_with_backoff(
        model="gpt-4.1-mini",
        messages=[
            {
                "role": "system", 
                "content": sys
            },
            {
                "role": "user",
                "content": prompt,
            }
        ],
        max_tokens=1,
        n=1
    )
    return response.choices[0].message.content

In [119]:
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langgraph.graph import START, StateGraph
from typing_extensions import List, TypedDict

def mega_load(data_path, msg="files"):
    files = [file for file in os.listdir(data_path) if os.path.isfile(os.path.join(data_path, file))]
    for i, file_name in enumerate(files):
        print(f"Processing {i + 1}th {msg}: {file_name}")
        content = ""
        file_path = os.path.join(data_path, file_name)
        with open(file_path, 'r', encoding='utf-8') as file:
            for line in file:
                content += line.strip() + "\n"
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=2048, chunk_overlap=128)
        all_splits = text_splitter.split_text(content)
        _ = vector_store.add_texts(all_splits)

faiss_index_path = "/home/ngjabach/Documents/NgJaBach/Medical-Graph-RAG/faiss_index"
if os.path.exists(faiss_index_path):
    vector_store = FAISS.load_local(faiss_index_path, embeddings, allow_dangerous_deserialization=True)
else:
    print("Begin vector storing...")
    mega_load("./books_1")
    mega_load("./books_2")
    print("Done vector storing!")
    vector_store.save_local(faiss_index_path)

retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 16})

In [120]:
prompt = """
---Start of Context---
{context}
---End of Context---

{question}
"""

In [121]:
class State(TypedDict):
    context: List[Document]
    answer: str
    question: str

# Define application steps
def retrieve(state: State):
    # print("Begin retriving!")
    docs = retriever.invoke(state["question"])
    # print("Done retrieving!")
    return {"context": docs}

def generate(state: State):
    # print("Begin generating!")
    docs_content = "\n\n".join(doc.page_content for doc in state["context"])
    messages = prompt.format(context = docs_content, question = state["question"])
    # print(messages)
    response = ask_gpt(messages)
    # print("Done generating!")
    return {"answer": response}

# print("Compiling...")
graph_builder = StateGraph(State).add_sequence([retrieve, generate])
graph_builder.add_edge(START, "retrieve")
graph = graph_builder.compile()
# print("Done compiling!")

In [122]:
import pandas as pd
from tqdm import tqdm

df = pd.read_json("MedQA.jsonl", lines=True)

correct = 0
for idx, row in tqdm(df.iterrows(), total=len(df)):
    q = row['question']
    a = row['options']
    s = row['answer_idx']
    # QA = f"Q: {q}\nA: {a['A']}\nB: {a['B']}\nC: {a['C']}\nD: {a['D']}\nAnswer: {s}"
    # print(QA)
    brompt = f"Question: {q}\nA: {a['A']}\nB: {a['B']}\nC: {a['C']}\nD: {a['D']}\nAnswer:"
    while True:
        res = graph.invoke({"question": brompt})["answer"]
        if res in ['A', 'B', 'C', 'D']:
            break
    if res == s:
        correct += 1

accuracy = correct / len(df) * 100
print(f"Total: {len(df)}")
print(f"Correct: {correct}")
print(f"Accuracy: {accuracy:.2f}%")

100%|██████████| 1273/1273 [46:26<00:00,  2.19s/it] 

Total: 1273
Correct: 865
Accuracy: 67.95%



