In [1]:
# Paste this into a Jupyter notebook cell and run.
# Full interactive UI with fixed highlighting logic (handles overlaps, newlines, case-insensitive).

import json, os, random, re, time, html
from pathlib import Path
from typing import List, Dict

try:
    import ipywidgets as widgets
    from IPython.display import display, HTML, clear_output
    WIDGETS_AVAILABLE = True
except Exception:
    from IPython.display import display, HTML, clear_output
    WIDGETS_AVAILABLE = False

# ----------------------------
# Config (edit if your files are elsewhere)
# ----------------------------
DATA_DIR = Path("processed_data/backup")
INPUTS_FILE = DATA_DIR / "processed_inputs_with_ids.jsonl"
CHECKPOINT_FILE = DATA_DIR / "icd_verification_ckpt.json"

SAMPLES_PER_CATEGORY = 25
SEED = 42

# ----------------------------
# Helpers
# ----------------------------
def load_jsonl(path: Path) -> List[Dict]:
    if not path.exists():
        raise FileNotFoundError(f"File not found: {path}")
    items = []
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            items.append(json.loads(line))
    return items

def load_checkpoint(path: Path) -> Dict:
    if not path.exists():
        raise FileNotFoundError(f"File not found: {path}")
    with open(path, 'r', encoding='utf-8') as f:
        checkpoint = json.load(f)
    if isinstance(checkpoint, list):
        return {str(i): checkpoint[i] for i in range(len(checkpoint))}
    return checkpoint

def has_no_evidence(parsed_list: List[Dict]) -> bool:
    for entry in parsed_list:
        ev = entry.get('evidence')
        if isinstance(ev, str) and ev.strip().lower() == 'no evidence':
            return True
    return False

def has_better_alternative(parsed_list: List[Dict]) -> bool:
    for entry in parsed_list:
        acc = entry.get('accuracy') or {}
        if acc.get('is_accurate') is False:
            return True
        ba = acc.get('better_alternative')
        if ba not in (None, '', 'null'):
            return True
    return False

def escape_for_regex(s: str) -> str:
    return re.escape(s)

