In [None]:
import os
os.chdir("../")
%pwd

'C:\\Users\\ankit\\Desktop\\medipyrag'

In [2]:
# imports of all funtion
from pypdf import PdfReader, PdfWriter
from langchain.document_loaders import PyPDFLoader, DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
import warnings
import torch
import os
from typing import List
from langchain.schema import Document
from langchain_huggingface import HuggingFaceEmbeddings
from astrapy import DataAPIClient
from astrapy.info import CollectionDefinition, CollectionVectorOptions
from astrapy.constants import VectorMetric
from astrapy.data.info.vectorize import VectorServiceOptions
from langchain_astradb import AstraDBVectorStore
from langchain.chains import create_retrieval_chain
from langchain_google_genai import ChatGoogleGenerativeAI

from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain_nvidia_ai_endpoints import ChatNVIDIA
from langchain_core.prompts import PromptTemplate
from dotenv import load_dotenv
warnings.filterwarnings("ignore", category=FutureWarning, module="torch")


In [None]:
import config
import json
from langchain_core.prompts import ChatPromptTemplate
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from src.autorag import auto_retriever
from langchain_core.pydantic_v1 import BaseModel, Field


# ✅ Schema for structured medical response
class MedicalResponse(BaseModel):
    definition: str = Field(..., description="Definition of the disease/condition")
    causes_risk_factors: str = Field(..., description="Causes and risk factors")
    symptoms: str = Field(..., description="Symptoms of the condition")
    diagnosis: str = Field(..., description="Diagnostic methods")
    treatment_cure: str = Field(..., description="Treatment or cure")
    prognosis_complications: str = Field(..., description="Prognosis and complications")
    prevention_lifestyle: str = Field(..., description="Prevention and lifestyle advice")
    additional_notes: str = Field(..., description="Any additional notes")


def genllm(question: str = "What is acne?", structured: bool = False):
    """
    Generates a medical response.
    If structured=True -> returns JSON (all fields from MedicalResponse).
    If structured=False -> returns plain text.
    """

    # ✅ Initialize Gemini LLM
    llm = ChatGoogleGenerativeAI(
        api_key=config.gemini_key,
        model=config.gemini_model,
        max_retries=2,
    )

    # ✅ Enforce schema if structured=True
    if structured:
        llm = llm.with_structured_output(MedicalResponse)

    # ✅ Prompt template
    prompt_template = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                """
You are a highly knowledgeable **medical AI assistant**.
Always produce a structured response with ALL sections.

If there is no relevant info in the context, write: "Not available in the provided context."
                """,
            ),
            ("human", "{input}"),
        ]
    )

    # ✅ Build chain
    question_answer_chain = create_stuff_documents_chain(llm, prompt_template)
    retrieval_chain = create_retrieval_chain(
        retriever=auto_retriever,
        combine_docs_chain=question_answer_chain,
    )

    # ✅ Run
    response = retrieval_chain.invoke({"input": question})

    if structured:
        # response["answer"] is already a Pydantic object
        return json.loads(response["answer"].json())  # ✅ JSON-safe dict
    else:
        return {"answer": response["answer"]}  # ✅ wrap plain text in JSON-safe dict


# ✅ Local test
if __name__ == "__main__":
    print("\n--- Unstructured ---")
    print(genllm("treatment of cancer", structured=False))

    print("\n--- Structured ---")
    print(genllm("treatment of cancer", structured=True))


In [None]:
# this block for skip pages if needed otherwise skis one
reader=PdfReader("data/Medical_book.pdf")
writer=PdfWriter()

skip_pages=range(0,14)

for i in range(len(reader.pages)):
    if i not in skip_pages:
        writer.add_page(reader.pages[i])

# writer.write("data/medical_book_fixed.pdf")
with open("data/medical_book_fixed.pdf", "wb") as f:
    writer.write(f)


In [3]:
def load_docs(filepath):
    loader = PyPDFLoader(filepath)
    docs = loader.load()
    return docs
ext_docs=load_docs("data/medical_book_fixed.pdf")
# ext_docs
len(ext_docs)

623

In [4]:
def filter_to_minimal_docs(docs: List[Document]) -> List[Document]:
    minimal_docs: List[Document] = []
    for doc in docs:
        src=doc.metadata.get("source")
        minimal_docs.append(Document(page_content=doc.page_content, metadata={"source": src}))

    return minimal_docs
    # return [doc for doc in docs if len(doc.page_content) >= min_length]
