## Mistral OCR: Clean Annotations + Headings Comparison

This notebook runs a clean, readable pipeline:
- Basic OCR to get per-page markdown and image bboxes (with crops)
- Annotations to extract language, title, authors, chapter titles, URLs
- Headings extracted two ways for comparison:
  - From page markdown (ATX and setext syntax)
  - From an OCR-inferred outline (model infers heading levels and pages)

References: [Basic OCR](https://docs.mistral.ai/capabilities/document_ai/basic_ocr/), [Annotations](https://docs.mistral.ai/capabilities/document_ai/annotations/)


### 1) Setup
Configures the client, paths, and constants used across the notebook.


In [None]:
"""
Setup the Mistral client and define basic constants.
- Keep code simple and explicit; raise if required variables are missing.
"""

# ### CONSTANTS ###
from pathlib import Path
NOTEBOOK_NAME: str = "2025.09.08-test_mistral_ocr_annotations_clean"
PDF_PATH: Path = Path("/Users/Focus/Downloads/2212.14024v2.pdf")
MODEL: str = "mistral-ocr-latest"
PAGES: list[int] = list(range(8))  # Document annotation supports up to 8 pages

# ### DEPENDENCIES ###
import os
import base64
from dotenv import load_dotenv
from mistralai import Mistral

# ### SETUP CLIENT ###
load_dotenv()
api_key = os.environ.get("MISTRAL_API_KEY")
if not api_key:
    raise RuntimeError("MISTRAL_API_KEY is not set in environment.")
client = Mistral(api_key=api_key)

if not PDF_PATH.exists():
    raise FileNotFoundError(f"PDF not found: {PDF_PATH}")

# Build a data: URL for the PDF (simple and explicit)
with open(PDF_PATH, "rb") as f:
    _pdf_bytes = f.read()
DOCUMENT_SPEC = {
    "type": "document_url",
    "document_url": "data:application/pdf;base64," + base64.b64encode(_pdf_bytes).decode("utf-8"),
}

print("Ready. Model:", MODEL)
print("PDF:", PDF_PATH)
print("Pages (doc annotation scope):", PAGES)


### 2) Basic OCR (markdown + image crops)
Runs Basic OCR to get per-page markdown and image bboxes. Also prints page count and shows cropped images.


In [None]:
"""
Run Basic OCR and display results:
- Print per-page markdown
- Print bbox coordinates
- Display each cropped bbox image
- Print number of pages detected
"""

from IPython.display import display
from PIL import Image as PILImage
import io

ocr = client.ocr.process(
    model=MODEL,
    document=DOCUMENT_SPEC,
    include_image_base64=True,
)

print("Pages (len):", len(ocr.pages))
for page in ocr.pages:
    dims = getattr(page, "dimensions", None)
    print(f"\n## Page {page.index} | dims: {getattr(dims,'width',None)}x{getattr(dims,'height',None)} dpi={getattr(dims,'dpi',None)}")
    print(page.markdown)

    images = getattr(page, "images", []) or []
    if not images:
        print("(no image bboxes)")
    for i, img in enumerate(images, start=1):
        tlx = img.top_left_x
        tly = img.top_left_y
        brx = img.bottom_right_x
        bry = img.bottom_right_y
        w = brx - tlx
        h = bry - tly
        print(f"- Image {i}: id={getattr(img,'id',None)} bbox=({tlx},{tly})→({brx},{bry}) size=({w}x{h})")

        data_str = img.image_base64
        if not data_str:
            continue
        if data_str.startswith("data:"):
            _, b64_data = data_str.split(",", 1)
        else:
            b64_data = data_str
        image_bytes = base64.b64decode(b64_data)
        pil_img = PILImage.open(io.BytesIO(image_bytes))
        display(pil_img)

ocr_pages = ocr.pages  # used later


### 3) Annotations (language, title, authors, chapters, URLs)
Requests document-level annotations and prints them in a readable format.


In [None]:
"""
Extract document-level fields via Document Annotation with batching.
Processes the document in 8-page chunks (API limit) and combines results.
- Schema: language, title, authors, chapter_titles, urls
"""

### IMPORTS ###
from pydantic import BaseModel, Field
from mistralai.extra import response_format_from_pydantic_model
import json as _json
from typing import Dict, List, Any

### SCHEMA ###
class DocumentAnnotation(BaseModel):
    language: str = Field(..., description="Language of the document")
    title: str | None = Field(None, description="Document title if present")
    authors: list[str] = Field(..., description="Author names")
    chapter_titles: list[str] = Field(..., description="Chapter titles in order")
    urls: list[str] = Field(..., description="URLs referenced in the document")

### HELPER FUNCTIONS ###
def create_page_batches(total_pages: int, batch_size: int = 8) -> List[List[int]]:
    """
    Split pages into batches of specified size without overlap.
    
    Args:
        total_pages: Total number of pages in document
        batch_size: Maximum pages per batch (API limit is 8)
        
    Returns:
        List of page number lists, e.g. [[0,1,2,3,4,5,6,7], [8,9,10,11,12,13,14]]
    """
    batches = []
    for start_page in range(0, total_pages, batch_size):
        end_page = min(start_page + batch_size, total_pages)
        batch_pages = list(range(start_page, end_page))
        batches.append(batch_pages)
    return batches

def process_annotation_batch(client, model: str, document_spec: Dict, pages: List[int], response_format) -> Dict[str, Any]:
    """
    Process a single batch of pages for document annotation.
    
    Args:
        client: Mistral client
        model: Model name
        document_spec: Document specification
        pages: List of page numbers to process
        response_format: Pydantic response format
        
    Returns:
        Parsed annotation data as dictionary
    """
    ann = client.ocr.process(
        model=model,
        document=document_spec,
        pages=pages,
        document_annotation_format=response_format,
        include_image_base64=False,
    )

    raw = ann.document_annotation
    if isinstance(raw, str):
        return _json.loads(raw)
    elif hasattr(raw, "model_dump"):
        return raw.model_dump()
    elif isinstance(raw, dict):
        return raw
    else:
        raise ValueError(f"Unsupported document_annotation type: {type(raw)}")

def combine_annotations(batch_results: List[Dict[str, Any]]) -> Dict[str, Any]:
    """
    Combine annotation results from multiple batches.
    
    Args:
        batch_results: List of annotation dictionaries from each batch
        
    Returns:
        Combined annotation dictionary
    """
    if not batch_results:
        raise ValueError("No batch results to combine")
    
    # Step 1: Take language and title from first batch (should be consistent)
    first_batch = batch_results[0]
    combined = {
        "language": first_batch.get("language"),
        "title": first_batch.get("title"),
    }
    
    # Step 2: Combine authors from all batches (remove duplicates, preserve order)
    all_authors = []
    seen_authors = set()
    for batch in batch_results:
        for author in batch.get("authors", []):
            if author.lower() not in seen_authors:
                all_authors.append(author)
                seen_authors.add(author.lower())
    combined["authors"] = all_authors
    
    # Step 3: Combine chapter titles in order across all batches
    all_chapter_titles = []
    for batch in batch_results:
        all_chapter_titles.extend(batch.get("chapter_titles", []))
    combined["chapter_titles"] = all_chapter_titles
    
    # Step 4: Combine URLs from all batches (remove duplicates)
    all_urls = []
    seen_urls = set()
    for batch in batch_results:
        for url in batch.get("urls", []):
            if url not in seen_urls:
                all_urls.append(url)
                seen_urls.add(url)
    combined["urls"] = all_urls
    
    return combined

### MAIN PROCESSING ###
doc_rf = response_format_from_pydantic_model(DocumentAnnotation)

# Step 1: Create page batches based on total pages from OCR
total_pages = len(ocr_pages)
page_batches = create_page_batches(total_pages, batch_size=8)
print(f"Processing {total_pages} pages in {len(page_batches)} batches:")
for i, batch in enumerate(page_batches):
    print(f"  Batch {i+1}: pages {batch[0]}-{batch[-1]} ({len(batch)} pages)")

# Step 2: Process each batch
batch_results = []
for i, batch_pages in enumerate(page_batches):
    print(f"Processing batch {i+1}/{len(page_batches)}...")
    batch_result = process_annotation_batch(
        client=client,
        model=MODEL,
        document_spec=DOCUMENT_SPEC,
        pages=batch_pages,
        response_format=doc_rf
    )
    batch_results.append(batch_result)

# Step 3: Combine results from all batches
parsed = combine_annotations(batch_results)

# Step 4: Display final results
print("\n=== COMBINED DOCUMENT ANNOTATIONS ===")
print("language:", parsed.get("language"))
print("title:", parsed.get("title"))
print("authors:")
for a in parsed.get("authors", []):
    print(" -", a)
print("chapter_titles:")
for t in parsed.get("chapter_titles", []):
    print(" -", t)
print("urls:")
for u in parsed.get("urls", []):
    print(" -", u)


### 4) Headings from Markdown (by page)
Extracts headings from each page’s markdown using ATX and setext rules, and records page index + line number.


In [None]:
"""
Parse headings from per-page markdown.
- ATX: lines starting with 1..6 '#' characters
- Setext: lines followed by '===' or '---' underlines
"""

import re
from typing import List, Dict, Any

markdown_headings: List[Dict[str, Any]] = []
for page in ocr_pages:
    page_index = page.index
    lines = (page.markdown or "").splitlines()

    # ATX headers
    for i, line in enumerate(lines, start=1):
        m = re.match(r"^(#{1,6})\s+(.*)$", line)
        if m:
            level = len(m.group(1))
            text = m.group(2).strip()
            markdown_headings.append({
                "source": "markdown",
                "page_index": page_index,
                "line": i,
                "level": level,
                "text": text,
            })

    # Setext headers
    for i in range(2, len(lines) + 1):
        underline = lines[i - 1].strip()
        if re.match(r"^={3,}$", underline):
            markdown_headings.append({
                "source": "markdown",
                "page_index": page_index,
                "line": i - 1,
                "level": 1,
                "text": lines[i - 2].strip(),
            })
        elif re.match(r"^-{3,}$", underline):
            markdown_headings.append({
                "source": "markdown",
                "page_index": page_index,
                "line": i - 1,
                "level": 2,
                "text": lines[i - 2].strip(),
            })

print(f"Found {len(markdown_headings)} markdown headings")
for h in markdown_headings:
    print(f"[page {h['page_index']} line {h['line']}] h{h['level']}: {h['text']}")


### 5) Headings from OCR-Inferred Outline
Requests a structured outline as an annotation, where the model infers heading levels and page indices.


In [None]:
"""
Request a structured outline as a document annotation with batching.
Processes the document in 8-page chunks (API limit) and combines outlines.
- The model infers heading level (1..6) only. Page indices are not requested.
"""

### IMPORTS ###
from pydantic import BaseModel, Field
from typing import List, Dict, Any
from mistralai.extra import response_format_from_pydantic_model
import json as _json

### SCHEMA ###
class OutlineItem(BaseModel):
    title: str = Field(..., description="Heading text")
    level: int = Field(..., description="Heading level 1..6")

class DocumentOutline(BaseModel):
    outline: List[OutlineItem] = Field(..., description="Document outline")

### HELPER FUNCTIONS ###
def process_outline_batch(client, model: str, document_spec: Dict, pages: List[int], response_format) -> List[Dict[str, Any]]:
    """
    Process a single batch of pages for outline extraction.
    
    Args:
        client: Mistral client
        model: Model name
        document_spec: Document specification
        pages: List of page numbers to process
        response_format: Pydantic response format
        
    Returns:
        List of outline items from this batch
    """
    outline_resp = client.ocr.process(
        model=model,
        document=document_spec,
        pages=pages,
        document_annotation_format=response_format,
        include_image_base64=False,
    )

    raw_outline = outline_resp.document_annotation
    if isinstance(raw_outline, str):
        outline_parsed = _json.loads(raw_outline)
    elif hasattr(raw_outline, "model_dump"):
        outline_parsed = raw_outline.model_dump()
    elif isinstance(raw_outline, dict):
        outline_parsed = raw_outline
    else:
        raise ValueError(f"Unsupported outline type: {type(raw_outline)}")
    
    return outline_parsed.get("outline", [])

def combine_outlines(batch_outline_results: List[List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
    """
    Combine outline results from multiple batches.
    
    Args:
        batch_outline_results: List of outline item lists from each batch
        
    Returns:
        Combined list of outline items in order
    """
    if not batch_outline_results:
        raise ValueError("No batch outline results to combine")
    
    # Step 1: Combine all outline items in order
    combined_outline = []
    for batch_outline in batch_outline_results:
        combined_outline.extend(batch_outline)
    
    return combined_outline

### MAIN PROCESSING ###
outline_rf = response_format_from_pydantic_model(DocumentOutline)

# Step 1: Reuse page batches from document annotation processing
# (page_batches was already created in cell 6)
print(f"Processing outline for {len(page_batches)} batches:")

# Step 2: Process each batch for outline
batch_outline_results = []
for i, batch_pages in enumerate(page_batches):
    print(f"Processing outline batch {i+1}/{len(page_batches)}...")
    batch_outline = process_outline_batch(
        client=client,
        model=MODEL,
        document_spec=DOCUMENT_SPEC,
        pages=batch_pages,
        response_format=outline_rf
    )
    batch_outline_results.append(batch_outline)
    print(f"  Found {len(batch_outline)} headings in batch {i+1}")

# Step 3: Combine outline results from all batches
ocr_outline = combine_outlines(batch_outline_results)

# Step 4: Display final results
print(f"\n=== COMBINED OCR OUTLINE ===")
print(f"Found {len(ocr_outline)} outline headings total")
for item in ocr_outline:
    print(f"h{item.get('level')}: {item.get('title')}")


### 6) Compare Headings (Markdown vs OCR-Inferred)
Shows both heading lists side-by-side (printed), so you can visually compare consistency.


In [None]:
"""
Print both headings lists for a quick manual comparison.
- Markdown headings include page and line
- OCR outline includes only level and title (no page)
"""

print("\nMarkdown-derived headings:")
for h in markdown_headings:
    print(f"[page {h['page_index']} line {h['line']}] h{h['level']}: {h['text']}")

print("\nOCR-inferred outline headings:")
for item in ocr_outline:
    print(f"h{item.get('level')}: {item.get('title')}")


### 7) Align OCR Outline to Markdown Lines (RapidFuzz)
Matches each OCR-inferred outline heading to the closest markdown heading on the same page to assign an exact markdown line number. This preserves the model’s semantic levels while grounding to precise locations.

Note: Requires `rapidfuzz`. If not installed, install it in your environment.


In [None]:
"""
Align OCR outline items to markdown headings using RapidFuzz.
Simplified:
- Ignore any page indices from the outline (not trustworthy).
- Match each OCR outline title against ALL markdown headings in the document.
- Assign page and line purely from the best markdown match.
Saves `headers_index_normalized.json` with aligned results.
"""

from typing import List, Dict, Any
from rapidfuzz import fuzz
import json as _json
from pathlib import Path as _Path

def normalize_heading_text(text: str) -> str:
    """Lowercase and collapse whitespace for robust matching."""
    return " ".join((text or "").lower().split())

# Build a flat list of markdown headings across the whole document
all_md: List[Dict[str, Any]] = list(markdown_headings)

aligned: List[Dict[str, Any]] = []
THRESHOLD: int = 85

for item in ocr_outline:
    title = str(item.get("title") or "")
    level = int(item.get("level"))

    if not all_md:
        aligned.append({
            "page_index": None,
            "ocr_title": title,
            "ocr_level": level,
            "markdown_line": None,
            "markdown_title": None,
            "score": None,
        })
        continue

    query = normalize_heading_text(title)

    best_score = -1
    best_idx = None
    for idx, md_h in enumerate(all_md):
        score = fuzz.token_set_ratio(query, normalize_heading_text(md_h["text"]))
        if score > best_score:
            best_score = score
            best_idx = idx

    if best_idx is not None and best_score >= THRESHOLD:
        md_h = all_md[best_idx]
        aligned.append({
            "page_index": md_h["page_index"],
            "ocr_title": title,
            "ocr_level": level,
            "markdown_line": md_h["line"],
            "markdown_title": md_h["text"],
            "score": int(best_score),
        })
    else:
        aligned.append({
            "page_index": None,
            "ocr_title": title,
            "ocr_level": level,
            "markdown_line": None,
            "markdown_title": None,
            "score": int(best_score) if best_score >= 0 else None,
        })

print(f"Aligned {len(aligned)} outline items")
for a in aligned:
    print(f"[page {a['page_index']}] h{a['ocr_level']} → line {a['markdown_line']} (score={a['score']}): {a['ocr_title']}")

# Save outputs with filename prefix
out_dir = _Path.cwd() / "outputs" 
out_dir.mkdir(exist_ok=True)

out_path = out_dir / f"{NOTEBOOK_NAME}_headers_index_normalized.json"
with open(out_path, "w", encoding="utf-8") as f:
    _json.dump(aligned, f, ensure_ascii=False, indent=2)
print("Saved:", out_path)
