In [None]:
from unstructured.partition.pdf import partition_pdf
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_chroma import Chroma
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
import os
import glob

file_path = "C:/Users/User/Downloads/tmj_rag_app/Backend/data/pdfs"


In [29]:
#Import env variables
from dotenv import load_dotenv

load_dotenv()



True

In [30]:
# Get all PDF files in the folder
pdf_files = glob.glob(os.path.join(file_path, "*.pdf"))
print(f"Found {len(pdf_files)} PDF files to process")

# Partition all PDF files
all_elements = []
for pdf_file in pdf_files:
    base_file_name = os.path.splitext(os.path.basename(pdf_file))[0]
    print(f"Processing: {base_file_name}.pdf")
    
    elements = partition_pdf(
        filename=pdf_file,
        strategy="auto",
        infer_table_structure=True
    )
    
    # Store elements with their source file name
    for element in elements:
        element._source_file = base_file_name  # Store source file name
    
    all_elements.extend(elements)
    print(f"  Extracted {len(elements)} elements from {base_file_name}.pdf")

print(f"\nTotal elements extracted: {len(all_elements)}")


Found 11 PDF files to process
Processing: tmjDoc1.pdf
  Extracted 251 elements from tmjDoc1.pdf
Processing: tmjDoc10.pdf
  Extracted 59 elements from tmjDoc10.pdf
Processing: tmjDoc11.pdf
  Extracted 19 elements from tmjDoc11.pdf
Processing: tmjDoc2.pdf
  Extracted 164 elements from tmjDoc2.pdf
Processing: tmjDoc3.pdf
  Extracted 33 elements from tmjDoc3.pdf
Processing: tmjDoc4.pdf
  Extracted 137 elements from tmjDoc4.pdf
Processing: tmjDoc5.pdf
  Extracted 143 elements from tmjDoc5.pdf
Processing: tmjDoc6.pdf
  Extracted 24 elements from tmjDoc6.pdf
Processing: tmjDoc7.pdf
  Extracted 63 elements from tmjDoc7.pdf
Processing: tmjDoc8.pdf
  Extracted 72 elements from tmjDoc8.pdf
Processing: tmjDoc9.pdf
  Extracted 37 elements from tmjDoc9.pdf

Total elements extracted: 1002


In [25]:
# Display the elements partitioned from the PDF file
for i, element in enumerate(elements):
    print(f"\n--- Element {i} ---")
    print(f"Type: {element.category}")
    print(f"Text: {element.text}")
    print(f"Metadata: {element.metadata}")


--- Element 0 ---
Type: Header
Text: 12/5/25, 3:44 PM
Metadata: <unstructured.documents.elements.ElementMetadata object at 0x0000026584B82030>

--- Element 1 ---
Type: Header
Text: TMD and Jaw Pain | NIDCR
Metadata: <unstructured.documents.elements.ElementMetadata object at 0x00000265872D7500>

--- Element 2 ---
Type: NarrativeText
Text: = An official website of the United States government Here’s how you know
Metadata: <unstructured.documents.elements.ElementMetadata object at 0x0000026584B80980>

--- Element 3 ---
Type: Image
Text: 
Metadata: <unstructured.documents.elements.ElementMetadata object at 0x0000026584B80B00>

--- Element 4 ---
Type: Title
Text: National Institute of Dental and Craniofacial Research
Metadata: <unstructured.documents.elements.ElementMetadata object at 0x0000026584B81370>

--- Element 5 ---
Type: NarrativeText
Text: </>
Metadata: <unstructured.documents.elements.ElementMetadata object at 0x0000026584B80E90>

--- Element 6 ---
Type: NarrativeText
Text: MENU


In [33]:
# Convert to LangChain Documents from all PDFs
docs = []
MIN_CHAR_LENGTH = 50

