In [1]:
import os, faiss, pickle
import numpy as np
from sentence_transformers import SentenceTransformer
from dotenv import load_dotenv
from huggingface_hub import login

In [2]:

FAISS_INDEX_PATH = "../data/embeddings/faiss_index.bin"
METADATA_PATH = "../data/embeddings/metadata.pkl"
PDF_DIR = "../data/raw_pdfs"

load_dotenv(dotenv_path='../Secrets/.env')

login(token=os.environ['HF_TOKEN'])

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


In [3]:
# Load FAISS index
index = faiss.read_index(FAISS_INDEX_PATH)

# Load metadata
with open(METADATA_PATH, "rb") as f:
    metadata = pickle.load(f)
print(f"Loaded index with {index.ntotal} entries.")

Loaded index with 13895 entries.


In [4]:
embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

In [None]:
def retrieve_top_k(query, k=5, similarity_threshold=0):
    """
    Retrieve top-k relevant chunks from FAISS, grouped by page with snapshots & images.
    Apply similarity threshold to filter noisy matches.
    """

    # Encode query
    query_vec = embed_model.encode([query], convert_to_numpy=True)

    # Search FAISS index
    distances, indices = index.search(query_vec, k)

    # Convert L2 distance to cosine-like similarity (optional, depends on index type)
    # For IndexFlatL2: lower distance = better, so invert to similarity
    max_dist = np.max(distances)
    similarities = 1 - (distances / max_dist)

    # Step 1: Collect raw results (with threshold filtering)
    raw_results = []
    for idx, dist, sim in zip(indices[0], distances[0], similarities[0]):
        if idx == -1:
            continue
        if sim < similarity_threshold:   # skip low-similarity chunks
            continue

        chunk_meta = metadata[idx]
        raw_results.append({
            "content": chunk_meta["content"],
            "page_num": chunk_meta["page_num"],
            "pdf_file": chunk_meta["pdf_file"],
            "page_snapshot": chunk_meta.get("page_snapshot"),
            "images": chunk_meta.get("images", []),
            "captions": chunk_meta.get("captions", []),
            "distance": float(dist),
            "similarity": float(sim),
            "link": f"{PDF_DIR}/{chunk_meta['pdf_file']}#page={chunk_meta['page_num']}",
        })

    # Step 2: Group by page
    grouped = {}
    for r in raw_results:
        page = r["page_num"]
        if page not in grouped:
            grouped[page] = {
                "page_num": page,
                "pdf_file": r["pdf_file"],
                "content": [],
                "images": [],
                "captions": r.get("captions", []),
                "page_snapshot": r["page_snapshot"],
                "link": r["link"]
            }
        grouped[page]["content"].append(r["content"])
        grouped[page]["images"].extend(r["images"])

    # Step 3: Merge
    results = []
    for page, data in grouped.items():
        merged_text = " ".join(data["content"])
        snippet = merged_text[:150] + "..." if len(merged_text) > 150 else merged_text
        results.append({
            "page_num": page,
            "pdf_file": data["pdf_file"],
            "content": merged_text,
            "page_snapshot": data["page_snapshot"],
            "images": list(set(data["images"])),
            "captions": data.get("captions", []),
            "link": data["link"],
            "snippet": snippet
        })

    return sorted(results, key=lambda x: x["page_num"])

In [31]:
from numpy.linalg import norm

def filter_images_by_caption_similarity(query, captions, threshold=0.1):
    """
    Return only images whose caption similarity with query exceeds threshold.
    """
    # Encode query
    query_emb = embed_model.encode([query], convert_to_numpy=True)[0]
    relevant_images = []

    for cap in captions:
        if cap.get("embedding") is not None:
            cap_emb = np.array(cap["embedding"])
            sim = np.dot(query_emb, cap_emb) / (norm(query_emb) * norm(cap_emb))
            if sim >= threshold:
                relevant_images.append(cap["image_path"])
    return relevant_images

In [None]:
from huggingface_hub import InferenceClient
import os

# Initialize client
client = InferenceClient(
    provider="auto",
    timeout=400,
)