minimal_docs=filter_to_minimal_docs(ext_docs)
# minimal_docs

In [5]:
def text_splitter(minimal_docs):
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
    texts_chunk = text_splitter.split_documents(minimal_docs)
    return texts_chunk
texts_chunk=text_splitter(minimal_docs)
len(texts_chunk)
# texts_chunk

3123

In [6]:
def get_embeddings():
    # torch.cuda.is_available()
    print(torch.cuda.get_device_name(0))
    model_name = "sentence-transformers/all-MiniLM-L6-v2"
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    embeddings = HuggingFaceEmbeddings(
        model_name=model_name,
        model_kwargs={"device": device}
    )
    return embeddings

embeddings = get_embeddings()
len(embeddings.embed_query("Hello world"))


NVIDIA GeForce RTX 3050 Laptop GPU
Using device: cuda


384

In [7]:
load_dotenv()
db_endpoint=os.getenv("ASTRADB_ENDPOINT")
db_token=os.getenv("ASTRA_TOKEN")
db_keyspace=os.getenv("DB_KEYSPACE")
db_collection_name=os.getenv("DB_COLLECTION_NAME")


In [14]:
# initialize database
client = DataAPIClient(db_token)
key_spaces="medical_lab"

# Get database reference
database = client.get_database(db_endpoint)
collection = database.get_collection(db_collection_name)
print(f"database connection established database_name: {database.name()}, and collection_name={collection.name}")

database connection established database_name: rag, and collection_name=medical_book


In [15]:
collection_name = "medical_book"
model_name = "sentence-transformers/all-MiniLM-L6-v2"
EMBEDDING_PROVIDER="huggingface"
EMBEDDING_MODEL_NAME=embeddings

# ✅ Create collection manually (no namespace needed)
if collection_name not in database.list_collection_names():
    # Define the collection
    collection_definition = CollectionDefinition(
        vector=CollectionVectorOptions(
            metric=VectorMetric.COSINE,
            dimension=384,
            service=VectorServiceOptions(
                provider="huggingface",
                model_name=model_name,
            )
        )
    )

    collection = database.create_collection(
        collection_name,
        definition=collection_definition,
    )

    print(f"* Collection: {collection.full_name}\n")
else:
    print(f"Collection '{collection_name}' already exists")


Collection 'medical_book' already exists


In [35]:
vstore = AstraDBVectorStore.from_documents(
    documents=[],  # start empty
    embedding=embeddings,
    api_endpoint=db_endpoint,
    token=db_token,
    collection_name=collection_name,
    batch_size=10,
)

# ✅ Insert documents
inserted_ids = vstore.add_documents(texts_chunk)
print(f"Inserted {len(inserted_ids)} documents.")


In [11]:
existing_vstore= AstraDBVectorStore(
    api_endpoint=db_endpoint,
    token=db_token,
    collection_name=collection_name,
    embedding=embeddings,
)

In [None]:
results = existing_vstore.similarity_search_with_score("urine test", k=3)
for doc, score in results:
    print("Text:", doc.page_content)
    print("Score:", score)


In [15]:
retriever=existing_vstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})

In [None]:
retrieved_docs=retriever.invoke("What is the treatment of diabetes?")
# retrieved_docs

option1: raw data in output

In [62]:
import os
from langchain_core.prompts import ChatPromptTemplate
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain

# ✅ Gemini LLM setup
llm = ChatGoogleGenerativeAI(
    api_key=os.getenv("GEMINI_KEY"),
    model="gemini-2.5-flash",
    temperature=0,       # Consistent, structured answers
    max_tokens=5000,
    max_retries=2,
)

# ✅ Force all sections
prompt_template = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            """
You are a highly knowledgeable **medical AI assistant**.
You MUST always produce a structured response with ALL sections listed below.
If there is no relevant information, explicitly write: "Not available in the provided context."

### Response Format (always include ALL 8 headers):
1. **Definition:**
2. **Causes & Risk Factors:**
3. **Symptoms:**
4. **Diagnosis:**
5. **Treatment / Cure:**
6. **Prognosis / Complications:**
7. **Prevention & Lifestyle Advice:**
8. **Additional Notes:**

---
**Context:**
{context}

**Question:**
{input}

**Structured Medical Response:**
""",
        ),
        ("human", "Answer the question above in the exact structured format."),
    ]
)

