# Method 6 — Intelligent (Layout-Aware) Multimodal Chunker + Summaries + Embeddings + FAISS


This notebook demonstrates a **layout-aware / intelligent multimodal chunking** pipeline:

1. **Parse with a layout-aware parser** (choose one):
   - **Azure Document Intelligence (prebuilt-layout)** (recommended if you have Azure keys)
   - **LandingAI ADE Parse** (if you already have the ADE JSON)
   - **Bring-your-own JSON** in the same schema (fallback)
2. Convert parser output into **typed elements** (`text`, `table`, `figure`) with:
   - **page number**, **bounding box / polygon**, (optional) **confidence**
3. **Hierarchical chunking** of text based on headings/sections when available.
4. **Crop figures/tables from the PDF** using their bounding regions.
5. Create **summaries**:
   - Figures: GPT-4o vision description (or BLIP if you want offline)
   - Tables: summarize from HTML/markdown
6. Embed all retrieval units (text chunks + figure summaries + table summaries) and index in **FAISS/VectorDB (Azure AI Search, Pinecone, Milvus, Vespa, etc - recommended in Production)**.
7. Inference:
   - Retrieve text/table chunks → generate an answer **with citations** (chunk ids + pages)
   - Use citations to **restrict figure search space** to cited pages (±1 window)
   - Retrieve the most relevant figures within that narrowed search space

> Default embeddings: **Cohere text embeddings** for all summaries/chunks. You can swap models easily.

## Install (run once)

```bash
pip install pymupdf pillow faiss-cpu tqdm cohere openai azure-ai-documentintelligence
```


In [None]:
# =========================
# 0) Config
# =========================
import os

PDF_PATH = "YOUR_PDF_HERE.pdf"   # <-- set this
ARTIFACTS_DIR = "artifacts_layout"
os.makedirs(ARTIFACTS_DIR, exist_ok=True)

# Choose parser: "azure_di" | "landingai_json" | "json_file"
PARSER_MODE = "json_file"

# --- If PARSER_MODE == "azure_di" ---
AZURE_DI_ENDPOINT = os.environ.get("AZURE_DI_ENDPOINT", "")
AZURE_DI_KEY = os.environ.get("AZURE_DI_KEY", "")
AZURE_DI_MODEL_ID = "prebuilt-layout"

# --- If PARSER_MODE == "landingai_json" or "json_file" ---
PARSE_JSON_PATH = "YOUR_PARSE_OUTPUT.json"  # ADE parse JSON or your own JSON in expected schema

# Summaries / answers (requires OpenAI)
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
OPENAI_VISION_MODEL = "gpt-4o-mini"
OPENAI_TEXT_MODEL = "gpt-4o-mini"

# Embeddings (requires Cohere)
COHERE_API_KEY = os.environ.get("COHERE_API_KEY", "")
COHERE_TEXT_MODEL = "embed-multilingual-v3.0"

# Retrieval parameters
TOP_K_TEXT = 12
TOP_K_ASSETS = 6
PAGE_WINDOW = 1  # cited pages ± this window for figures/tables


In [None]:
# =========================
# 1) Schemas + helpers
# =========================
from dataclasses import dataclass, asdict
from typing import List, Dict, Any, Optional, Tuple
import re, json, base64
import numpy as np

@dataclass
class Element:
    element_id: str
    element_type: str  # "text" | "table" | "figure"
    page: int          # 1-based
    content: str       # markdown/text for text, html/markdown for table, "" for figure
    bbox: Optional[Dict[str, float]] = None     # normalized {left, top, right, bottom} in [0,1] if available
    polygon: Optional[List[float]] = None       # Azure polygon (x1,y1,x2,y2,...)
    confidence: Optional[float] = None
    heading_path: Optional[List[str]] = None

@dataclass
class RetrievalUnit:
    unit_id: str
    unit_type: str     # "text" | "table" | "figure"
    page: int
    text_for_embedding: str
    raw_content_ref: Optional[str] = None
    meta: Optional[Dict[str, Any]] = None