def generate_answer_multimodal(query, retrieved_chunks, model="unsloth/Mistral-Small-3.2-24B-Instruct-2506-GGUF"):
    """
    Generate structured answer (Definition, Causes, Diagnosis, Treatment) using text + images.
    """

    # Combine retrieved text context
    context_text = "\n\n".join([
        f"--- Page {c['page_num']} ---\nText:\n{c['content']}"
        for c in retrieved_chunks
    ])

    print(context_text)

    # Collect relevant image paths (page snapshots + extracted images)
    image_paths = []
    for c in retrieved_chunks:
        if c.get("page_snapshot"):
            image_paths.append(c["page_snapshot"])
        image_paths.extend(c.get("images", []))

    system_prompt = """
You are a knowledgeable medical assistant.

- Use ONLY the provided context (text) to answer the user query.
- Output strictly in English.
- If specific sections like Definition, Causes, Diagnosis, or Treatment are relevant, organize the answer using these sections.
- If the query does not fit these sections (e.g., a simple definition or general explanation), provide a concise, well-structured paragraph instead.
- If any requested information is missing in the context, explicitly state "Not mentioned in the document."
- Do NOT invent facts, add unrelated metadata, or include random symbols.
- Do NOT mention images unless explicitly described in the text.
- If you need to say about document then say According to my sources.

For structured answers (when applicable), use this format:

**Definition:**
<text>

**Causes:**
<text>

**Diagnosis:**
<text>

**Treatment:**
<text>

For general answers (when structured sections are irrelevant), provide a single detailed paragraph addressing the query clearly and factually and in detail.

For Example:
Query:
What is asthma and how is it diagnosed?

Answer:
**Definition:**
Asthma is a chronic inflammatory disease of the airways characterized by episodes of wheezing and breathlessness.

**Causes:**
- Allergic reactions
- Environmental triggers
- Respiratory infections

**Diagnosis:**
- Spirometry
- Peak flow measurement

**Treatment:**
- Inhaled corticosteroids
- Bronchodilators

For Example:
Query:
What is antigen?

Answer:
An Antigen (Ag) is a molecule, moiety, foreign particulate matter, or an allergen, such as pollen, that can bind to a specific antibody or T-cell receptor. The presence of antigens in the body may trigger an immune response.
Antigens can be proteins, peptides (amino acid chains), polysaccharides (chains of simple sugars), lipids, or nucleic acids.[3][4] Antigens exist on normal cells, cancer cells, parasites, viruses, fungi, and bacteria.

"""

    # ---- User Prompt ----
    user_prompt = f"""
Context:
{context_text}

User Question:
{query}
"""

    # Prepare messages for API
    messages = [
        {"role": "system", "content": [{"type": "text", "text": system_prompt}]},
        {"role": "user", "content": [{"type": "text", "text": user_prompt}]}
    ]

    # Attach images as multimodal content
    # for img_path in image_paths:
    #     if os.path.exists(img_path):
    #         messages[0]["content"].append({"type": "image", "image": img_path})

    # Call API (non-streaming to avoid timeout issues)
    response = client.chat.completions.create(
        model=model,
        messages=messages,
        stream=False,
        max_tokens=1000
    )

    # Extract text output
    answer = response.choices[0].message["content"]
    return answer

def rag_query_multimodal(query, k=5):
    """Retrieve text + images, generate multimodal structured answer."""
    # Retrieve top-k chunks (text + metadata)
    retrieved = retrieve_top_k(query, k=k)

    # Generate answer with text + image context
    answer = generate_answer_multimodal(query, retrieved, 
                                        model="google/gemma-3-27b-it"
                                        )

    # Prepare references with page, snapshots, images (caption-aware)
    references = []
    for r in retrieved:
        # Filter images by caption similarity
        relevant_images = filter_images_by_caption_similarity(
            query,
            r.get("captions", []),
            threshold=0.4
        )

        references.append({
            "page": r["page_num"],
            "link": r["link"],
            "snippet": r["snippet"],
            "page_snapshot": r.get("page_snapshot"),
            "images": relevant_images  # only relevant images
        })
    
    return answer, references

query1 = "What are the causes and treatment for Atherosclerosis in detail?"
query2 = "What is Hypertension?"
query3 = "What is Alkaline phosphatase test?"
query4 = "What is Hemoglobin?"
query5 = "How to cure Diabetes?"
query6 = "What is Hernia?"
answer, references = rag_query_multimodal(query4, k=50)

print("### Answer ###\n", answer)
print("\n### References ###")
for ref in references:
    print(f"- Page {ref['page']}: {ref['link']} ({ref['snippet']})")
    print(f"  Page Snapshot: {ref['page_snapshot']}")
    if ref['images']:
        print(f"  Images: {', '.join(ref['images'])}")

  return forward_call(*args, **kwargs)


--- Page 103 ---
Text:
They carry one gene for albinism.

--- Page 120 ---
Text:
Sandra Bain Cushman Alkali-resistant hemoglobin test see Fetal hemoglobin test Alkaline phosphatase test Definition Alkaline phosphatase is an enzyme found throughout the body.

--- Page 147 ---
Text:
This is known as hypoxemia.

--- Page 195 ---
Text:
(Hemoglobin is composed of four chains of amino acids.) HEMOLYTIC ANEMIA. An inherited form of hemolytic anemia, thalassemia stems from the body s inability to manufacture as much normal hemoglobin as it needs. Some people are born with hemolytic anemia . Hemolytic anemia can enlarge the spleen, accelerating the destruction of red blood cells (hemolysis). AUTOIMMUNE HEMOLYTIC ANEMIAS. Warm antibody hemolytic anemia is the most common type of this disorder. In cold antibody hemolytic anemia, the body attacks red blood cells at or below normal body temperature.

--- Page 196 ---
Text:
Sickle cell anemia is a chronic, incurable condition that causes the body to

In [None]:
# from PIL import Image
# import matplotlib.pyplot as plt
# import os

# def show_reference_images(references):
#     """
#     Display page snapshots and extracted images from references (non-Streamlit).
#     """

#     for ref in references:
#         page = ref["page"]
#         snapshot = ref.get("page_snapshot")
#         images = ref.get("images", [])

#         print(f"\n--- Page {page} ---")

#         # Show page snapshot if available
#         if snapshot and os.path.exists(snapshot):
#             print(f"Page Snapshot: {snapshot}")
#             img = Image.open(snapshot)
#             plt.figure(figsize=(8, 10))
#             plt.imshow(img)
#             plt.axis("off")
#             plt.title(f"Full Page {page}")
#             plt.show()

#         # Show extracted images (figures/tables)
#         if images:
#             for img_path in images:
#                 if os.path.exists(img_path):
#                     print(f"Extracted Image: {img_path}")
#                     img = Image.open(img_path)
#                     plt.figure(figsize=(6, 6))
#                     plt.imshow(img)
#                     plt.axis("off")
#                     plt.title(f"Figure from Page {page}")
#                     plt.show()

# show_reference_images(references)

NameError: name 'references' is not defined