for el in all_elements:
    #Skip headers, dont contain useful info
    if el.category == "Header":
        continue
    
    text = getattr(el, "text", None)
    if not text:
        continue

    if (len(text.strip()) < MIN_CHAR_LENGTH):
        continue

    # Get source file name from the element
    source_file = getattr(el, "_source_file", "unknown")
    
    meta = el.to_dict().get("metadata", {}) or {}
    docs.append(
        Document(
            page_content=text,
            metadata={
                "source": source_file,
                "page_number": meta.get("page_number"),
                "type": el.category if hasattr(el, "category") else meta.get("type"),
            },
        )
    )

print(f"Total documents created: {len(docs)}")
print(f"Documents from {len(set(doc.metadata.get('source') for doc in docs))} different PDFs")

#Print first few documents as sample
for i, doc in enumerate(docs[:5], 1):  # Show first 5 as sample
    meta = doc.metadata
    print(f"--- Document {i} ---")
    print(f"Source: {meta.get('source')} | Page: {meta.get('page_number')} | Type: {meta.get('type')}")
    print(f"Content length: {len(doc.page_content)} characters")
    print(f"Content:\n{doc.page_content}")
    print()

splitter = RecursiveCharacterTextSplitter(
    chunk_size=1000, 
    chunk_overlap=200,
    add_start_index=True
)
chunked_docs = splitter.split_documents(docs)
print(f"Total chunks created: {len(chunked_docs)}")

Total documents created: 426
Documents from 11 different PDFs
--- Document 1 ---
Source: tmjDoc1 | Page: 1 | Type: NarrativeText
Content length: 73 characters
Content:
= An official website of the United States government Here’s how you know

--- Document 2 ---
Source: tmjDoc1 | Page: 1 | Type: NarrativeText
Content length: 54 characters
Content:
National Institute of Dental and Craniofacial Research

--- Document 3 ---
Source: tmjDoc1 | Page: 1 | Type: UncategorizedText
Content length: 60 characters
Content:
</espanol/temas-de-salud/los-trastornos-temporomandibulares>

--- Document 4 ---
Source: tmjDoc1 | Page: 2 | Type: NarrativeText
Content length: 221 characters
Content:
Temporomandibular disorders (TMDs) area group of more than 30 conditions that cause pain and dysfunction in the jaw joint and muscles that control jaw movement. “TMDs” refers to the disorders, and “TMJ” refers only to the

--- Document 5 ---
Source: tmjDoc1 | Page: 2 | Type: UncategorizedText
Content length: 81 cha

In [34]:
#Print the chunked documents
for i, doc in enumerate(chunked_docs, 1):
    meta = doc.metadata
    print(f"--- Chunk {i} ---")
    print(f"source={meta.get('source')} page={meta.get('page_number')} type={meta.get('type')} section={meta.get('section')}")
    print(doc.page_content)
    print()

--- Chunk 1 ---
source=tmjDoc1 page=1 type=NarrativeText section=None
= An official website of the United States government Here’s how you know

--- Chunk 2 ---
source=tmjDoc1 page=1 type=NarrativeText section=None
National Institute of Dental and Craniofacial Research

--- Chunk 3 ---
source=tmjDoc1 page=1 type=UncategorizedText section=None
</espanol/temas-de-salud/los-trastornos-temporomandibulares>

--- Chunk 4 ---
source=tmjDoc1 page=2 type=NarrativeText section=None
Temporomandibular disorders (TMDs) area group of more than 30 conditions that cause pain and dysfunction in the jaw joint and muscles that control jaw movement. “TMDs” refers to the disorders, and “TMJ” refers only to the

--- Chunk 5 ---
source=tmjDoc1 page=2 type=UncategorizedText section=None
Healthy temporomandibular joint during mouth opening & closing. temporomandibular

--- Chunk 6 ---
source=tmjDoc1 page=2 type=NarrativeText section=None
itself. People have two TMJs; one on each side of the jaw. You can feel t

In [35]:
#Create embeddings
embeddings = OpenAIEmbeddings(
    model="text-embedding-3-small",
)


In [36]:
#Create Chroma db
vector_store = Chroma(
    collection_name="tmj_rag_app",
    embedding_function=embeddings,
    persist_directory="../data/chroma_db"
)

In [37]:
#Add documents to the vector store and persist vector store
vector_store.add_documents(chunked_docs)