def clean_anchor_tags(s: str) -> str:
    return re.sub(r"<a id=['\"][a-f0-9\-]+['\"]></a>\s*", "", s).strip()

def ensure_dir(p: str):
    import os
    os.makedirs(p, exist_ok=True)
    return p


In [None]:
# =========================
# 2) Parse document with a layout-aware parser
# =========================
from typing import List
import os, json

def parse_from_json(parse_json: Dict[str, Any]) -> List[Element]:
    # Supports LandingAI ADE-like schema: parse_json['chunks'] with type/id/markdown/text/grounding.
    chunks = parse_json.get("chunks", [])
    elements: List[Element] = []

    for ch in chunks:
        t = ch.get("type", "text")
        if t not in ("text", "table", "figure"):
            t = "text"

        cid = ch.get("id") or ch.get("chunk_id") or f"chunk_{len(elements)+1}"
        grounding = ch.get("grounding", {}) or {}
        page = grounding.get("page")
        if page is None:
            page = ch.get("page") or 1
        page = int(page)

        box = grounding.get("box") or grounding.get("bbox") or None
        conf = grounding.get("confidence") or ch.get("confidence")

        content = ch.get("markdown") or ch.get("text") or ""
        content = clean_anchor_tags(content)

        elements.append(Element(
            element_id=str(cid),
            element_type=t,
            page=page,
            content=content,
            bbox=box,
            confidence=conf
        ))
    return elements

def parse_with_azure_document_intelligence(pdf_path: str) -> List[Element]:
    # Calls Azure DI Layout model and converts output into Elements.
    from azure.core.credentials import AzureKeyCredential
    from azure.ai.documentintelligence import DocumentIntelligenceClient
    from azure.ai.documentintelligence.models import ContentFormat

    if not AZURE_DI_ENDPOINT or not AZURE_DI_KEY:
        raise ValueError("Set AZURE_DI_ENDPOINT and AZURE_DI_KEY env vars.")

    client = DocumentIntelligenceClient(endpoint=AZURE_DI_ENDPOINT, credential=AzureKeyCredential(AZURE_DI_KEY))

    with open(pdf_path, "rb") as f:
        poller = client.begin_analyze_document(
            model_id=AZURE_DI_MODEL_ID,
            analyze_request=f,
            output_content_format=ContentFormat.MARKDOWN,
        )
    result = poller.result()

    elements: List[Element] = []

    if getattr(result, "paragraphs", None):
        for i, p in enumerate(result.paragraphs):
            page = 1
            poly = None
            if p.bounding_regions:
                page = p.bounding_regions[0].page_number
                poly = p.bounding_regions[0].polygon
            elements.append(Element(
                element_id=f"az_para_{i}",
                element_type="text",
                page=int(page),
                content=(p.content or "").strip(),
                polygon=list(poly) if poly else None,
                confidence=getattr(p, "confidence", None)
            ))

    if getattr(result, "tables", None):
        for i, t in enumerate(result.tables):
            page = 1
            poly = None
            if t.bounding_regions:
                page = t.bounding_regions[0].page_number
                poly = t.bounding_regions[0].polygon

            max_row = max([c.row_index for c in t.cells], default=-1)
            max_col = max([c.column_index for c in t.cells], default=-1)
            grid = [["" for _ in range(max_col+1)] for _ in range(max_row+1)]
            for c in t.cells:
                grid[c.row_index][c.column_index] = (c.content or "").strip()

            html = "<table>" + "".join(
                "<tr>" + "".join(f"<td>{cell}</td>" for cell in row) + "</tr>"
                for row in grid
            ) + "</table>"

            elements.append(Element(
                element_id=f"az_table_{i}",
                element_type="table",
                page=int(page),
                content=html,
                polygon=list(poly) if poly else None
            ))

    return elements

