# Components

In [16]:
import os
from langchain_ollama import OllamaEmbeddings
import faiss
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_community.vectorstores import FAISS

embeddings = OllamaEmbeddings(model="mxbai-embed-large")

d = len(embeddings.embed_query("hello world"))
index = faiss.IndexFlatL2(d)

vector_store = FAISS(
    embedding_function=embeddings,
    index=index,
    docstore=InMemoryDocstore(),
    index_to_docstore_id={}
)

In [17]:
from openai import OpenAI

client = OpenAI(
    api_key=os.getenv("OPENAI_API_KEY"),
    # base_url="https://api.llama-api.com/"
)

def ask_gpt(prompt: str, system: str, model="gpt-4o") -> str:
    response = client.chat.completions.create(
        model=model,
        # model="llama3.2-3b",
        # model="llama3.3-70b",
        messages=[
            {
                "role": "system", 
                "content": system
            },
            {
                "role": "user",
                "content": prompt,
            }
        ],
        # max_tokens=512
    )
    return response.choices[0].message.content

sys_res = "Read the context, patient record, and answer the task question. If you don't know, just say you don't know."

sys_key = "Extract as much keywords and keyphrases most related to the task from the record as possible, but concisely and raw text."

In [18]:
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langgraph.graph import START, StateGraph
from typing_extensions import List, TypedDict
from langchain.document_loaders import PyPDFLoader

def chunking(sauce: str, chunk_size, chunk_overlap) -> str:
    print("Begin chunking...")
    loader = PyPDFLoader(sauce)
    docs = loader.load()
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
    print("Done!")
    return text_splitter.split_documents(docs)


# Path to the saved vector store
faiss_index_path = "/home/ngjabach/Documents/State-of-the-Art-Papers/SurgeryLLM (VIET + DONE)/faiss_index"
if os.path.exists(faiss_index_path):
    vector_store = FAISS.load_local(faiss_index_path, embeddings, allow_dangerous_deserialization=True)
else:
    all_splits = []
    all_splits = all_splits + chunking("ExternalDoc/labval.pdf", 200, 10)
    all_splits = all_splits + chunking("ExternalDoc/aortic.pdf", 400, 20)
    all_splits = all_splits + chunking("ExternalDoc/coronary.pdf", 400, 20)
    all_splits = all_splits + chunking("ExternalDoc/valvular.pdf", 400, 20)

    print("Begin vector storing...")
    _ = vector_store.add_documents(documents=all_splits)
    print("Done vector storing!")

    vector_store.save_local(faiss_index_path)

retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 20})

Begin chunking...
Done!
Begin chunking...
Done!
Begin chunking...
Done!
Begin chunking...
Done!
Begin vector storing...
Done vector storing!


In [19]:
prompt_res = '''
---
{reference}
---
A. Record: {record} 
B. Task: {task}
C. Answer:
'''

prompt_key = '''
A. Record: {record}
B. Task: {task}
'''

task1 = "Check the record and identify results outside of reference ranges."
task2 = "Identify unavailable preoperative tests."
task3 = "Surgical recommendation."
task4 = "Prepare sample operative notes."

record = '''
Patient: John Smith
Age: 58, Male
Medical History: Diabetes; multivessel coronary artery disease with left anterior descending (LAD) involvement
Vital Signs: BP 140/85 mmHg, HR 80 bpm
Laboratory Findings: Hemoglobin 9.0 g/dL
Preoperative Workup: Basic clinical assessment, coronary angiography
'''

In [20]:
log = []

# Define state for application
class State(TypedDict):
    task: str #
    record: str #
    reference: List[Document]
    answer: str

# Define application steps
def retrieve(state: State):
    print("Begin retriving!")
    messages = prompt_key.format(record=state["record"], task=state["task"])
    keywords = ask_gpt(messages, sys_key)
    log.append("\n\nDebug: Keywords\n\n" + keywords) #
    docs = retriever.invoke(keywords)
    print("Done retrieving!")
    return {"reference": docs}

def generate(state: State):
    print("Begin generating!")
    docs_content = "\n\n".join(doc.page_content for doc in state["reference"])
    messages = prompt_res.format(record = state["record"], task = state["task"], reference = docs_content)
    log.append("\n\nDebug: Prompt\n\n" + messages) #
    response = ask_gpt(messages, sys_res)
    print("Done generating!")
    return {"answer": response}

print("Compiling...")
graph_builder = StateGraph(State).add_sequence([retrieve, generate])
graph_builder.add_edge(START, "retrieve")
graph = graph_builder.compile()
print("Done compiling!")

Compiling...
Done compiling!


In [21]:
def logging(taskn: str, recordn: str, cnt: int):
    log.clear()
    print(f"Doing Task {cnt}")
    response = graph.invoke({"task": taskn, "record": recordn})
    log.append("\n\nDebug: Response\n\n" + response["answer"]) #
    with open(f"Result/Task{cnt}.txt", "w", encoding="utf-8") as file:
        file.write("\n".join(log))

logging(task1, record, 1)
logging(task2, record, 2)
logging(task3, record, 3)
logging(task4, record, 4)

Doing Task 1
Begin retriving!
Done retrieving!
Begin generating!
Done generating!
Doing Task 2
Begin retriving!
Done retrieving!
Begin generating!
Done generating!
Doing Task 3
Begin retriving!
Done retrieving!
Begin generating!
Done generating!
Doing Task 4
Begin retriving!
Done retrieving!
Begin generating!
Done generating!