def highlight_text(text: str, quotes: list, max_token_gap: int = 30) -> str:
    """
    Highlight (wrap with <mark>) approximate matches of each string in `quotes` inside `text`.
    Matching behavior:
      1. Try a normalized exact match: quote whitespace collapsed -> regex with `\s+`, case-insensitive.
      2. If that fails, try an approximate in-order token match that allows up to `max_token_gap`
         characters between successive tokens (helps when evidence is paraphrased or has small diffs).
    Returns HTML (with html-escaped non-matching parts) suitable for display in Jupyter.
    
    Parameters:
      - text: the original discharge note (string).
      - quotes: list of evidence strings (from the LLM) to locate in `text`.
      - max_token_gap: max number of arbitrary characters allowed between tokens for fuzzy match.
    """
    import re, html
    if not text:
        return "<div style='white-space:pre-wrap'></div>"
    if not quotes:
        return f"<div style='white-space:pre-wrap'>{html.escape(text)}</div>"

    def _norm_space_to_regex(q: str) -> str:
        # collapse internal whitespace in the quote and make a \s+ regex between tokens
        toks = re.split(r'\s+', q.strip())
        toks = [t for t in toks if t]
        if not toks:
            return None
        # escape tokens and join with \s+ to allow varied whitespace/newlines
        return r'\b' + r'\s+'.join(re.escape(t) for t in toks) + r'\b'

    def _token_gap_regex(q: str, gap: int) -> str:
        toks = re.split(r'\s+', q.strip())
        toks = [t for t in toks if t]
        if not toks:
            return None
        # allow up to `gap` arbitrary chars between tokens (non-greedy)
        return r'\b' + r'.{0,' + str(gap) + r'}?'.join(re.escape(t) for t in toks) + r'\b'

    candidates = []  # list of (start, end, length)
    text_for_search = text  # keep original for position indices; use flags for case-insensitive

    # try to find matches for each distinct quote
    seen = set()
    for q in quotes:
        if not isinstance(q, str):
            continue
        q_str = q.strip()
        if not q_str or q_str in seen:
            continue
        seen.add(q_str)

        # 1) normalized-exact (whitespace tolerant) match
        pattern = _norm_space_to_regex(q_str)
        match = None
        if pattern:
            try:
                match = re.search(pattern, text_for_search, flags=re.IGNORECASE | re.DOTALL)
            except re.error:
                match = None

        # 2) fuzzy token-gap match if exact-like fails
        if not match:
            pattern2 = _token_gap_regex(q_str, max_token_gap)
            if pattern2:
                try:
                    match = re.search(pattern2, text_for_search, flags=re.IGNORECASE | re.DOTALL)
                except re.error:
                    match = None

        # 3) as a last resort, try a simple case-insensitive substring match of a collapsed quote
        if not match:
            collapsed_q = re.sub(r'\s+', ' ', q_str).lower()
            text_coll = re.sub(r'\s+', ' ', text_for_search).lower()
            idx = text_coll.find(collapsed_q)
            if idx != -1:
                # map collapsed index back to original by naive search of first token
                first_token = q_str.split()[0]
                try:
                    start = re.search(re.escape(first_token), text_for_search, flags=re.IGNORECASE).start()
                    # end estimate: start + len(collapsed_q) but we will clip
                    end = min(len(text_for_search), start + len(collapsed_q) + 50)
                    candidates.append((start, end, end - start))
                    continue
                except Exception:
                    pass

        if match:
            s, e = match.start(), match.end()
            if e > s:
                candidates.append((s, e, e - s))

    if not candidates:
        return f"<div style='white-space:pre-wrap'>{html.escape(text)}</div>"

    # prefer longer matches and avoid overlaps
    candidates_sorted = sorted(candidates, key=lambda x: (-x[2], x[0]))
    occupied = [False] * (len(text) + 1)
    accepted = []
    for s, e, ln in candidates_sorted:
        if any(occupied[i] for i in range(s, e)):
            continue
        for i in range(s, e):
            occupied[i] = True
        accepted.append((s, e))
    accepted.sort(key=lambda x: x[0])

    # build final HTML by escaping non-matched regions and wrapping matches with <mark>
    parts = []
    last = 0
    for s, e in accepted:
        if last < s:
            parts.append(html.escape(text[last:s]))
        parts.append(f"<mark>{html.escape(text[s:e])}</mark>")
        last = e
    if last < len(text):
        parts.append(html.escape(text[last:]))

    return "<div style='white-space:pre-wrap'>" + "".join(parts) + "</div>"


def get_gold_options(entry) -> List[tuple]:
    opts = []
    for code in entry['icd_gold_standard']:
        if isinstance(code, dict):
            code_str = code.get('code') or code.get('icd_code') or str(code)
            title = code.get('title', '')
            label = f"{code_str} — {title}" if title else code_str
        else:
            code_str = str(code)
            label = code_str
        opts.append((label, code_str))
    return opts

def find_parsed_entry_for_code(parsed_list, code_str):
    if not parsed_list:
        return None
    for p in parsed_list:
        if str(p.get('code')).upper() == str(code_str).upper():
            return p
    return None

# ----------------------------
# Load data and build subset
# ----------------------------
print("Loading inputs and checkpoint...")
inputs = load_jsonl(INPUTS_FILE)
checkpoint = load_checkpoint(CHECKPOINT_FILE)
N = len(inputs)
print(f"Loaded {N} input notes and {len(checkpoint)} checkpoint entries.")

parsed_by_index = []
for i in range(N):
    key = str(i)
    parsed_by_index.append(checkpoint.get(key))  # may be None

no_evidence_indices = [i for i, parsed in enumerate(parsed_by_index) if parsed and has_no_evidence(parsed)]
better_alt_indices = [i for i, parsed in enumerate(parsed_by_index) if parsed and has_better_alternative(parsed)]

random.seed(SEED)
sel_no_evidence = random.sample(no_evidence_indices, min(SAMPLES_PER_CATEGORY, len(no_evidence_indices)))
sel_better_alt = random.sample(better_alt_indices, min(SAMPLES_PER_CATEGORY, len(better_alt_indices)))