if PARSER_MODE == "azure_di":
    elements = parse_with_azure_document_intelligence(PDF_PATH)
elif PARSER_MODE in ("landingai_json", "json_file"):
    with open(PARSE_JSON_PATH, "r", encoding="utf-8") as f:
        parse_json = json.load(f)
    elements = parse_from_json(parse_json)
else:
    raise ValueError("Unknown PARSER_MODE")

print("Parsed elements:", len(elements))
print("Types:", {e.element_type for e in elements})
print("Example:", elements[0] if elements else None)

with open(os.path.join(ARTIFACTS_DIR, "elements_raw.json"), "w", encoding="utf-8") as f:
    json.dump([asdict(e) for e in elements], f, ensure_ascii=False, indent=2)


In [None]:
# =========================
# 3) Build hierarchy + intelligent chunking
# =========================
def guess_is_heading(text: str) -> bool:
    t = text.strip()
    if len(t) < 3:
        return False
    if re.match(r"^(\d+(\.\d+)*)\s+\S+", t):
        return True
    if t.isupper() and len(t) <= 80:
        return True
    if re.match(r"^(chapter|section)\b", t.lower()):
        return True
    return False

def assign_heading_paths(elements: List[Element]) -> List[Element]:
    els = sorted(elements, key=lambda e: (e.page, e.element_id))
    current_path: List[str] = []
    for e in els:
        if e.element_type != "text":
            continue
        if guess_is_heading(e.content):
            current_path = [e.content.strip()]
            e.heading_path = current_path.copy()
        else:
            e.heading_path = current_path.copy() if current_path else []
    return els

def section_chunk_text(elements: List[Element], max_chars: int = 1400, overlap: int = 200) -> List[RetrievalUnit]:
    units: List[RetrievalUnit] = []
    buf = ""
    buf_meta = None

    def flush():
        nonlocal buf, buf_meta
        if buf_meta and buf.strip():
            same_prefix = [u for u in units if u.meta and u.meta.get("prefix")==buf_meta["prefix"]]
            uid = f"{buf_meta['prefix']}_{len(same_prefix)+1}"
            units.append(RetrievalUnit(
                unit_id=uid,
                unit_type="text",
                page=buf_meta["page"],
                text_for_embedding=buf.strip(),
                meta=buf_meta
            ))
        buf = ""
        buf_meta = None

    for e in elements:
        if e.element_type != "text":
            continue
        heading = " > ".join(e.heading_path or []) if e.heading_path else ""
        prefix = f"p{e.page}_sec{abs(hash(heading))%10_000}"

        if buf_meta is None:
            buf_meta = {"page": e.page, "heading": heading, "prefix": prefix}

        if (buf_meta["page"] != e.page) or (buf_meta["heading"] != heading):
            flush()
            buf_meta = {"page": e.page, "heading": heading, "prefix": prefix}

        piece = e.content.strip()
        if not piece:
            continue

        if len(buf) + len(piece) + 1 > max_chars:
            flush()
            buf = piece[-overlap:] if overlap and len(piece) > overlap else piece
            buf_meta = {"page": e.page, "heading": heading, "prefix": prefix}
        else:
            buf = (buf + "\n" + piece).strip() if buf else piece

    flush()
    return units

elements_sorted = assign_heading_paths(elements)
text_units = section_chunk_text(elements_sorted)

print("Text retrieval units:", len(text_units))
print("Example text unit:", text_units[0] if text_units else None)

with open(os.path.join(ARTIFACTS_DIR, "text_units.json"), "w", encoding="utf-8") as f:
    json.dump([asdict(u) for u in text_units], f, ensure_ascii=False, indent=2)


In [None]:
# =========================
# 4) Crop figures and tables from PDF using bounding boxes
# =========================
import fitz  # PyMuPDF

fig_dir = ensure_dir(os.path.join(ARTIFACTS_DIR, "figures"))
tbl_dir = ensure_dir(os.path.join(ARTIFACTS_DIR, "tables"))

