# 04a — JSON Extraction Hardening
Harden the final JSON extraction step:
- Coerce any model output into a list of dicts
- Save raw model replies per chunk
- Page-by-page chunking + fallback model
- Quick QA of produced valid chunks


In [7]:
run_root = "outputs/run_001"
schema_json = "config/schema_prescription.json"
primary_model = "alibayram/medgemma:latest"
fallback_model = "mistral:7b"
chunk_size = 40
max_retries = 3


In [8]:
import re, json
from pathlib import Path
from typing import List, Dict, Any, Optional
from langchain.prompts import PromptTemplate
try:
    from langchain_ollama import ChatOllama
except Exception:
    from langchain_community.chat_models import ChatOllama
from langchain.schema import StrOutputParser

run_root = Path(run_root).expanduser().resolve()
schema_path = Path(schema_json).expanduser().resolve()
if not schema_path.exists():
    raise FileNotFoundError(schema_path)

out_dir = run_root / '04_jsonextracted'
out_dir.mkdir(parents=True, exist_ok=True)

search_order = ['03_llmcleaned', '02_cleaned', '01_blocks']
blocks_dir: Optional[Path] = None
for folder in search_order:
    candidate = run_root / folder
    if any(candidate.glob('page_*_blocks*.json')):
        blocks_dir = candidate
        break
if not blocks_dir:
    raise FileNotFoundError(f'No page_* block files found under {run_root}')

print('[INFO] Using input folder →', blocks_dir)
print('[INFO] Output folder →', out_dir)
print('[INFO] Schema →', schema_path.name)

schema_str = schema_path.read_text(encoding='utf-8')
schema = json.loads(schema_str)


[INFO] Using input folder → /Users/balijepalli/Documents/GitHub/entheory-ai/notebooks/outputs/run_001/03_llmcleaned
[INFO] Output folder → /Users/balijepalli/Documents/GitHub/entheory-ai/notebooks/outputs/run_001/04_jsonextracted
[INFO] Schema → schema_prescription.json


In [9]:
def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def coerce_json_str(s: str) -> str:
    s = s.strip()
    if s.startswith("```"):
        s = re.sub(r"^```(?:json)?\s*", "", s)
        s = re.sub(r"\s*```$", "", s)
    return s

def coerce_to_list_of_dicts(obj):
    if obj is None:
        return []
    if isinstance(obj, str):
        s = obj.strip()
        if not s:
            return []
        try:
            obj = json.loads(s)
        except Exception:
            return [{'_raw': s}]
    if isinstance(obj, dict):
        return [obj]
    if isinstance(obj, list):
        out = []
        for x in obj:
            out.append(x if isinstance(x, dict) else {'_value': x})
        return out
    return [{'_value': obj}]

def is_effectively_empty(obj: dict) -> bool:
    def empty(x):
        return (x is None) or (isinstance(x, str) and x.strip()=="") or (isinstance(x, list) and len(x)==0)
    keys_scalar = ["diagnosis","complaints","advice","follow_up"]
    top = ["patient","doctor"]
    if all(empty(obj.get(k)) for k in keys_scalar) \
       and all(empty(obj.get(k,{}).get("name","")) for k in top) \
       and empty(obj.get("medications",[])) and empty(obj.get("tests",[])) and empty(obj.get("investigations",[])):
        return True
    return False

prompt = PromptTemplate.from_template(
    'Return **only** a JSON **array** of objects that conforms to this schema.\n'
    'Do not include prose or code fences.\n\n'
    'Schema:\n{schema}\n\nBlocks (list of objects with bbox & text):\n{blocks}\n'
)
fmt = StrOutputParser()

def make_llm(name: str):
    if not name:
        return None
    try:
        return ChatOllama(model=name, temperature=0, model_kwargs={'keep_alive': 0, 'format': 'json'})
    except Exception as e:
        print('[WARN] Could not init LLM', name, e)
        return None

def try_model(name: str, payload: str, max_retries: int, raw_outfile: Path):
    llm = make_llm(name)
    if not llm:
        return None
    chain = prompt | llm | fmt
    cur_payload = payload
    for attempt in range(max_retries):
        try:
            out = chain.invoke({'schema': schema_str, 'blocks': cur_payload})
            raw = (out or '').strip()
            raw_outfile.write_text(raw, encoding='utf-8')
            if raw.startswith('```'):
                import re
                raw = re.sub(r'^```(?:json)?\\s*', '', raw)
                raw = re.sub(r'\\s*```$', '', raw)
            try:
                parsed = json.loads(raw)
            except Exception:
                parsed = raw
            as_list = coerce_to_list_of_dicts(parsed)
            if not as_list or not any(isinstance(x, dict) and x for x in as_list):
                raise ValueError('Parsed to empty or non-dict content')
            return as_list
        except Exception as e:
            print(f'[WARN] {name} attempt {attempt+1}/{max_retries} failed: {e}')
            try:
                obj = json.loads(cur_payload)
                shrunk = obj[:max(1, int(len(obj) * 0.75))]
                cur_payload = json.dumps(shrunk, ensure_ascii=False)
            except Exception:
                pass
    return None

def verify_valid_jsons(valid_dir: Path):
    print(f"\n[QA] Verifying chunked valid JSONs in {valid_dir}...")
    total_records = 0
    files = sorted(valid_dir.glob('valid_chunk_*.json')) + sorted(valid_dir.glob('*_valid_*.json'))
    for f in files:
        try:
            data = json.loads(f.read_text(encoding='utf-8'))
            if isinstance(data, dict):
                sample_keys = list(data.keys())[:6]
                print(f'  ✓ {f.name}: dict (1 record) keys={sample_keys}')
                total_records += 1
            elif isinstance(data, list):
                recs = len(data)
                sample = data[0] if recs else {}
                sample_keys = list(sample.keys())[:6] if isinstance(sample, dict) else [type(sample).__name__]
                print(f'  ✓ {f.name}: list ({recs} records) sample_keys={sample_keys}')
                total_records += recs
            else:
                print(f'[WARN] {f.name}: unexpected type {type(data).__name__}')
        except Exception as e:
            print(f'[ERROR] {f.name}: {e}')
    print(f'[SUMMARY] Total merged records (pre-merge): {total_records}')


