In [3]:
import pandas as pd
import numpy as np
import simple_icd_10_cm as cm
import os, json, time
from tqdm import tqdm
from openai import OpenAI
from typing import List, Dict

In [14]:
# load, filter, and sample (n=10,000) the data
def load_sampled_mimic(path="processed_data/mimiciv_icd10.parquet", n=10_000, min_codes=5, max_codes=11):
    df = pd.read_parquet(path, columns=["text", "diagnosis_codes", "note_id"])
    df = df[df["diagnosis_codes"].apply(lambda x: isinstance(x, (list, np.ndarray)) and min_codes <= len(x) < max_codes)]
    return df.sample(n=min(n, len(df)), random_state=42)

# Get the title of an ICD code
def get_icd_title(code, path="processed_data/icd10cm_dict.parquet"):
    df = pd.read_parquet(path)
    row = df.loc[df['icd_code'] == code]
    return row.iloc[0]['long_title'] if not row.empty else "Unknown"

# Create structured input format
def format_row(row, title_dict_path="processed_data/icd10cm_dict.parquet"):
    summary = row['text']
    gold = row['diagnosis_codes']

    out = []
    for code in gold:
        cat = code.split('.')[0]
        diffs = [d for d in cm.get_descendants(cat) if d != code and d not in gold]
        title = cm.get_description(code) if cm.is_valid_item(code) else get_icd_title(code, title_dict_path)
        diff_list = [{"code": d, "title": cm.get_description(d) if cm.is_valid_item(d) else get_icd_title(d, title_dict_path)} for d in diffs]
        out.append({"code": code, "title": title, "differential_codes": diff_list})

    return {"discharge_summary": summary, "icd_gold_standard": out}

# Loads the input file or the df and then processes and saves the input file
def process_and_save(path="processed_data/processed_inputs_with_ids.jsonl",
                     mimic_path="processed_data/mimiciv_icd10.parquet", n=10_000, min_codes=5, max_codes=11):
    if os.path.exists(path):
        print(f"Loading saved data from {path}")
        with open(path) as f:
            return [json.loads(line) for line in f]
    df = load_sampled_mimic(mimic_path, n, min_codes, max_codes)
    inputs = [format_row(row) for _, row in tqdm(df.iterrows(), total=len(df))]
    with open(path, "w") as f:
        f.writelines(json.dumps(item) + "\n" for item in inputs)
    print(f"Saved {len(inputs)} entries to {path}")
    return inputs

In [15]:
inputs = process_and_save()

Loading saved data from processed_data/processed_inputs_with_ids.jsonl


In [16]:
import json

def json_structure(data, indent=0):
    prefix = "  " * indent
    if isinstance(data, dict):
        result = {}
        for k, v in data.items():
            result[k] = json_structure(v, indent+1)
        return result
    elif isinstance(data, list):
        if data:
            return [json_structure(data[0], indent+1)]
        else:
            return []
    else:
        return type(data).__name__

# Example
data = inputs

print(json.dumps(json_structure(data), indent=2))


[
  {
    "discharge_summary": "str",
    "icd_gold_standard": [
      {
        "code": "str",
        "title": "str",
        "differential_codes": [
          {
            "code": "str",
            "title": "str"
          }
        ]
      }
    ],
    "note_id": "str"
  }
]


In [28]:
import json

n = 0   # <-- Change this to the index you want to compare

# Load JSON output
with open("processed_data/icd_verification_ckpt.json", "r") as f:
    outputs = json.load(f)

# Load JSONL input
inputs = []
with open("processed_data/processed_inputs.jsonl", "r") as f:
    for line in f:
        inputs.append(json.loads(line))

# Extract only the ICD codes from input (excluding differential_codes)
input_icd = inputs[n].get("icd_gold_standard", [])

# Extract just the codes from the input
input_codes = [entry['code'] for entry in input_icd]

# If outputs is a dict, use str(n) as the key
if isinstance(outputs, dict):
    output_icd = outputs.get(str(n), None)
elif isinstance(outputs, list):
    output_icd = outputs[n] if n < len(outputs) else None
else:
    output_icd = None

input_codes

['Z72.0', 'E86.1', 'L03.012', 'N17.9', 'Z21']

In [17]:
 output_icd = outputs['2437']

In [18]:
output_icd