doc = fitz.open(PDF_PATH)

def rect_from_bbox(page, bbox: Dict[str, float]) -> fitz.Rect:
    return fitz.Rect(
        bbox["left"] * page.rect.width,
        bbox["top"] * page.rect.height,
        bbox["right"] * page.rect.width,
        bbox["bottom"] * page.rect.height
    )

def rect_from_polygon(page, polygon: List[float], unit_hint: str = "points") -> fitz.Rect:
    xs = polygon[0::2]
    ys = polygon[1::2]
    x0, x1 = min(xs), max(xs)
    y0, y1 = min(ys), max(ys)
    if unit_hint == "inch":
        x0, x1, y0, y1 = x0*72.0, x1*72.0, y0*72.0, y1*72.0
    return fitz.Rect(x0, y0, x1, y1)

asset_units: List[RetrievalUnit] = []

for e in elements_sorted:
    if e.element_type not in ("figure", "table"):
        continue
    if e.page < 1 or e.page > len(doc):
        continue
    page = doc.load_page(e.page - 1)

    rect = None
    if e.bbox:
        rect = rect_from_bbox(page, e.bbox)
    elif e.polygon:
        rect = rect_from_polygon(page, e.polygon, unit_hint="points")

    if rect is None:
        continue

    pix = page.get_pixmap(clip=rect, dpi=200)
    img_bytes = pix.tobytes("png")

    if e.element_type == "figure":
        out_path = os.path.join(fig_dir, f"{e.element_id}.png")
        with open(out_path, "wb") as f:
            f.write(img_bytes)
        asset_units.append(RetrievalUnit(
            unit_id=f"fig_{e.element_id}",
            unit_type="figure",
            page=e.page,
            text_for_embedding="",
            raw_content_ref=out_path,
            meta={"source_element_id": e.element_id}
        ))
    else:
        out_path = os.path.join(tbl_dir, f"{e.element_id}.png")
        with open(out_path, "wb") as f:
            f.write(img_bytes)

        html_path = os.path.join(tbl_dir, f"{e.element_id}.html")
        with open(html_path, "w", encoding="utf-8") as f:
            f.write(e.content or "")

        asset_units.append(RetrievalUnit(
            unit_id=f"tbl_{e.element_id}",
            unit_type="table",
            page=e.page,
            text_for_embedding="",
            raw_content_ref=html_path,
            meta={"crop_image_path": out_path, "source_element_id": e.element_id}
        ))

doc.close()

print("Asset units created (cropped):", len(asset_units))
print("Example asset:", asset_units[0] if asset_units else None)

with open(os.path.join(ARTIFACTS_DIR, "asset_units_raw.json"), "w", encoding="utf-8") as f:
    json.dump([asdict(u) for u in asset_units], f, ensure_ascii=False, indent=2)


In [None]:
# =========================
# 5) Summarize figures + tables (for retrieval)
# =========================
import os, json, base64
from openai import OpenAI

client = OpenAI(api_key=OPENAI_API_KEY)

def b64_image(path: str) -> str:
    with open(path, "rb") as f:
        return base64.b64encode(f.read()).decode("utf-8")

def describe_figure(image_path: str) -> str:
    b64 = b64_image(image_path)
    resp = client.responses.create(
        model=OPENAI_VISION_MODEL,
        input=[{
            "role": "user",
            "content": [
                {"type": "input_text", "text": "Describe this figure for semantic search. Include what it depicts, labels, and key entities. Keep it concise but informative."},
                {"type": "input_image", "image_url": f"data:image/png;base64,{b64}"},
            ],
        }],
    )
    return resp.output_text.strip()

def summarize_table_html(table_html: str) -> str:
    resp = client.responses.create(
        model=OPENAI_TEXT_MODEL,
        input=f"""Summarize this table for retrieval. Mention what the table represents, key columns, and notable values or trends.
Return a concise summary (2-6 sentences).

TABLE HTML:
{table_html}
"""
    )
    return resp.output_text.strip()

