# 04 — JSON Extraction (Strict, Chunked)

Convert cleaned blocks into strict schema JSON. Emits chunk inputs, raw LLM outputs, and validated JSON per chunk.

In [None]:
# --- CONFIG ---
blocks_dir = "outputs/run_001/03_llmcleaned"
schema_json = "config/schema_prescription.json"
primary_model = "alibayram/medgemma:latest"
fallback_model = "mistral:7b"
chunk_size = 15
max_retries = 3
output_dir = "outputs/run_001/04_extracted_json"

In [5]:
# --- IMPORTS ---
import re, json, math
from pathlib import Path
from typing import List, Dict, Any, Iterable
from langchain.prompts import PromptTemplate
try:
    from langchain_ollama import ChatOllama
except Exception:
    from langchain_community.chat_models import ChatOllama
from langchain.schema import StrOutputParser

# ----------------------
# PATH SETUP (auto-root)
# ----------------------
run_root = Path("outputs/run_001")

# Prefer latest stage that actually exists
search_order = ["03_llmcleaned", "02_cleaned", "01_blocks"]
blocks_dir = None
for folder in search_order:
    cand = run_root / folder
    if any(cand.glob("page_*_blocks*.json")):
        blocks_dir = cand
        print(f"[INFO] Using input folder → {cand}")
        break
if not blocks_dir:
    raise FileNotFoundError(f"No JSON block files found under {run_root}")

out_dir = run_root / "04_jsonextracted"
out_dir.mkdir(parents=True, exist_ok=True)

# ----------------
# SCHEMA LOADING
# ----------------
schema_path = Path(schema_json).expanduser().resolve()
if not schema_path.exists():
    raise FileNotFoundError(schema_path)
schema_str = schema_path.read_text(encoding="utf-8")
try:
    schema = json.loads(schema_str)
except Exception as e:
    raise ValueError(f"Schema must be JSON. Could not parse {schema_path}: {e}")
print(f"[INFO] Loaded schema from {schema_path.name}")

# ----------------
# FILE DISCOVERY
# ----------------
def normalize_page_id(name: str) -> str:
    # page_1_blocks → page_001_blocks (helps consistency if mixed sources exist)
    return re.sub(r"(page_)(\d{1,3})(_)", lambda m: f"{m.group(1)}{int(m.group(2)):03d}{m.group(3)}", name)

page_files = sorted(blocks_dir.glob("page_*_blocks*.json"))
if not page_files:
    raise FileNotFoundError(f"No page_* JSON files found in {blocks_dir}")
print(f"[INFO] Found {len(page_files)} page files.")

# -------------
# LLM PLUMBING
# -------------
fmt = StrOutputParser()

SYSTEM_PROMPT = (
    "You convert OCR text blocks into a single JSON object that STRICTLY conforms to the schema below.\n"
    "Rules:\n"
    " - Output ONLY raw JSON. No prose, no markdown fences, no comments.\n"
    " - Do NOT invent facts. If a field is missing, leave it empty or null as per the schema.\n"
    " - Preserve medical wording, numbers, signs, and units.\n"
    " - Use the input blocks' order and content faithfully.\n\n"
    "Schema:\n{schema}\n\n"
    "Blocks (list of objects with bbox & text):\n{blocks}"
)

prompt = PromptTemplate.from_template(SYSTEM_PROMPT)

def make_llm(name: str):
    try:
        # `format=json` nudges some Ollama models to keep JSON
        return ChatOllama(model=name, temperature=0,
                          model_kwargs={"keep_alive": 0, "format": "json"})
    except Exception as e:
        print("[WARN] Could not init LLM", name, e)
        return None

def strip_fences(s: str) -> str:
    s = s.strip()
    if s.startswith("```"):
        s = re.sub(r"^```(?:json)?\s*", "", s)
        s = re.sub(r"\s*```$", "", s)
    return s.strip()

def light_repair(s: str) -> str:
    # Common tiny repairs (still safe). We do NOT try to invent keys.
    s = strip_fences(s)
    # Remove trailing commas before ] or }
    s = re.sub(r",(\s*[\]\}])", r"\1", s)
    # Ensure it starts with { or [
    m = re.search(r"[\{\[]", s)
    if m and m.start() > 0:
        s = s[m.start():]
    return s

def parse_json_maybe(s: str):
    s1 = strip_fences(s)
    try:
        return json.loads(s1)
    except Exception:
        s2 = light_repair(s1)
        return json.loads(s2)  # will raise if still invalid

# -----------------
# TEXT SELECTION
# -----------------
def best_text(b: Dict[str, Any]) -> str:
    # Priority: LLM → domain cleaned → raw
    for k in ("text_llm", "text_cleaned", "text"):
        v = b.get(k)
        if isinstance(v, str) and v.strip():
            return v
    return ""

def blocks_payload(chunk: List[Dict[str, Any]]) -> str:
    slim = []
    for b in chunk:
        slim.append({
            "bbox": b.get("bbox", [0, 0, 1, 1]),
            "text": best_text(b),
            "source": b.get("source", "")
        })
    return json.dumps(slim, ensure_ascii=False)