In [10]:
pages = sorted((run_root / '03_llmcleaned').glob('page_*_blocks*.json'))
if not pages:
    pages = sorted((run_root / '02_cleaned').glob('page_*_blocks*.json'))
if not pages:
    pages = sorted((run_root / '01_blocks').glob('page_*_blocks*.json'))
if not pages:
    raise FileNotFoundError('No page_* JSON files found')
print(f'[INFO] Found {len(pages)} page files.')

piece_paths = []
for page_path in pages:
    name = page_path.stem
    blocks = json.loads(page_path.read_text(encoding='utf-8'))
    print(f"\n[INFO] Page {name}: {len(blocks)} blocks")
    for idx, chunk in enumerate(chunks(blocks, int(chunk_size)), start=1):
        payload = json.dumps([
            {'bbox': b.get('bbox', [0,0,1,1]), 'text': b.get('text', ''), 'source': b.get('source', '')}
            for b in chunk
        ], ensure_ascii=False)
        (out_dir / f'{name}_input_{idx}.json').write_text(payload, encoding='utf-8')
        raw_path = out_dir / f'{name}_raw_{idx}.txt'
        parsed = try_model(primary_model, payload, int(max_retries), raw_outfile=raw_path)
        if parsed is None and fallback_model:
            print(f'  [INFO] Retrying chunk {idx} with fallback {fallback_model}')
            parsed = try_model(fallback_model, payload, int(max_retries), raw_outfile=raw_path)
        if parsed is not None:
            vp = out_dir / f'{name}_valid_{idx}.json'
            vp.write_text(json.dumps(parsed, ensure_ascii=False, indent=2), encoding='utf-8')
            piece_paths.append(vp)
            print(f'  ✓ chunk {idx} ok → {vp.name} (items={len(parsed)})')
        else:
            print(f'  ✗ chunk {idx} failed')


[INFO] Found 8 page files.

[INFO] Page page_001_blocks.domain.llm: 8 blocks
  ✓ chunk 1 ok → page_001_blocks.domain.llm_valid_1.json (items=1)

[INFO] Page page_002_blocks.domain.llm: 12 blocks
  ✓ chunk 1 ok → page_002_blocks.domain.llm_valid_1.json (items=1)

[INFO] Page page_003_blocks.domain.llm: 8 blocks
  ✓ chunk 1 ok → page_003_blocks.domain.llm_valid_1.json (items=1)

[INFO] Page page_004_blocks.domain.llm: 9 blocks
  ✓ chunk 1 ok → page_004_blocks.domain.llm_valid_1.json (items=1)

[INFO] Page page_1_blocks.domain.llm: 56 blocks
  ✓ chunk 1 ok → page_1_blocks.domain.llm_valid_1.json (items=1)
  ✓ chunk 2 ok → page_1_blocks.domain.llm_valid_2.json (items=1)

[INFO] Page page_2_blocks.domain.llm: 74 blocks
  ✓ chunk 1 ok → page_2_blocks.domain.llm_valid_1.json (items=1)
  ✓ chunk 2 ok → page_2_blocks.domain.llm_valid_2.json (items=1)

[INFO] Page page_3_blocks.domain.llm: 17 blocks
  ✓ chunk 1 ok → page_3_blocks.domain.llm_valid_1.json (items=1)

[INFO] Page page_4_blocks.domai

In [11]:
final = []
for p in piece_paths:
    try:
        obj = json.loads(p.read_text(encoding='utf-8'))
        final.extend(coerce_to_list_of_dicts(obj))
    except Exception as e:
        print(f'[WARN] Skipping {p.name}: {e}')

final_path = out_dir / 'final_prescription.json'
final_path.write_text(json.dumps(final, indent=2, ensure_ascii=False), encoding='utf-8')
print(f"\n🩺 merged → {final_path} (records={len(final)})")
verify_valid_jsons(out_dir)
print('\n✅ Extraction hardening complete.')



🩺 merged → /Users/balijepalli/Documents/GitHub/entheory-ai/notebooks/outputs/run_001/04_jsonextracted/final_prescription.json (records=10)

[QA] Verifying chunked valid JSONs in /Users/balijepalli/Documents/GitHub/entheory-ai/notebooks/outputs/run_001/04_jsonextracted...
  ✓ page_001_blocks.domain.llm_valid_1.json: list (1 records) sample_keys=['_raw']
  ✓ page_002_blocks.domain.llm_valid_1.json: list (1 records) sample_keys=['_raw']
  ✓ page_003_blocks.domain.llm_valid_1.json: list (1 records) sample_keys=['_raw']
  ✓ page_004_blocks.domain.llm_valid_1.json: list (1 records) sample_keys=['_raw']
  ✓ page_1_blocks.domain.llm_valid_1.json: list (1 records) sample_keys=['_raw']
  ✓ page_1_blocks.domain.llm_valid_2.json: list (1 records) sample_keys=['_raw']
  ✓ page_2_blocks.domain.llm_valid_1.json: list (1 records) sample_keys=['_raw']
  ✓ page_2_blocks.domain.llm_valid_2.json: list (1 records) sample_keys=['_raw']
  ✓ page_3_blocks.domain.llm_valid_1.json: list (1 records) sample_keys