cache_path = os.path.join(ARTIFACTS_DIR, "asset_summaries.json")
summaries = json.load(open(cache_path, "r", encoding="utf-8")) if os.path.exists(cache_path) else {}

for u in asset_units:
    if u.unit_id in summaries:
        continue
    if u.unit_type == "figure":
        summaries[u.unit_id] = describe_figure(u.raw_content_ref)
    elif u.unit_type == "table":
        html = open(u.raw_content_ref, "r", encoding="utf-8").read()
        summaries[u.unit_id] = summarize_table_html(html)

with open(cache_path, "w", encoding="utf-8") as f:
    json.dump(summaries, f, ensure_ascii=False, indent=2)

for u in asset_units:
    u.text_for_embedding = summaries.get(u.unit_id, "")

print("Example summaries:", list(summaries.items())[:2])


In [None]:
# =========================
# 6) Embed retrieval units + build FAISS index
# =========================
import faiss
import cohere

co = cohere.ClientV2(api_key=COHERE_API_KEY)

def l2_normalize(x: np.ndarray, eps: float = 1e-12) -> np.ndarray:
    n = np.linalg.norm(x, axis=1, keepdims=True)
    return x / np.clip(n, eps, None)

def cohere_embed_texts(texts: List[str], model: str, input_type: str, batch_size: int = 64) -> np.ndarray:
    vecs = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        inputs = [{"content": [{"type": "text", "text": t}]} for t in batch]
        resp = co.embed(model=model, inputs=inputs, input_type=input_type, embedding_types=["float"])
        vecs.append(np.array(resp.embeddings.float_, dtype=np.float32))
    return np.vstack(vecs) if vecs else np.zeros((0, 1), dtype=np.float32)

def build_faiss_ip_index(emb: np.ndarray) -> faiss.Index:
    emb = emb.astype(np.float32)
    idx = faiss.IndexFlatIP(emb.shape[1])
    idx.add(emb)
    return idx

def faiss_search(index: faiss.Index, q: np.ndarray, top_k: int):
    D, I = index.search(q.astype(np.float32), top_k)
    return D[0], I[0]

all_units: List[RetrievalUnit] = text_units + asset_units

texts_to_embed = [u.text_for_embedding for u in all_units]
emb = cohere_embed_texts(texts_to_embed, model=COHERE_TEXT_MODEL, input_type="search_document")
emb_n = l2_normalize(emb)

index = build_faiss_ip_index(emb_n)

faiss.write_index(index, os.path.join(ARTIFACTS_DIR, "faiss_all_units_cohere.index"))
np.save(os.path.join(ARTIFACTS_DIR, "all_units_embeddings.npy"), emb_n)

with open(os.path.join(ARTIFACTS_DIR, "all_units_meta.json"), "w", encoding="utf-8") as f:
    json.dump([asdict(u) for u in all_units], f, ensure_ascii=False, indent=2)

print("Indexed units:", len(all_units), "dim:", emb_n.shape[1])


In [None]:
# =========================
# 7) Inference: answer with citations + restrict assets by cited pages
# =========================
import os, json, re
from openai import OpenAI
from PIL import Image
from IPython.display import display
import faiss
import numpy as np

client = OpenAI(api_key=OPENAI_API_KEY)

index = faiss.read_index(os.path.join(ARTIFACTS_DIR, "faiss_all_units_cohere.index"))
all_units = [RetrievalUnit(**u) for u in json.load(open(os.path.join(ARTIFACTS_DIR, "all_units_meta.json"), "r", encoding="utf-8"))]
emb_n = np.load(os.path.join(ARTIFACTS_DIR, "all_units_embeddings.npy"))