# ✅ Build document → LLM chain
question_answer_chain = create_stuff_documents_chain(llm, prompt_template)

# ✅ Retrieval chain
retrieval_chains = create_retrieval_chain(
    retriever=retriever,
    combine_docs_chain=question_answer_chain,
)

# ✅ Run test
response = retrieval_chains.invoke({"input": "What is acne?"})
print(response["answer"])


1.  **Definition:**
    Acne is a common skin disease characterized by pimples on the face, chest, and back. It occurs when the pores of the skin become clogged with oil (sebum), dead skin cells, and bacteria. The medical term for common acne is Acne vulgaris, and it is the most common skin disease, affecting nearly 17 million people in the United States.

2.  **Causes & Risk Factors:**
    The exact cause of acne is unknown, but several factors contribute to its development:
    *   **Blocked Pores/Hair Follicles:** Pores or hair follicles become blocked, preventing sebum (a waxy material that normally flows out onto the skin and hair) from exiting. This leads to the collection of sebum, bacteria, and dead skin cells inside the pores or follicles.
    *   **Bacterial Invasion:** Plugged follicles can be invaded by *Propionibacterium acnes*, a bacteria that normally lives on the skin, leading to inflammation.
    *   **Hormonal Changes:** Teenagers are more likely to develop acne due t

option 2: structured output

In [68]:
import os
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.prompts import ChatPromptTemplate
from typing import List

# --------- Define Schema ----------

class MedicalResponse(BaseModel):
    definition: str = Field(..., description="Brief definition of the condition.")
    causes_risk_factors: List[str] = Field(..., description="Causes and risk factors as a list.")
    symptoms: List[str] = Field(..., description="List of symptoms.")
    diagnosis: List[str] = Field(..., description="Diagnostic methods.")
    treatment_cure: List[str] = Field(..., description="Treatment or cure methods.")
    prognosis_complications: List[str] = Field(..., description="Prognosis and possible complications.")
    prevention_lifestyle: List[str] = Field(..., description="Prevention and lifestyle advice.")
    additional_notes: List[str] = Field(..., description="Additional notes.")

# --------- Initialize Gemini ----------
llm = ChatGoogleGenerativeAI(
    api_key=os.getenv("GEMINI_KEY"),
    model="gemini-2.5-flash",
    temperature=0,
    max_tokens=5000,
)

# ✅ Wrap LLM with structured output
structured_llm = llm.with_structured_output(MedicalResponse)

# --------- Retrieval Prompt ----------
prompt = ChatPromptTemplate.from_template("""
You are a medical AI assistant.
Use the following retrieved context to answer the question in a structured way.
If information is missing, respond with "Not available in the provided context."

Context:
{context}

Question:
{question}
""")

# --------- Build Retrieval Chain ----------
class RetrievalChain:
    def __init__(self, retriever):
        self.retriever = retriever

    def invoke(self, question: str) -> MedicalResponse:
        # Step 1: retrieve docs
        docs = self.retriever.get_relevant_documents(question)
        context_text = "\n\n".join([d.page_content for d in docs])

        # Step 2: format prompt
        formatted_prompt = prompt.format(context=context_text, question=question)

        # Step 3: structured LLM call → returns MedicalResponse
        return structured_llm.invoke(formatted_prompt)

# --------- Example Usage ----------
retrieval_chain = RetrievalChain(retriever)

response = retrieval_chain.invoke("What is acne?")
# print(response)
print(f"response = {response.json(indent=2)}")  # Pretty JSON


response = {
  "definition": "Acne is a common skin disease characterized by pimples on the face, chest, and back. It occurs when the pores of the skin become clogged with oil, dead skin cells, and bacteria.",
  "causes_risk_factors": [
    "The exact cause of acne is unknown.",
    "Age: Due to hormonal changes, teenagers are more likely to develop acne.",
    "Gender: Boys have more severe acne and develop it more often than girls.",
    "Pores or hair follicles become blocked, allowing sebum to collect inside.",
    "Bacteria and dead skin cells can collect, causing inflammation.",
    "Plugged follicle is invaded by Propionibacterium acnes, a bacteria that normally lives on the skin.",
    "A pimple forms when the damaged follicle weakens and bursts open, releasing sebum, bacteria, and skin and white blood cells into the surrounding tissues."
  ],
  "symptoms": [
    "Small swellings on the skin surface.",
    "Whiteheads or blackheads (small, not inflamed swellings).",
    "Pimple