In [3]:
import pandas as pd
import numpy as np
import simple_icd_10_cm as cm
import os, json, time
from tqdm import tqdm
from openai import OpenAI
from typing import List, Dict

In [14]:
# load, filter, and sample (n=10,000) the data
def load_sampled_mimic(path="processed_data/mimiciv_icd10.parquet", n=10_000, min_codes=5, max_codes=11):
    df = pd.read_parquet(path, columns=["text", "diagnosis_codes", "note_id"])
    df = df[df["diagnosis_codes"].apply(lambda x: isinstance(x, (list, np.ndarray)) and min_codes <= len(x) < max_codes)]
    return df.sample(n=min(n, len(df)), random_state=42)

In [11]:
import json, pandas as pd

def add_note_ids(jsonl="processed_data/processed_inputs.jsonl",
                 parquet="processed_data/mimiciv_icd10.parquet",
                 out="processed_data/processed_inputs_with_ids.jsonl",
                 n=10_000, min_codes=5, max_codes=11):

    df = load_sampled_mimic(path=parquet, n=n, min_codes=min_codes, max_codes=max_codes)

    data = [json.loads(l) for l in open(jsonl)]
    print(len(data))
    print(len(df))
    assert len(data)==len(df)

    for d,(_,r) in zip(data,df.iterrows()):
        if len(d["discharge_summary"])!=len(r.text):
            raise ValueError("Length mismatch")
        d["note_id"]=r.note_id

    with open(out,"w") as f: f.writelines(json.dumps(d)+"\n" for d in data)
    print(f"Saved {len(data)} with note_id → {out}")


In [14]:
add_note_ids()

10000
10000
Saved 10000 with note_id → processed_data/processed_inputs_with_ids.jsonl


In [1]:
import json
from pathlib import Path
from typing import List, Dict, Any

def _load_json_or_jsonl(p: str) -> List[Dict[str, Any]]:
    s = Path(p).read_text().strip()
    if not s:
        return []
    if s[0] == '[':
        return json.loads(s)
    return [json.loads(l) for l in s.splitlines() if l.strip()]

def _load_json(p: str) -> Dict[str, Any]:
    return json.loads(Path(p).read_text())

def _has_evidence(e) -> bool:
    if e is None:
        return False
    if isinstance(e, list):
        return any(str(x).strip() and str(x).strip().lower() != 'no evidence' for x in e)
    s = str(e).strip()
    return bool(s) and s.lower() != 'no evidence'

def _build_combined(inputs: List[Dict], ckpt: Dict[str, Any]) -> List[Dict]:
    out = []
    for k, entries in ckpt.items():
        try:
            i = int(k)
        except ValueError:
            continue
        if not (0 <= i < len(inputs)):
            continue
        note = inputs[i]
        mimic_codes = [{'code': c.get('code'), 'title': c.get('title')}
                       for c in note.get('icd_gold_standard', [])]
        llm_validated = []
        llm_accurate = []
        for ent in entries:
            if not _has_evidence(ent.get('evidence')):
                continue
            code = ent.get('code')
            title = ent.get('title')
            accuracy = ent.get('accuracy') or {}
            better = accuracy.get('better_alternative') if isinstance(accuracy, dict) else None
            llm_validated.append({
                'code': code, 'title': title,
                'evidence': ent.get('evidence'),
                'accuracy': accuracy
            })
            code2 = better if (isinstance(better, str) and better.strip()) else code
            llm_accurate.append({
                'code': code2, 'title': title,
                'evidence': ent.get('evidence'),
                'replaced_from': code if code2 != code else None
            })
        out.append({
            'index': i,
            'note_id': note.get('note_id'),
            'discharge_summary': note.get('discharge_summary'),
            'mimic_codes': mimic_codes,
            'llm_validated_codes': llm_validated,
            'llm_accurate_codes': llm_accurate
        })
    return sorted(out, key=lambda x: x['index'])

def process_icd_validation(inputs_path: str,
                           ckpt_path: str,
                           out_dir: str = 'processed_data') -> Dict[str, Path]:
    inputs = _load_json_or_jsonl(inputs_path)
    ckpt = _load_json(ckpt_path)
    combined = _build_combined(inputs, ckpt)

    # Create output directory
    outp = Path(out_dir)
    outp.mkdir(parents=True, exist_ok=True)

    # Save full version
    combined_path = outp / 'combined.json'
    combined_path.write_text(json.dumps(combined, indent=2))

    # Save version with empty discharge summaries
    combined_no_ds = [
        {**entry, 'discharge_summary': ''} for entry in combined
    ]
    combined_no_ds_path = outp / 'combined_no_ds.json'
    combined_no_ds_path.write_text(json.dumps(combined_no_ds, indent=2))

    return {"combined": combined_path, "combined_no_ds": combined_no_ds_path}