def embed_query(query: str) -> np.ndarray:
    q = cohere_embed_texts([query], model=COHERE_TEXT_MODEL, input_type="search_query")
    return l2_normalize(q)

def retrieve_units(query: str, top_k: int = 12):
    q = embed_query(query)
    scores, idxs = faiss_search(index, q, top_k)
    return [(all_units[i], float(scores[j])) for j, i in enumerate(idxs)]

def build_context_for_answer(hits):
    parts = []
    for u, _ in hits:
        if u.unit_type in ("text", "table"):
            parts.append(f"[{u.unit_id} | p{u.page}] {u.text_for_embedding}")
    return "\n\n".join(parts)

def answer_with_citations(query: str, hits) -> str:
    ctx = build_context_for_answer(hits)
    prompt = f"""You are answering a question using ONLY the provided context excerpts.
- Cite sources inline using square brackets with the unit id, e.g., [p3_sec1234_2] or [tbl_az_table_1].
- If you don't know, say so.

Question: {query}

Context:
{ctx}
"""
    resp = client.responses.create(model=OPENAI_TEXT_MODEL, input=prompt)
    return resp.output_text.strip()

def cited_pages_from_answer(answer: str) -> set:
    pages = set(int(p) for p in re.findall(r"\[p(\d+)_", answer))
    if not pages:
        pages = set(int(p) for p in re.findall(r"\bp(\d+)\b", answer))
    return pages

def allowed_asset_indices(pages: set, window: int = 1):
    if not pages:
        return [i for i,u in enumerate(all_units) if u.unit_type in ("figure","table")]
    allowed_pages = set()
    for p in pages:
        for dp in range(-window, window+1):
            if p + dp >= 1:
                allowed_pages.add(p + dp)
    return [i for i,u in enumerate(all_units) if (u.unit_type in ("figure","table") and u.page in allowed_pages)]

def retrieve_assets_filtered(query: str, allowed_idxs, top_k: int = 6):
    if not allowed_idxs:
        return []
    q = embed_query(query)
    sub = emb_n[allowed_idxs].astype(np.float32)
    sub_index = build_faiss_ip_index(sub)
    scores, sub_idxs = faiss_search(sub_index, q, min(top_k, len(allowed_idxs)))
    return [(all_units[allowed_idxs[int(si)]], float(scores[r])) for r, si in enumerate(sub_idxs)]

def show_asset_hits(asset_hits, max_show: int = 4):
    for u, s in asset_hits[:max_show]:
        print(f"{u.unit_id} ({u.unit_type}) page={u.page} score={s:.3f}")
        if u.unit_type == "figure" and u.raw_content_ref:
            display(Image.open(u.raw_content_ref))
        elif u.unit_type == "table" and u.raw_content_ref:
            if u.meta and u.meta.get("crop_image_path") and os.path.exists(u.meta["crop_image_path"]):
                display(Image.open(u.meta["crop_image_path"]))
            html = open(u.raw_content_ref, "r", encoding="utf-8").read()
            print(html[:800] + ("..." if len(html) > 800 else ""))

QUERY = "What are the indian strategies in education system"
hits = retrieve_units(QUERY, top_k=TOP_K_TEXT)
answer = answer_with_citations(QUERY, hits)
print(answer)

pages = cited_pages_from_answer(answer)
print("\nCited pages:", sorted(pages))

allowed_idxs = allowed_asset_indices(pages, window=PAGE_WINDOW)
asset_hits = retrieve_assets_filtered(QUERY, allowed_idxs, top_k=TOP_K_ASSETS)

print("\nTop figures/tables from cited pages (±window):")
show_asset_hits(asset_hits, max_show=4)


## Notes

- Replace the heading heuristic with parser-provided section hierarchy when available.
- If your parser returns normalized boxes, cropping is straightforward; if it returns absolute polygons (Azure DI), ensure correct unit conversion.
- For production: store assets in object storage (GCS/S3) and persist metadata + embeddings in a DB/vector store.