combined = list(dict.fromkeys(sel_no_evidence + sel_better_alt))
desired_total = min(2 * SAMPLES_PER_CATEGORY, len(no_evidence_indices) + len(better_alt_indices))
if len(combined) < desired_total:
    pool = [i for i in range(N) if parsed_by_index[i]]
    extras = [i for i in pool if i not in combined]
    random.shuffle(extras)
    combined += extras[:desired_total - len(combined)]

print(f"Selected {len(sel_no_evidence)} no-evidence notes, {len(sel_better_alt)} better-alt notes, combined subset size = {len(combined)}")

subset = []
for local_idx, global_idx in enumerate(combined):
    item = inputs[global_idx]
    parsed = parsed_by_index[global_idx]
    subset.append({
        'local_idx': local_idx,
        'global_idx': global_idx,
        'note_id': item.get('note_id') or item.get('id') or None,
        'discharge_summary': item.get('discharge_summary') or item.get('text') or '',
        'icd_gold_standard': item.get('icd_gold_standard') or item.get('icd_codes') or [],
        'parsed': parsed
    })

if not subset:
    raise RuntimeError("No entries in subset. Check checkpoint and inputs alignment.")

# ----------------------------
# UI: widgets (single output area) or fallback
# ----------------------------
def render_into_output(output_widget, idx: int, selected_code: str):
    """Render all UI HTML into the provided output widget (single place)."""
    entry = subset[idx]
    ds_text = entry['discharge_summary'] or ""
    parsed = entry['parsed'] or []

    header_html = f"<h3 style='margin:0 0 6px 0;'>Local #{idx} — Global idx {entry['global_idx']}" + (f" — note_id: {entry['note_id']}" if entry['note_id'] else "") + "</h3>"

    parsed_entry = find_parsed_entry_for_code(parsed, selected_code) if selected_code else None

    evidence_html = "<b>Evidence:</b> <br>"
    if parsed_entry is None:
        evidence_html += "(No parsed LLM entry for this ICD in checkpoint.)"
        highlighted_html = f"<div style='white-space:pre-wrap'>{html.escape(ds_text)}</div>"
    else:
        ev = parsed_entry.get('evidence')
        if isinstance(ev, str) and ev.strip().lower() == 'no evidence':
            evidence_html += "<i>No evidence reported by LLM.</i><br>"
            highlighted_html = f"<div style='white-space:pre-wrap'>{html.escape(ds_text)}</div>"
        else:
            quotes = ev if isinstance(ev, list) else [ev]
            evidence_html += "<ol style='margin:6px 0 6px 18px;'>"
            for q in quotes:
                evidence_html += f"<li>{html.escape(str(q))}</li>"
            evidence_html += "</ol>"
            highlighted_html = highlight_text(ds_text, quotes)

    acc_html = "<b>Accuracy check (LLM):</b><br>"
    if parsed_entry is None:
        acc_html += "(no data)"
    else:
        acc = parsed_entry.get('accuracy') or {}
        ia = acc.get('is_accurate')
        better = acc.get('better_alternative')
        just = acc.get('justification')
        acc_html += f"is_accurate: {html.escape(str(ia))}<br>"
        acc_html += (f"better_alternative: {html.escape(str(better))}<br>") if better not in (None, '', 'null') else "better_alternative: None<br>"
        acc_html += f"justification:<div style='white-space:pre-wrap;margin:6px 0'>{html.escape(str(just))}</div>" if just else "justification: (none)"

    # layout: left = highlighted discharge, right = evidence + accuracy
    layout_html = f"""
    {header_html}
    <div style='display:flex; gap:18px; align-items:flex-start;'>
      <div style='flex:1; padding:8px; border:1px solid #ddd; border-radius:6px; background:#fff;'>{highlighted_html}</div>
      <div style='width:420px; padding:8px;'>
         <div style='margin-bottom:10px'><b>Gold ICD (selected):</b> {html.escape(selected_code) if selected_code else '(none selected)'}</div>
         <div style='margin-bottom:8px'>{evidence_html}</div>
         <hr style='margin:8px 0'>
         <div>{acc_html}</div>
      </div>
    </div>
    """
    with output_widget:
        clear_output(wait=True)
        display(HTML(layout_html))