# run:
process_icd_validation(
    'processed_data/processed_inputs_with_ids.jsonl',
    'processed_data/icd_verification_ckpt.json',
    out_dir='processed_data'
)


{'combined': PosixPath('processed_data/combined.json'),
 'combined_no_ds': PosixPath('processed_data/combined_no_ds.json')}

In [8]:
import json
from collections import Counter
import statistics

import json
from collections import Counter
import statistics

def describe_groups(data):
    defs = {
        "mimic_codes": "MIMIC gold standard (original ICD codes in dataset)",
        "llm_validated_codes": "LLM validated codes (only those with evidence kept)",
        "llm_accurate_codes": "LLM accurate codes (validated + better alternatives replacing originals)"
    }
    norm = lambda c: c.strip().upper() if isinstance(c, str) else str(c)

    print("\n=== Group definitions ===")
    for k,v in defs.items(): print(f"{k}: {v}")

    stats, sets = {}, {}
    for grp in defs:
        all_codes = [norm(c.get("code","") if isinstance(c, dict) else c) for note in data for c in note.get(grp,[])]
        sets[grp] = set(all_codes)
        counts_per_note = [len(note.get(grp,[])) for note in data]
        stats[grp] = {
            "notes": len(data),
            "total": len(all_codes),
            "avg_per_note": round(sum(counts_per_note)/len(data),2),
            "median_per_note": int(statistics.median(counts_per_note)),
            "unique": len(sets[grp]),
            "top5": Counter(all_codes).most_common(5)
        }

    print("\n=== Descriptive stats ===")
    for grp,v in stats.items():
        print(f"\n{grp}: Notes={v['notes']}, Total={v['total']}, Avg/note={v['avg_per_note']}, Median/note={v['median_per_note']}, Unique={v['unique']}")
        print("  Top5:", ", ".join(f"{c}({n})" for c,n in v["top5"]))

    print("\n=== Overlap between groups ===")
    keys = list(defs.keys())
    for i in range(len(keys)):
        for j in range(i+1,len(keys)):
            a,b = keys[i],keys[j]
            inter = sets[a]&sets[b]
            print(f"{a} ∩ {b}: {len(inter)} (≈{len(inter)/len(sets[a])*100:.1f}% of {a}, {len(inter)/len(sets[b])*100:.1f}% of {b})")

    # Evidence / accuracy
    total_mimic = sum(len(note.get("mimic_codes",[])) for note in data)
    total_validated = sum(len(note.get("llm_validated_codes",[])) for note in data)
    codes_lacking_evidence = total_mimic - total_validated

    not_accurate = sum(1 for note in data for ent in note.get("llm_validated_codes",[]) if (ent.get("accuracy") or {}).get("is_accurate") is False)

    print("\n=== Evidence / Accuracy checks ===")
    print(f"Total validated entries: {total_validated}")
    print(f"Codes lacking evidence (MIMIC minus LLM validated): {codes_lacking_evidence} (avg {round(codes_lacking_evidence/len(data),3)} per note)")
    print(f"Codes with evidence but flagged NOT accurate: {not_accurate} (avg {round(not_accurate/len(data),3)} per note)")


In [9]:
data = json.load(open("processed_data/combined.json"))
describe_groups(data)


=== Group definitions ===
mimic_codes: MIMIC gold standard (original ICD codes in dataset)
llm_validated_codes: LLM validated codes (only those with evidence kept)
llm_accurate_codes: LLM accurate codes (validated + better alternatives replacing originals)

=== Descriptive stats ===

mimic_codes: Notes=9528, Total=73389, Avg/note=7.7, Median/note=8, Unique=4453
  Top5: I10(3574), E78.5(2326), Z87.891(1938), K21.9(1651), F32.9(1318)

llm_validated_codes: Notes=9528, Total=55638, Avg/note=5.84, Median/note=6, Unique=4075
  Top5: I10(3138), E78.5(1982), K21.9(1142), F32.9(953), F41.9(883)

llm_accurate_codes: Notes=9528, Total=55638, Avg/note=5.84, Median/note=6, Unique=4685
  Top5: I10(3141), E78.5(1922), K21.9(1126), F32.9(871), E11.9(793)

=== Overlap between groups ===
mimic_codes ∩ llm_validated_codes: 4075 (≈91.5% of mimic_codes, 100.0% of llm_validated_codes)
mimic_codes ∩ llm_accurate_codes: 3857 (≈86.6% of mimic_codes, 82.3% of llm_accurate_codes)
llm_validated_codes ∩ llm_accur