# ---------------
# CHUNKING UTILS
# ---------------
def chunks_by_size(items: List[Dict[str, Any]], max_chars: int) -> Iterable[List[Dict[str, Any]]]:
    """Chunk list so that json.dumps(list) stays under ~max_chars."""
    current, size = [], 0
    for it in items:
        t = best_text(it)
        # estimate addition size (text + bbox + overhead)
        est = len(t) + 180
        if current and size + est > max_chars:
            yield current
            current, size = [], 0
        current.append(it)
        size += est
    if current:
        yield current

# --------------
# CORE INFERENCE
# --------------
def try_model(model_name: str, blocks_json: str, max_retries: int):
    llm = make_llm(model_name)
    if not llm:
        return None
    chain = prompt | llm | fmt

    payload = blocks_json
    for attempt in range(max_retries):
        try:
            out = chain.invoke({"schema": schema_str, "blocks": payload})
            obj = parse_json_maybe(out)
            return obj
        except Exception as e:
            # progressive shrink: trim characters from payload text
            print(f"[WARN] {model_name} attempt {attempt+1}/{max_retries} failed: {e}")
            try:
                # shrink by 20% characters
                cut = max(1, int(len(payload) * 0.8))
                payload = payload[:cut]
                # try to end on a JSON boundary if possible
                payload = re.sub(r",\s*?\Z", "", payload)
            except Exception:
                pass
    return None

# ----------------
# MAIN (per page)
# ----------------
piece_paths: List[str] = []
primary = primary_model
fallback = fallback_model if 'fallback_model' in globals() else None
max_retries = int(max_retries)

# If user provided chunk_size (count), we still respect a char budget to avoid overflows.
# Default char budget is ~90k chars which is safe for many local models.
char_budget = int(globals().get("char_budget", 90000))

for pf in page_files:
    # Keep pages separate for locality/coherence
    try:
        blocks = json.loads(pf.read_text(encoding="utf-8"))
    except Exception as e:
        print(f"[WARN] Could not read {pf.name}: {e}")
        continue

    page_id = normalize_page_id(pf.stem)  # for consistent filenames/logs
    total = len(blocks)
    print(f"\n[INFO] Page {page_id}: {total} blocks")

    # Chunk this single page by char budget (or by user chunk_size if smaller)
    max_count = int(globals().get("chunk_size", 999999))
    page_chunks: List[List[Dict[str, Any]]] = []
    # First generate by char budget, then split any oversized chunk by count
    for ch in chunks_by_size(blocks, char_budget):
        if len(ch) <= max_count:
            page_chunks.append(ch)
        else:
            # split by count
            for i in range(0, len(ch), max_count):
                page_chunks.append(ch[i:i+max_count])

    for ci, chunk in enumerate(page_chunks, start=1):
        payload = blocks_payload(chunk)
        # Trace input
        (out_dir / f"{page_id}_input_{ci}.json").write_text(payload, encoding="utf-8")

        parsed = try_model(primary, payload, max_retries)
        if parsed is None and fallback:
            print(f"[INFO] Fallback model on {page_id} chunk {ci} → {fallback}")
            parsed = try_model(fallback, payload, max_retries)

        if parsed is not None:
            vp = out_dir / f"{page_id}_valid_{ci}.json"
            vp.write_text(json.dumps(parsed, ensure_ascii=False, indent=2), encoding="utf-8")
            piece_paths.append(str(vp))
            print(f"  ✓ chunk {ci}/{len(page_chunks)} ok → {vp.name}")
        else:
            print(f"  ✗ chunk {ci}/{len(page_chunks)} failed")

# -------------
# MERGE OUTPUT
# -------------
if piece_paths:
    merged: List[Any] = []
    for p in piece_paths:
        try:
            merged.extend(json.loads(Path(p).read_text(encoding="utf-8")))
        except Exception as e:
            print(f"[WARN] Skipping {Path(p).name}: {e}")
    final_path = out_dir / "final_prescription.json"
    final_path.write_text(json.dumps(merged, indent=2, ensure_ascii=False), encoding="utf-8")
    print(f"\n🩺 merged → {final_path} ({len(merged)} records)")
else:
    print("\n[WARN] No valid chunks to merge.")

print("\n✅ Extraction complete.")


[INFO] Using input folder → outputs/run_001/03_llmcleaned
[INFO] Loaded schema from schema_prescription.json
[INFO] Found 8 page files.

[INFO] Page page_001_blocks.domain.llm: 8 blocks
[WARN] alibayram/medgemma:latest attempt 1/3 failed: Expecting value: line 6787 column 17 (char 115950)
  ✓ chunk 1/1 ok → page_001_blocks.domain.llm_valid_1.json

[INFO] Page page_002_blocks.domain.llm: 12 blocks
  ✓ chunk 1/1 ok → page_002_blocks.domain.llm_valid_1.json

[INFO] Page page_003_blocks.domain.llm: 8 blocks
  ✓ chunk 1/1 ok → page_003_blocks.domain.llm_valid_1.json