if WIDGETS_AVAILABLE:
    # create controls
    slider = widgets.IntSlider(value=0, min=0, max=len(subset)-1, description='Note #', continuous_update=False, layout=widgets.Layout(width='50%'))
    code_dropdown = widgets.Dropdown(options=[("(no gold codes)","")], description='Gold ICD:', layout=widgets.Layout(width='45%'))
    prev_btn = widgets.Button(description='◀ Prev', layout=widgets.Layout(width='80px'))
    next_btn = widgets.Button(description='Next ▶', layout=widgets.Layout(width='80px'))
    output_area = widgets.Output(layout={'border': '1px solid #ddd', 'padding':'6px'})

    def update_dropdown_options_for_slider(idx):
        entry = subset[idx]
        opts = get_gold_options(entry)
        if not opts:
            code_dropdown.options = [("(no gold codes)","")]
            code_dropdown.value = ""
        else:
            code_dropdown.options = opts
            # set to first option (this will trigger the observer to render)
            code_dropdown.value = code_dropdown.options[0][1]

    def on_slider_change(change):
        if change['name'] == 'value':
            update_dropdown_options_for_slider(change['new'])
            # if dropdown value is empty (no codes), render manually
            if code_dropdown.value == "":
                render_into_output(output_area, change['new'], code_dropdown.value)

    def on_code_change(change):
        if change['name'] == 'value':
            render_into_output(output_area, slider.value, change['new'])

    def on_prev(_):
        slider.value = max(slider.min, slider.value - 1)

    def on_next(_):
        slider.value = min(slider.max, slider.value + 1)

    slider.observe(on_slider_change, names='value')
    code_dropdown.observe(on_code_change, names='value')
    prev_btn.on_click(on_prev)
    next_btn.on_click(on_next)

    # initialize dropdown for first slider position and render
    update_dropdown_options_for_slider(0)
    # layout controls and output
    controls_row = widgets.HBox([prev_btn, next_btn, slider, code_dropdown], layout=widgets.Layout(margin='6px 0 12px 0'))
    display(controls_row)
    display(output_area)
    # initial render
    render_into_output(output_area, 0, code_dropdown.value)

    print("Interactive UI ready. Use Prev/Next or the slider; choose a gold ICD from the dropdown.")
else:
    # basic fallback
    print("ipywidgets not available. Install ipywidgets for full interactivity (pip install ipywidgets). Falling back to console navigation.")
    idx = 0
    while True:
        entry = subset[idx]
        print(f"\n=== Local #{idx} — Global idx {entry['global_idx']} — note_id: {entry['note_id']}")
        print("Gold ICDs:")
        for c in entry['icd_gold_standard']:
            print("  -", c)
        sel = input("Type ICD code to inspect (or press Enter to skip), or command (n=next, p=prev, q=quit): ").strip()
        if sel.lower() == 'q':
            break
        if sel.lower() == 'n':
            idx = min(idx+1, len(subset)-1)
            continue
        if sel.lower() == 'p':
            idx = max(0, idx-1)
            continue
        if sel == '':
            idx = min(idx+1, len(subset)-1)
            continue
        parsed_entry = find_parsed_entry_for_code(entry['parsed'] or [], sel)
        if parsed_entry is None:
            print("(No parsed LLM entry for this ICD in checkpoint.)")
        else:
            ev = parsed_entry.get('evidence')
            print("Evidence:", ev)
            print("Accuracy:", parsed_entry.get('accuracy'))
        input("Press Enter to continue...")
        idx = min(idx+1, len(subset)-1)

print("Done.")


  """


Loading inputs and checkpoint...
Loaded 10000 input notes and 9528 checkpoint entries.
Selected 25 no-evidence notes, 25 better-alt notes, combined subset size = 50


HBox(children=(Button(description='◀ Prev', layout=Layout(width='80px'), style=ButtonStyle()), Button(descript…

Output(layout=Layout(border_bottom='1px solid #ddd', border_left='1px solid #ddd', border_right='1px solid #dd…

Interactive UI ready. Use Prev/Next or the slider; choose a gold ICD from the dropdown.
Done.