['3212d767-4dfd-40db-8ec3-a1399729568e',
 '5030bb42-0150-4b05-9a06-8db2abac6023',
 'e8a5a23a-5c34-4230-814b-07f7706ace4c',
 'd75343a5-16a6-4b3b-bd2e-bb0474555028',
 '1541b482-7a71-410f-b19c-d97c76a0a8c3',
 '650b3800-9f9b-40f5-80d0-e63ccfe9982e',
 '738f358c-725e-48cb-922e-81ac57007746',
 'b0fe93aa-64b1-48e5-bf90-847b7075d0ea',
 'a0ef09f9-2468-4c13-a4cb-df63746bfe30',
 '4e177ef2-ab90-490b-a4bd-3108f0c542fb',
 '67d39de9-6c6d-4932-9cc9-60dab2229930',
 '4e310e91-e325-45ec-8421-526a7b179540',
 '42f8bca1-3087-46be-ba62-4be27c42f441',
 'ca60c77c-4149-4a1e-b707-1b64cb0137ef',
 'a7ebcc67-33dc-4b10-ac2d-13c14d2a89d3',
 '0425052f-6693-46e5-8ce0-8220534e52b0',
 '25b0abd4-67d0-479e-82e9-fbcfa1e81255',
 'cdf8472e-aafc-4c7a-89e1-4d69610a34f1',
 'a5f7666a-a5d9-4bb3-b574-50fd9a36f11e',
 'de387a74-09d7-4c49-a60d-ee4bd8394a35',
 '51a0f8e8-64c5-4e8f-abce-483ae20c7b71',
 'c3b57993-c67e-4d1e-b7f8-2f02877cc89e',
 '52fc87a6-f5f9-4876-ab2a-f608c5365590',
 '5b11d875-d724-463f-978b-472e05cceb3c',
 'c05f5c9a-af62-

In [44]:
#Prompt template
medical_template = ChatPromptTemplate.from_messages([
    ("system",
     "You are a medical information assistant specializing in TMJ disorders. "
     "Answer ONLY using the provided context. "
     "If the context does not contain enough information, say so clearly. "
     "Use precise medical terminology. "
     "This is for informational purposes only and does not replace professional medical advice."),
    ("human",
     "Context:\n{context}\n\nQuestion:\n{question}\n\nAnswer:")
])

#Build retriever from vector store
retriever = vector_store.as_retriever(
    searchType="mmr",
    searchKwargs={"k": 5}
)


#Function to format docs and create RunnableLambda
def format_docs(docs):
    formatted = "\n\n".join(doc.page_content for doc in docs)
    print(formatted)
    return formatted

format_docs_runnable = RunnableLambda(format_docs)

output_parser = StrOutputParser()

In [39]:
#Initialize LLM
llm = ChatOpenAI(
    model="gpt-4o-mini",
    temperature = 0.0
)

In [45]:
#Create RAG chain
rag_chain = (
    {
        "context": retriever | format_docs_runnable,
        "question": RunnablePassthrough(),
    }
    | medical_template
    | llm
    | output_parser  
)

In [48]:
#Query the chain
answer = rag_chain.invoke("What habits make TMJ worse?")
print(answer)

Factors that may raise the risk of getting TMJ disorders include:

Arthritis is one cause of TMJ symptoms. It can result from an injury or from grinding the teeth at night. Another common cause involves displacement or dislocation of the disk that is located between the jawbone and the socket. A displaced disk may produce clicking or popping sounds, limit jaw movement and cause pain when opening and closing the mouth.

The exact cause of TMJ disorder is often hard to determine. The pain may be due to a mix of factors, including habits such as teeth clenching, gum chewing and nail biting; stress; and painful conditions that occur along with TMJ disorder such as fibromyalgia, osteoarthritis or jaw injury. The habit of teeth clenching or grinding also is known as bruxism.

TMJ disorders [https://medlineplus.gov/ency/article/001227 htm]
Habits that may exacerbate TMJ disorders include teeth clenching, gum chewing, nail biting, and bruxism (grinding of the teeth).
