In [None]:
import os
import re
import pickle
import xml.etree.ElementTree as ET

# -----------------------------
# Configuration (single setting)
# -----------------------------
SYN_NOTES_PER_REAL_NOTE = 1
CONTEXT = 15

PHI_LIST = [
    "AGE", "BIOID", "CITY", "DATE", "DEVICE", "EMAIL", "FAX", "HEALTHPLAN",
    "IDNUM", "MEDICALRECORD", "ORGANIZATION", "PATIENT", "PHONE", "STREET",
    "URL", "ZIP",
]
# Other possible tags:
# ["STATE", "COUNTRY", "LOCATION-OTHER", "HOSPITAL", "DOCTOR", "USERNAME", "PROFESSION"]

I2B2_XML_DIR = "/path/to/i2b2_xml_dir"                # e.g., "/workspace/i2b2/"
PROMPT_PKL_PATH = f"/path/to/prompt_dictionary.pkl"
SYNTH_PKL_PATH = "/path/to/synthetic_note.pkl"

files_with_any_leak = set()


# -----------------------------
# Helpers
# -----------------------------
def all_occurrences(haystack: str, needle: str) -> list[int]:
    """Return list of start indices for all non-overlapping occurrences of needle in haystack."""
    if not needle:
        return []
    return [m.start() for m in re.finditer(re.escape(needle), haystack)]


def get_context(s: str, start: int, length: int, context: int = CONTEXT) -> tuple[str, int, int]:
    left = max(0, start - context)
    right = min(len(s), start + length + context)
    return s[left:right], left, right


def extract_prompt_substring(text: str) -> str:
    """
    Returns text from the beginning up to (but not including) 'An example note is:'.
    If substring isn't found, returns an empty string.
    """
    substring = "An example note is:"
    idx = text.find(substring)
    return text[:idx] if idx != -1 else ""


# -----------------------------
# Load data
# -----------------------------
with open(SYNTH_PKL_PATH, "rb") as f:
    synth_notes = pickle.load(f)

with open(PROMPT_PKL_PATH, "rb") as f:
    prompts = pickle.load(f)

note_order = list(prompts.keys())


# -----------------------------
# Scan for leakage
# -----------------------------
for note_idx in range(len(note_order)):
    filename = f"{note_order[note_idx]}.xml"
    xml_file = os.path.join(I2B2_XML_DIR, filename)

    tree = ET.parse(xml_file)
    root = tree.getroot()
    raw_note = root.find("TEXT").text or ""

    # Get the synthetic note for this real note
    generated_note = synth_notes[note_idx]

    # Avoid printing the same (PHI_type, PHI_text) repeatedly within the same file
    printed_in_file = set()

    for tag in root.findall(".//TAGS/*"):
        phi_text = tag.attrib.get("text", "")
        phi_type = tag.attrib.get("TYPE", "")

        start = int(tag.attrib.get("start", -1))
        end = int(tag.attrib.get("end", -1))

        if phi_type not in PHI_LIST or not phi_text:
            continue

        occ = all_occurrences(generated_note, phi_text)
        if not occ:
            continue

        key = (phi_type, phi_text)
        if key in printed_in_file:
            continue
        printed_in_file.add(key)

        files_with_any_leak.add(filename)

        # Original context from the XML indices (end is inclusive in i2b2-style tags)
        if 0 <= start < len(raw_note) and 0 <= end < len(raw_note) and start <= end:
            left = max(0, start - CONTEXT)
            right = min(len(raw_note), end + 1 + CONTEXT)
            orig_context = raw_note[left:right]
        else:
            orig_context = "[ORIGINAL CONTEXT UNAVAILABLE]"

        print(f"\n===== File: {filename} =====")
        print(f"[LEAKED] {phi_text} - {phi_type} (example XML start={start})")
        print(f"Original: ...{orig_context}...")

        for j, idx in enumerate(occ, start=1):
            red_context, _, _ = get_context(generated_note, idx, len(phi_text), CONTEXT)
            print(f"  Redacted ({j} of {len(occ)} @ idx={idx}): ...{red_context}...")


# -----------------------------
# Summary
# -----------------------------
print("\n\n================ SUMMARY ================")
print(f"Total files with any leakage: {len(files_with_any_leak)}")
