In [1]:
from unstructured.partition.pdf import partition_pdf
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_chroma import Chroma
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser

file_path = "C:/Users/User/Downloads/tmj_rag_app/data/pdfs"
base_file_name = "tmjDoc2"


  from .autonotebook import tqdm as notebook_tqdm


In [10]:
#Import env variables
from dotenv import load_dotenv

load_dotenv()



True

In [58]:
# Partition the PDF file
elements = partition_pdf(
    filename=f"{file_path}/{base_file_name}.pdf",
    strategy="hi_res",
    infer_table_structure=True
)





In [None]:
# Display the elements partitioned from the PDF file
for i, element in enumerate(elements):
    print(f"\n--- Element {i} ---")
    print(f"Type: {element.category}")
    print(f"Text preview: {element.text}...")
    print(f"Metadata: {element.metadata}")

In [59]:
# Convert to LangChain Documents
docs = []
for el in elements:
    #Skip headers, dont contain useful info
    if el.category == "Header":
        continue
    
    text = getattr(el, "text", None)
    if not text:
        continue
    meta = el.to_dict().get("metadata", {}) or {}
    docs.append(
        Document(
            page_content=text,
            metadata={
                "source": base_file_name,
                "page_number": meta.get("page_number"),
                "type": el.category if hasattr(el, "category") else meta.get("type"),
            },
        )
    )

splitter = RecursiveCharacterTextSplitter(
    chunk_size=1000, 
    chunk_overlap=200,
    add_start_index=True
)
chunked_docs = splitter.split_documents(docs)

In [None]:
#Print the chunked documents
for i, doc in enumerate(chunked_docs, 1):
    meta = doc.metadata
    print(f"--- Chunk {i} ---")
    print(f"source={meta.get('source')} page={meta.get('page_number')} type={meta.get('type')} section={meta.get('section')}")
    print(doc.page_content)
    print()

In [4]:
#Create embeddings
embeddings = OpenAIEmbeddings(
    model="text-embedding-3-small",
)


In [5]:
#Create Chroma db
vector_store = Chroma(
    collection_name="tmj_rag_app",
    embedding_function=embeddings,
    persist_directory="../data/chroma_db"
)

In [62]:
#Add documents to the vector store and persist vector store
vector_store.add_documents(chunked_docs)

['68159ad0-2139-411c-af69-b6d2ac22d967',
 'a32713d8-2cf7-4ddd-835c-1e4a308445fe',
 'ae6f5dd4-a6bd-4473-907b-90d5baf0d462',
 '0b5ee6b2-e0f8-4600-9273-d3ec26a2aa74',
 'd48f0708-3918-489b-a913-b13fdce3c43d',
 '0a392bc3-cbb7-4b50-b8a6-61ace5ddb6e2',
 '5cd46457-a4ef-4324-bd6a-f920963560c2',
 'f147b413-c993-4e48-82e4-2f2cc85464b3',
 '1b71ba49-e9e9-40b3-b42f-b5760b96b9ca',
 'de09603e-c3e5-4cf2-b360-003fe06601e4',
 '6ab9438f-4bb3-46dd-8d3b-ebe0054e1cdc',
 '16e67336-897c-4f81-bd96-7811e4e777ec',
 'e0b11cf3-5b39-4ba3-8257-a25085ab277a',
 'defef2c3-d09b-4e5e-8980-9ee674a34622',
 '7bb15130-22e4-45bc-8aee-4b9262eaa3d5',
 'a4d456bd-2f14-42d5-9f61-2c1ac83a188b',
 'b6ac6736-fb3d-45b8-ad00-c8de407f0e36',
 '91123220-85bd-4f02-94b1-bbc81e16ce75',
 '4f82406a-15a4-4e62-ba92-98731303d10e',
 '954fdb08-4e3f-4722-a5c4-feb6a56ed985',
 '1f75e589-a8e1-4167-85a3-a3aba8ecfa3e',
 '5da46d03-720d-4969-8729-4833d15df91c',
 '8f61bb98-ead5-446f-8aee-f3191fa7c8ca',
 '11e061da-db86-4c99-b5c9-dc82c70ae01f',
 '397917cf-71e4-

In [11]:
#Prompt template
medical_template = ChatPromptTemplate.from_messages([
    ("system",
     "You are a medical information assistant specializing in TMJ disorders. "
     "Answer ONLY using the provided context. "
     "If the context does not contain enough information, say so clearly. "
     "Use precise medical terminology. "
     "This is for informational purposes only and does not replace professional medical advice."),
    ("human",
     "Context:\n{context}\n\nQuestion:\n{question}\n\nAnswer:")
])

#Build retriever from vector store
retriever = vector_store.as_retriever(
    searchType="mmr",
    searchKwargs={"k": 3}
)


#Function to format docs and create RunnableLambda
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

format_docs_runnable = RunnableLambda(format_docs)

output_parser = StrOutputParser()

In [12]:
#Initialize LLM
llm = ChatOpenAI(
    model="gpt-4o-mini",
    temperature = 0.0
)

In [13]:
#Create RAG chain
rag_chain = (
    {
        "context": retriever | format_docs_runnable,
        "question": RunnablePassthrough(),
    }
    | medical_template
    | llm
    | output_parser  
)

In [14]:
#Query the chain
answer = rag_chain.invoke("Explain arthroscopy for TMJ")
print(answer)

Arthroscopy for TMJ (temporomandibular joint) is a minimally invasive surgical procedure used to diagnose and treat various disorders of the TMJ. During the procedure, a small camera (arthroscope) is inserted into the joint space through a small incision. This allows the surgeon to visualize the internal structures of the joint, including the articular disc, ligaments, and surrounding tissues. 

Arthroscopy can be used to remove adhesions, repair damaged tissues, or address other abnormalities within the joint. It is typically indicated for patients who have not responded to conservative treatments such as physical therapy, medications, or splints. The benefits of arthroscopy include reduced recovery time, less postoperative pain, and minimal scarring compared to open joint surgery.