[INFO] Page page_004_blocks.domain.llm: 9 blocks
  ✓ chunk 1/1 ok → page_004_blocks.domain.llm_valid_1.json

[INFO] Page page_001_blocks.domain.llm: 56 blocks
  ✓ chunk 1/4 ok → page_001_blocks.domain.llm_valid_1.json
  ✓ chunk 2/4 ok → page_001_blocks.domain.llm_valid_2.json
  ✓ chunk 3/4 ok → page_001_blocks.domain.llm_valid_3.json
  ✓ chunk 4/4 ok → page_001_blocks.domain.llm_valid_4.json

[INFO] Page page_002_blocks.domain.

In [8]:
from pathlib import Path
import json

def verify_input_jsons(input_dir: Path):
    print(f"\n[QA] Verifying input files in {input_dir}...")
    for f in sorted(input_dir.glob("page_*_blocks*.json")):
        try:
            data = json.loads(f.read_text(encoding="utf-8"))
            if not isinstance(data, list) or not data:
                print(f"[WARN] {f.name}: empty or non-list")
                continue
            if not all(isinstance(x, dict) for x in data):
                print(f"[WARN] {f.name}: malformed entries (not dicts)")
                continue
            sample = data[0]
            keys = list(sample.keys())[:5]
            print(f"  ✓ {f.name}: {len(data)} blocks, sample keys={keys}")
        except Exception as e:
            print(f"[ERROR] {f.name}: {e}")

# Verify input files
verify_input_jsons(blocks_dir)



[QA] Verifying input files in outputs/run_001/03_llmcleaned...
  ✓ page_001_blocks.domain.llm.json: 8 blocks, sample keys=['bbox', 'text', 'source', 'confidence', 'section']
  ✓ page_002_blocks.domain.llm.json: 12 blocks, sample keys=['bbox', 'text', 'source', 'confidence', 'section']
  ✓ page_003_blocks.domain.llm.json: 8 blocks, sample keys=['bbox', 'text', 'source', 'confidence', 'section']
  ✓ page_004_blocks.domain.llm.json: 9 blocks, sample keys=['bbox', 'text', 'source', 'confidence', 'section']
  ✓ page_1_blocks.domain.llm.json: 56 blocks, sample keys=['bbox', 'text', 'source', 'confidence', 'section']
  ✓ page_2_blocks.domain.llm.json: 74 blocks, sample keys=['bbox', 'text', 'source', 'confidence', 'section']
  ✓ page_3_blocks.domain.llm.json: 17 blocks, sample keys=['bbox', 'text', 'source', 'confidence', 'section']
  ✓ page_4_blocks.domain.llm.json: 30 blocks, sample keys=['bbox', 'text', 'source', 'confidence', 'section']


In [9]:
def verify_valid_jsons(valid_dir: Path):
    print(f"\n[QA] Verifying chunked valid JSONs in {valid_dir}...")
    for f in sorted(valid_dir.glob("*_valid_*.json")):
        try:
            data = json.loads(f.read_text(encoding="utf-8"))
            if not isinstance(data, list) or not data:
                print(f"[WARN] {f.name}: empty or non-list")
                continue
            empty_count = sum(1 for x in data if not any(v for v in x.values()))
            print(f"  ✓ {f.name}: {len(data)} records ({empty_count} empty)")
        except Exception as e:
            print(f"[ERROR] {f.name}: {e}")
# Verify valid chunked files
verify_valid_jsons(out_dir) 


[QA] Verifying chunked valid JSONs in outputs/run_001/04_jsonextracted...
[WARN] page_001_blocks.domain.llm_valid_1.json: empty or non-list
[WARN] page_001_blocks.domain.llm_valid_2.json: empty or non-list
[WARN] page_001_blocks.domain.llm_valid_3.json: empty or non-list
[WARN] page_001_blocks.domain.llm_valid_4.json: empty or non-list
[WARN] page_002_blocks.domain.llm_valid_1.json: empty or non-list
[WARN] page_002_blocks.domain.llm_valid_2.json: empty or non-list
[WARN] page_002_blocks.domain.llm_valid_3.json: empty or non-list
[WARN] page_002_blocks.domain.llm_valid_4.json: empty or non-list
[WARN] page_002_blocks.domain.llm_valid_5.json: empty or non-list
[WARN] page_003_blocks.domain.llm_valid_1.json: empty or non-list
[WARN] page_003_blocks.domain.llm_valid_2.json: empty or non-list
[WARN] page_004_blocks.domain.llm_valid_1.json: empty or non-list
[WARN] page_004_blocks.domain.llm_valid_2.json: empty or non-list


In [10]:
def summarize_extraction(valid_dir: Path):
    valid_files = sorted(valid_dir.glob("*_valid_*.json"))
    total_records = 0
    for f in valid_files:
        data = json.loads(f.read_text(encoding="utf-8"))
        total_records += len(data)
    print(f"\n[SUMMARY] Total valid chunks: {len(valid_files)}")
    print(f"[SUMMARY] Total merged records (pre-merge): {total_records}")

summarize_extraction(out_dir)


[SUMMARY] Total valid chunks: 13
[SUMMARY] Total merged records (pre-merge): 130