[{'code': 'r64',
  'title': 'cachexia',
  'evidence': ['general - cachectic, in no distress'],
  'accuracy': {'is_accurate': True,
   'better_alternative': None,
   'justification': "the discharge summary directly describes the patient as 'cachectic,' supporting the use of code r64."}},
 {'code': 'r00.1',
  'title': 'bradycardia, unspecified',
  'evidence': ['patient states that she has had difficulty with eating for about a year.',
   'cardiovasc - regular, borderlien bradycardic, normal s1/s2, no murmur',
   '# anorexia nervosa presented voluntarily with symptoms of severe malnutrition/underweight with bradycardia.'],
  'accuracy': {'is_accurate': True,
   'better_alternative': None,
   'justification': "the discharge summary mentions bradycardia multiple times in the context of the patient's anorexia nervosa and malnutrition, supporting the use of r00.1."}},
 {'code': 'n39.0',
  'title': 'urinary tract infection, site not specified',
  'evidence': 'no evidence',
  'accuracy': {'is_a

In [11]:
import json, pandas as pd

def add_note_ids(jsonl="processed_data/processed_inputs.jsonl",
                 parquet="processed_data/mimiciv_icd10.parquet",
                 out="processed_data/processed_inputs_with_ids.jsonl",
                 n=10_000, min_codes=5, max_codes=11):

    df = load_sampled_mimic(path=parquet, n=n, min_codes=min_codes, max_codes=max_codes)

    data = [json.loads(l) for l in open(jsonl)]
    print(len(data))
    print(len(df))
    assert len(data)==len(df)

    for d,(_,r) in zip(data,df.iterrows()):
        if len(d["discharge_summary"])!=len(r.text):
            raise ValueError("Length mismatch")
        d["note_id"]=r.note_id

    with open(out,"w") as f: f.writelines(json.dumps(d)+"\n" for d in data)
    print(f"Saved {len(data)} with note_id → {out}")


In [14]:
add_note_ids()

10000
10000
Saved 10000 with note_id → processed_data/processed_inputs_with_ids.jsonl


In [15]:
import json
import pandas as pd

def validate_icd_codes(jsonl="processed_data/processed_inputs_with_ids.jsonl",
                       parquet="processed_data/mimiciv_icd10.parquet"):
    # Load dataframe with note_id and diagnosis_codes
    df = pd.read_parquet(parquet, columns=["note_id","diagnosis_codes"])
    df_dict = df.set_index("note_id")["diagnosis_codes"].to_dict()

    # Load JSONL
    data = [json.loads(l) for l in open(jsonl)]

    mismatches = []
    for item in data:
        nid = item.get("note_id")
        if nid not in df_dict:
            mismatches.append((nid, "note_id not in DF"))
            continue
        json_codes = sorted([x["code"] for x in item["icd_gold_standard"]])
        df_codes = sorted(df_dict[nid])
        if json_codes != df_codes:
            mismatches.append((nid, {"json": json_codes, "df": df_codes}))

    if not mismatches:
        print("✅ All ICD codes match the original dataframe.")
    else:
        print(f"❌ {len(mismatches)} mismatches found:")
        for m in mismatches[:10]:  # show first 10 for brevity
            print(m)

    return mismatches


In [19]:
import json, random, math
from pathlib import Path
from typing import List, Dict, Any

def _load_json_or_jsonl(p: str) -> List[Dict[str,Any]]:
    s = Path(p).read_text().strip()
    if not s:
        return []
    if s[0] == '[':
        return json.loads(s)
    return [json.loads(l) for l in s.splitlines() if l.strip()]

def _load_json(p: str) -> Dict[str,Any]:
    return json.loads(Path(p).read_text())

def _has_evidence(e) -> bool:
    if e is None: return False
    if isinstance(e, list):
        return any(str(x).strip() and str(x).strip().lower() != 'no evidence' for x in e)
    s = str(e).strip()
    return bool(s) and s.lower() != 'no evidence'

def _build_combined(inputs: List[Dict], ckpt: Dict[str,Any]) -> List[Dict]:
    out = []
    for k, entries in ckpt.items():
        try:
            i = int(k)
        except ValueError:
            continue
        if not (0 <= i < len(inputs)):
            continue
        note = inputs[i]
        mimic_codes = [{'code': c.get('code'), 'title': c.get('title')} 
                       for c in note.get('icd_gold_standard', [])]
        llm_validated = []
        llm_accurate = []
        for ent in entries:
            if not _has_evidence(ent.get('evidence')):
                continue
            code = ent.get('code')
            title = ent.get('title')
            accuracy = ent.get('accuracy') or {}
            better = accuracy.get('better_alternative') if isinstance(accuracy, dict) else None
            llm_validated.append({
                'code': code, 'title': title,
                'evidence': ent.get('evidence'),
                'accuracy': accuracy
            })
            code2 = better if (isinstance(better, str) and better.strip()) else code
            llm_accurate.append({
                'code': code2, 'title': title,
                'evidence': ent.get('evidence'),
                'replaced_from': code if code2 != code else None
            })
        out.append({
            'index': i,
            'note_id': note.get('note_id'),
            'discharge_summary': note.get('discharge_summary'),
            'mimic_codes': mimic_codes,
            'llm_validated_codes': llm_validated,
            'llm_accurate_codes': llm_accurate
        })
    return sorted(out, key=lambda x: x['index'])

def _split_and_write(all_notes: List[Dict], out_dir: str,
                     splits=(0.729, 0.109, 0.162), seed=42) -> Dict[str, Path]:
    total = len(all_notes)
    rng = random.Random(seed)
    idxs = list(range(total))
    rng.shuffle(idxs)
    n_train = int(round(total * splits[0]))
    n_val = int(round(total * splits[1]))
    n_train = min(n_train, total)
    n_val = min(n_val, total - n_train)
    n_test = total - n_train - n_val
    train = [all_notes[i] for i in idxs[:n_train]]
    val = [all_notes[i] for i in idxs[n_train:n_train + n_val]]
    test = [all_notes[i] for i in idxs[n_train + n_val:]]
    outp = Path(out_dir)
    outp.mkdir(parents=True, exist_ok=True)
    (outp / 'combined.json').write_text(json.dumps(all_notes, indent=2))
    (outp / 'train.json').write_text(json.dumps(train, indent=2))
    (outp / 'val.json').write_text(json.dumps(val, indent=2))
    (outp / 'test.json').write_text(json.dumps(test, indent=2))
    return {'combined': outp/'combined.json', 'train': outp/'train.json',
            'val': outp/'val.json', 'test': outp/'test.json'}

def process_icd_validation(inputs_path: str,
                           ckpt_path: str,
                           out_dir: str = 'processed_data',
                           splits=(0.729, 0.109, 0.162),
                           seed: int = 42) -> Dict[str, Path]:
    inputs = _load_json_or_jsonl(inputs_path)
    ckpt = _load_json(ckpt_path)
    combined = _build_combined(inputs, ckpt)
    return _split_and_write(combined, out_dir, splits, seed)

# one-line run (will execute when you run the file / paste into REPL):
process_icd_validation('processed_data/processed_inputs_with_ids.jsonl',
                       'processed_data/icd_verification_ckpt.json',
                       out_dir='processed_data', splits=(0.729,0.109,0.162), seed=42)


{'combined': PosixPath('processed_data/combined.json'),
 'train': PosixPath('processed_data/train.json'),
 'val': PosixPath('processed_data/val.json'),
 'test': PosixPath('processed_data/test.json')}

In [4]:
import json
from collections import Counter

def normalize_code(c):
    return c.strip().upper()

codes = {normalize_code(c) for c in codes}


def describe_groups(data):
    defs = {
        "mimic_codes": "MIMIC gold standard (original ICD codes in dataset)",
        "llm_validated_codes": "LLM validated codes (only those with evidence kept)",
        "llm_accurate_codes": "LLM accurate codes (validated + better alternatives replacing originals)"
    }
    print("\n=== Group Definitions ===")
    for k,v in defs.items():
        print(f"{k}: {v}")

    stats, all_sets = {}, {}
    for grp in defs:
        codes_per_note = [len(note.get(grp, [])) for note in data]
        all_codes = [c['code'] for note in data for c in note.get(grp,[])]
        all_sets[grp] = set(all_codes)
        stats[grp] = {
            "Notes": len(data),
            "Total codes": len(all_codes),
            "Avg codes/note": round(sum(codes_per_note)/len(data),2),
            "Median codes/note": sorted(codes_per_note)[len(data)//2],
            "Unique codes": len(set(all_codes)),
            "Top5": Counter(all_codes).most_common(5)
        }

    print("\n=== Descriptive Stats ===")
    for grp,vals in stats.items():
        print(f"\n{grp}:")
        for k,v in vals.items():
            print(f"  {k}: {v}")

    print("\n=== Overlap Between Groups (unique codes) ===")
    keys = list(defs.keys())
    for i in range(len(keys)):
        for j in range(i+1,len(keys)):
            a,b = keys[i],keys[j]
            inter = len(all_sets[a] & all_sets[b])
            print(f"{a} ∩ {b}: {inter}")

    # evidence / accuracy checks
    no_evidence, with_evidence_not_acc = 0,0
    for note in data:
        for ent in note.get("llm_validated_codes",[]):
            ev = ent.get("evidence")
            acc = ent.get("accuracy",{})
            if not ev or (isinstance(ev,str) and ev.strip().lower()=="no evidence") or (isinstance(ev,list) and not any(x.strip() for x in ev)):
                no_evidence+=1
            elif acc and acc.get("is_accurate") is False:
                with_evidence_not_acc+=1

    print("\n=== Evidence / Accuracy Checks ===")
    print(f"Codes lacking evidence (but present in checkpoint): {no_evidence}")
    print(f"Codes with evidence but flagged not accurate: {with_evidence_not_acc}")

In [5]:
data = json.load(open("processed_data/combined.json"))
describe_groups(data)


=== Group Definitions ===
mimic_codes: MIMIC gold standard (original ICD codes in dataset)
llm_validated_codes: LLM validated codes (only those with evidence kept)
llm_accurate_codes: LLM accurate codes (validated + better alternatives replacing originals)

=== Descriptive Stats ===

mimic_codes:
  Notes: 9528
  Total codes: 73389
  Avg codes/note: 7.7
  Median codes/note: 8
  Unique codes: 4453
  Top5: [('I10', 3574), ('E78.5', 2326), ('Z87.891', 1938), ('K21.9', 1651), ('F32.9', 1318)]

llm_validated_codes:
  Notes: 9528
  Total codes: 55638
  Avg codes/note: 5.84
  Median codes/note: 6
  Unique codes: 4081
  Top5: [('i10', 3137), ('e78.5', 1982), ('k21.9', 1142), ('f32.9', 953), ('f41.9', 882)]

llm_accurate_codes:
  Notes: 9528
  Total codes: 55638
  Avg codes/note: 5.84
  Median codes/note: 6
  Unique codes: 4691
  Top5: [('i10', 3140), ('e78.5', 1922), ('k21.9', 1126), ('f32.9', 871), ('e11.9', 793)]

=== Overlap Between Groups (unique codes) ===
mimic_codes ∩ llm_validated_code

In [1]:
import json
from pathlib import Path
from typing import Dict, Any

def load_json_or_jsonl(p: str):
    s = Path(p).read_text().strip()
    if not s:
        return []
    if s[0] == '[':
        return json.loads(s)
    return [json.loads(l) for l in s.splitlines() if l.strip()]

def load_json(p: str):
    return json.loads(Path(p).read_text())

def validate_index(idx: int,
                   inputs_path='processed_data/processed_inputs_with_ids.jsonl',
                   ckpt_path='processed_data/icd_verification_ckpt.json',
                   combined_path='processed_data/combined.json') -> None:
    inputs = load_json_or_jsonl(inputs_path)
    ckpt = load_json(ckpt_path)
    combined = load_json(combined_path)

    note = inputs[idx]
    mimic = [{'code': c['code'], 'title': c['title']}
             for c in note.get('icd_gold_standard', [])]

    ckpt_entry = ckpt.get(str(idx), [])
    comb_entry = next((c for c in combined if c['index'] == idx), None)

    print(f"\n=== Index {idx} / Note ID {note.get('note_id')} ===")
    print("\n--- MIMIC gold-standard ---")
    for c in mimic:
        print(f"  {c['code']} | {c['title']}")

    print("\n--- Checkpoint (LLM raw) ---")
    for e in ckpt_entry:
        code, title = e.get('code'), e.get('title')
        ev = e.get('evidence')
        acc = e.get('accuracy')
        print(f"  {code} | {title}\n    evidence: {ev}\n    accuracy: {acc}")

    if comb_entry:
        print("\n--- LLM validated codes ---")
        for c in comb_entry['llm_validated_codes']:
            print(f"  {c['code']} | {c['title']}")

        print("\n--- LLM accurate codes ---")
        for c in comb_entry['llm_accurate_codes']:
            rep = f" (replaced from {c['replaced_from']})" if c.get('replaced_from') else ""
            print(f"  {c['code']} | {c['title']}{rep}")
    else:
        print("\n(No combined entry for this index)")
