## 3D postprocessing of Chartis lobe segmentation predictions

In [None]:
import os
import nibabel as nib
import numpy as np
from scipy.ndimage import label, binary_dilation
from pathlib import Path

# === Directories ===
input_dir = r"------ INSERT PATH HERE ------"
output_dir = r"------ INSERT PATH HERE ------"
os.makedirs(output_dir, exist_ok=True)

# === Function to assign an isolated component to the neighboring label with largest contact ===
def reassign_to_largest_border_component(island_mask, full_mask):
    dilated = binary_dilation(island_mask, iterations=1)
    border_voxels = dilated & (full_mask > 0) & (~island_mask)
    neighbor_labels, counts = np.unique(full_mask[border_voxels], return_counts=True)
    if len(counts) == 0:
        return 0  # assign to background
    return neighbor_labels[np.argmax(counts)]

# === Main postprocessing function ===
def remove_isolated_components(filepath, output_dir):
    img = nib.load(filepath)
    data = img.get_fdata().astype(np.uint8)
    new_data = np.zeros_like(data)

    label_changes = {}
    to_background_volumes = {}
    total_voxels_changed = 0

    for label_id in range(1, 6):  # Labels 1 to 5
        binary = (data == label_id)
        if not np.any(binary):
            continue
        labeled_cc, num = label(binary)
        sizes = np.bincount(labeled_cc.ravel())
        sizes[0] = 0
        if len(sizes) == 1:
            new_data[labeled_cc == 1] = label_id
            continue
        largest_cc = np.argmax(sizes)
        new_data[labeled_cc == largest_cc] = label_id
        for i in range(1, num + 1):
            if i == largest_cc:
                continue
            island_mask = (labeled_cc == i)
            new_label = reassign_to_largest_border_component(island_mask, data)
            vol = np.sum(island_mask)
            total_voxels_changed += vol
            if new_label == 0:
                to_background_volumes[filepath.name] = to_background_volumes.get(filepath.name, 0) + vol
            else:
                key = (filepath.name, label_id, new_label)
                label_changes[key] = label_changes.get(key, 0) + vol
            new_data[island_mask] = new_label

    out_path = os.path.join(output_dir, os.path.basename(filepath))
    nib.save(nib.Nifti1Image(new_data, img.affine, img.header), out_path)
    return label_changes, to_background_volumes, total_voxels_changed

# === Process all files ===
all_files = list(Path(input_dir).glob("*.nii.gz"))
total_label_changes = {}
total_bg_changes = {}

for f in all_files:
    out_path = Path(output_dir) / f.name
    if out_path.exists():
        print(f"⏭️ Skipping {f.name} — already processed.")
        continue

    print(f"🔄 Processing {f.name} ...")
    label_chg, bg_chg, total_changed = remove_isolated_components(f, output_dir)
    for k, v in label_chg.items():
        total_label_changes[k] = total_label_changes.get(k, 0) + v
    for k, v in bg_chg.items():
        total_bg_changes[k] = total_bg_changes.get(k, 0) + v
    print(f"✅ Done {f.name} — total voxels reassigned: {total_changed}")

# === Summary ===
print("\n🔝 Top 5 conversions from one label to another:")
for (name, from_label, to_label), vol in sorted(total_label_changes.items(), key=lambda x: x[1], reverse=True)[:5]:
    print(f"{name}: {vol} voxels from label {from_label} → {to_label}")

print("\n🔝 Top 5 conversions from label to background:")
for name, vol in sorted(total_bg_changes.items(), key=lambda x: x[1], reverse=True)[:5]:
    print(f"{name}: {vol} voxels converted to background")


## Calculating the fissure completeness of Chartis dataset

In [None]:
import os, numpy as np, nibabel as nib, pandas as pd
from pathlib import Path
from tqdm import tqdm

# ── Input paths ─────────────────────────────────────────────────────────────
root_dir = r"------ INSERT PATH HERE ------"
pred_dir = os.path.join(root_dir, "Pred")

dirs = {
    "LOF": os.path.join(pred_dir, "pred_folder_LOF134_best5000"),
    "ROF": os.path.join(pred_dir, "pred_folder_ROF134_best5000"),
    "RHF": os.path.join(pred_dir, "pred_folder_RHF134_best5000"),
}
lobe_dir = os.path.join(pred_dir, "pred_folder_lobes134_best25003dEdited")
out_xls  = os.path.join(pred_dir, "fissure_completeness_all_columnwise.xlsx")

# ── Algorithm params ────────────────────────────────────────────────────────
z_window, max_search, max_gap = 5, 4, 5
junctions = {
    "LOF": (4, 5),
    "RHF": (1, 2),
    "ROF1": (1, 3),
    "ROF2": (2, 3)
}

def get_covered_and_total(fiss, lob, A, B):
    h, n_slices, w = lob.shape
    covered = total = 0
    for y in range(n_slices):
        lob_sl, fiss_sl = lob[:, y, :], fiss[:, y, :]
        prevA = prevB = None
        for x in range(h):
            col = lob_sl[x, :]
            if A not in col or B not in col:
                continue
            candA = [z for z in np.where(col == A)[0] if np.any(col[max(z-max_search,0): z+max_search+1] == B)]
            candB = [z for z in np.where(col == B)[0] if np.any(col[max(z-max_search,0): z+max_search+1] == A)]
            border_z = []
            if candA:
                zA = min(candA)
                if prevA is None or abs(zA-prevA) <= max_gap:
                    border_z.append(zA); prevA = zA
                else: prevA = None
            if candB:
                zB = max(candB)
                if prevB is None or abs(zB-prevB) <= max_gap:
                    border_z.append(zB); prevB = zB
                else: prevB = None
            for z in border_z:
                total += 1
                for dz in range(-z_window, z_window+1):
                    zz = z + dz
                    if 0 <= zz < w and fiss_sl[x, zz]:
                        covered += 1
                        break
    return covered, total

def strip_nii(name: str) -> str:
    return name.replace(".nii.gz", "").replace(".nii", "").split("_postprocessed")[0]

all_mrns = set()
for ftype, fdir in dirs.items():
    all_mrns.update(strip_nii(p.name) for p in Path(fdir).rglob("*.nii*"))

print(f"🔍 Found {len(all_mrns)} unique MRNs across all fissure predictions")

records = []

for mrn in tqdm(sorted(all_mrns), desc="Computing all completeness values"):
    print(f"\n📋 Processing MRN: {mrn}")
    
    lobepath = Path(lobe_dir) / f"{mrn}.nii.gz"
    if not lobepath.exists():
        lobepath = Path(lobe_dir) / f"{mrn}.nii"
    if not lobepath.exists():
        print(f"   ❌ Lobe mask missing for {mrn} - skipped")
        continue

    lob = nib.load(lobepath).get_fdata().astype(np.uint8)
    entry = {"MRN": mrn}
    cov_tot_sum = [0, 0]
    rul_cov_tot = [0, 0]
    
    # LOF
    fpath = Path(dirs["LOF"]) / f"{mrn}.nii.gz"
    if not fpath.exists():
        fpath = Path(dirs["LOF"]) / f"{mrn}.nii"
    if not fpath.exists():
        entry["LOF"] = np.nan
        print(f"   ❌ LOF prediction missing")
    else:
        fiss = (nib.load(fpath).get_fdata() > 0).astype(np.uint8)
        cov, tot = get_covered_and_total(fiss, lob, *junctions["LOF"])
        entry["LOF"] = np.nan if tot == 0 else 100.0 * cov / tot
        print(f"   ✅ LOF: {entry['LOF']:.1f}% ({cov}/{tot})")

    # RHF
    fpath = Path(dirs["RHF"]) / f"{mrn}.nii.gz"
    if not fpath.exists():
        fpath = Path(dirs["RHF"]) / f"{mrn}.nii"
    if not fpath.exists():
        entry["RHF"] = np.nan
        print(f"   ❌ RHF prediction missing")
    else:
        fiss = (nib.load(fpath).get_fdata() > 0).astype(np.uint8)
        cov, tot = get_covered_and_total(fiss, lob, *junctions["RHF"])
        entry["RHF"] = np.nan if tot == 0 else 100.0 * cov / tot
        cov_tot_sum[0] += cov
        cov_tot_sum[1] += tot
        rul_cov_tot[0] += cov
        rul_cov_tot[1] += tot
        print(f"   ✅ RHF: {entry['RHF']:.1f}% ({cov}/{tot})")

    # ROF
    fpath = Path(dirs["ROF"]) / f"{mrn}.nii.gz"
    if not fpath.exists():
        fpath = Path(dirs["ROF"]) / f"{mrn}.nii"
    if not fpath.exists():
        entry["ROF_combined"] = np.nan
        print(f"   ❌ ROF prediction missing")
    else:
        fiss = (nib.load(fpath).get_fdata() > 0).astype(np.uint8)
        rof_total_cov = rof_total_tot = 0
        cov1, tot1 = get_covered_and_total(fiss, lob, *junctions["ROF1"])
        rof_total_cov += cov1
        rof_total_tot += tot1
        rul_cov_tot[0] += cov1
        rul_cov_tot[1] += tot1
        print(f"   📊 ROF1 (1↔3): ({cov1}/{tot1})")
        cov2, tot2 = get_covered_and_total(fiss, lob, *junctions["ROF2"])
        rof_total_cov += cov2
        rof_total_tot += tot2
        print(f"   📊 ROF2 (2↔3): ({cov2}/{tot2})")
        entry["ROF_combined"] = np.nan if rof_total_tot == 0 else 100.0 * rof_total_cov / rof_total_tot
        print(f"   ✅ ROF Combined: {entry['ROF_combined']:.1f}% ({rof_total_cov}/{rof_total_tot})")
        cov_tot_sum[0] += cov2
        cov_tot_sum[1] += tot2

    entry["RML1"] = np.nan if cov_tot_sum[1] == 0 else 100.0 * cov_tot_sum[0] / cov_tot_sum[1]
    print(f"   ✅ RML1 (RHF+ROF2): {entry['RML1']:.1f}% ({cov_tot_sum[0]}/{cov_tot_sum[1]})")

    entry["RUL"] = np.nan if rul_cov_tot[1] == 0 else 100.0 * rul_cov_tot[0] / rul_cov_tot[1]
    print(f"   ✅ RUL (RHF+ROF1): {entry['RUL']:.1f}% ({rul_cov_tot[0]}/{rul_cov_tot[1]})")

    records.append(entry)

df = pd.DataFrame(records).sort_values("MRN").reset_index(drop=True)
df.to_excel(out_xls, index=False)
print(f"\n✅ Saved to {out_xls}")
print(f"📊 Processed {len(records)} MRNs total")


## Quality check of the predicted fissures vs lobe borders - Chartis dataset

In [None]:
import os, numpy as np, nibabel as nib, pandas as pd
from pathlib import Path

# === Parameters ===
threshold_fissure_far_from_border = 30
threshold_multiple_borders = 30

# === Paths (Chartis CTs) ===
root_dir = r"------ INSERT PATH HERE ------"
lobes_dir = os.path.join(root_dir, "pred_folder_lobes134_best25003dEdited")
fissure_dirs = {
    "LOF": os.path.join(root_dir, "pred_folder_LOF134_best5000"),
    "ROF": os.path.join(root_dir, "pred_folder_ROF134_best5000"),
    "RHF": os.path.join(root_dir, "pred_folder_RHF134_best5000"),
}
out_xls = os.path.join(root_dir, "fissure_not_amenable_summary_chartis.xlsx")

# === Define fissure border pairs ===
fissure_defs = {
    "LOF": [(4, 5)],
    "ROF": [(1, 3), (2, 3)],
    "RHF": [(1, 2)],
    "RUL": [(1, 2), (1, 3)] 
}

def strip_nii(name): return name.replace(".nii.gz", "").replace(".nii", "")

lobes_mrns = {strip_nii(f.name) for f in Path(lobes_dir).glob("*.nii*")}
fissure_mrns = {
    k: {strip_nii(f.name) for f in Path(v).glob("*.nii*")} for k, v in fissure_dirs.items()
}
common_mrns = lobes_mrns & fissure_mrns["ROF"] & fissure_mrns["RHF"] & fissure_mrns["LOF"]

all_results = []

for idx, mrn in enumerate(sorted(common_mrns), 1):
    print(f"\n🔄 [{idx}/{len(common_mrns)}] Processing MRN: {mrn}")
    lobes_path = os.path.join(lobes_dir, f"{mrn}.nii.gz")
    if not os.path.exists(lobes_path):
        lobes_path = os.path.join(lobes_dir, f"{mrn}.nii")
    lobes = nib.load(lobes_path).get_fdata().astype(np.uint8)

    fissures = {
        k: (nib.load(os.path.join(d, f"{mrn}.nii.gz")).get_fdata() > 0).astype(np.uint8)
        for k, d in fissure_dirs.items()
    }

    H, D, W = lobes.shape
    mrn_result = {"MRN": mrn}

    for fissure_name, border_pairs in fissure_defs.items():
        print(f"   ➤ Fissure: {fissure_name}")
        total = 0
        not_amenable = 0

        for (A, B) in border_pairs:
            if fissure_name == "RUL":
                # Use RHF for (1,2) and ROF for (2,3)
                if (A, B) == (1, 2):
                    fissure_pred = fissures["RHF"]
                elif (A, B) == (1, 3):
                    fissure_pred = fissures["ROF"]
                else:
                    raise ValueError(f"Unexpected pair for RUL: {A}, {B}")
            else:
                fissure_pred = fissures[fissure_name]

            for y in range(D):
                for x in range(H):
                    line = lobes[x, y, :]
                    line_pred = fissure_pred[x, y, :]

                    if A not in line or B not in line:
                        continue
                    total += 1

                    border_rows = []
                    for z in range(W):
                        if line[z] == A and np.any(line[max(0, z-1):z+2] == B):
                            border_rows.append(z)
                        elif line[z] == B and np.any(line[max(0, z-1):z+2] == A):
                            border_rows.append(z)

                    fissure_rows = np.where(line_pred > 0)[0]

                    # Heuristic 1
                    if len(fissure_rows) > 0 and len(border_rows) > 0:
                        min_dist = min([np.min(np.abs(fissure_rows - b)) for b in border_rows])
                        if min_dist > threshold_fissure_far_from_border:
                            not_amenable += 1
                            continue

                    # Heuristic 2
                    if len(border_rows) > 1 and (np.max(border_rows) - np.min(border_rows) > threshold_multiple_borders):
                        not_amenable += 1
                        continue

        percent = 100 * not_amenable / total if total else 0
        print(f"     ⬅️ {fissure_name}: {not_amenable}/{total} not amenable ({percent:.2f}%)")

        # Store values in spread columns
        mrn_result[f"{fissure_name}_not_amenable"] = not_amenable
        mrn_result[f"{fissure_name}_total"] = total
        mrn_result[f"{fissure_name}_percent"] = round(percent, 2)

    all_results.append(mrn_result)

# === Save summary ===
df = pd.DataFrame(all_results)

# Optional: sort columns
fissures = ["LOF", "ROF", "RHF", "RUL"]
metrics = ["not_amenable", "total", "percent"]
ordered_cols = ["MRN"] + [f"{f}_{m}" for f in fissures for m in metrics]
df = df.reindex(columns=ordered_cols)

df.to_excel(out_xls, index=False)
print(f"\n✅ Saved summary to: {out_xls}")


## Diagnostic accuracy metrics of Model-derived fissure completeness scores for predicting Chartis CV negative status

In [None]:
import pandas as pd
import numpy as np
from sklearn.metrics import confusion_matrix

# === Load data ===
completeness = pd.read_excel(r"------ INSERT PATH HERE ------\fissure_completeness_all_columnwise.xlsx")
chartis_outcome = pd.read_excel(r"------ INSERT PATH HERE ------\ChartisList_withFEV1.xlsx")

# === Clean column names ===
completeness.columns = completeness.columns.str.strip()
chartis_outcome.columns = chartis_outcome.columns.str.strip()

# === Merge by MRN ===
df = completeness.merge(chartis_outcome, on="MRN", how="inner")
df = df.rename(columns={"ROF_combined": "ROF"})

# === Map fissures to Chartis outcome columns ===
fissure_to_outcome = {
    "LOF": "L",
    "ROF": "RLL1",
    "RUL": "RUL1"
}

# === Confusion matrix + metrics at fixed threshold ===
threshold = 92

for fissure, outcome_col in fissure_to_outcome.items():
    df_sub = df[[fissure, outcome_col]].dropna()

    y_true = df_sub[outcome_col].values  # 1 = No CV, 0 = CV
    y_score = df_sub[fissure].values
    y_pred = (y_score >= threshold).astype(int)  # ≥ threshold → predict No CV (1)

    cm = confusion_matrix(y_true, y_pred, labels=[1, 0])
    cm_flipped = cm.T

    row_labels = [f"≥{threshold} (No CV)", f"<{threshold} (CV)"]
    col_labels = ["1 (No CV)", "0 (CV present)"]

    print(f"\n=== {fissure} — Confusion Matrix at Threshold {threshold}% ===")
    print("Confusion Matrix (rows = predicted, cols = actual):")
    print(" " * 20 + "".join([f"{label:>20}" for label in col_labels]))
    for j, row in enumerate(cm_flipped):
        print(f"{row_labels[j]:<20}" + "".join([f"{val:>20}" for val in row]))

    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()

    sensitivity = tp / (tp + fn) if (tp + fn) else np.nan
    specificity = tn / (tn + fp) if (tn + fp) else np.nan
    ppv = tp / (tp + fp) if (tp + fp) else np.nan  # Precision
    npv = tn / (tn + fn) if (tn + fn) else np.nan
    accuracy = (tp + tn) / (tp + tn + fp + fn)

    print(f"\nPerformance Metrics:")
    print(f"Sensitivity (No CV):      {sensitivity:.2f}")
    print(f"Specificity (CV):         {specificity:.2f}")
    print(f"PPV (Predicted No CV):    {ppv:.2f}")
    print(f"NPV (Predicted CV):       {npv:.2f}")
    print(f"Accuracy:                 {accuracy:.2f}")


## Diagnostic accuracy metrics of StratX-derived fissure completeness scores for predicting Chartis CV negative status

In [None]:
import pandas as pd
import numpy as np
from sklearn.metrics import confusion_matrix

# === Load data ===
stratx_file = r"------ INSERT PATH HERE ------\StratX_List_2.xlsx"
chartis_file = r"------ INSERT PATH HERE ------\ChartisList_withFEV1.xlsx"

stratx = pd.read_excel(stratx_file)
chartis = pd.read_excel(chartis_file)

# === Clean column names and MRNs ===
stratx.columns = stratx.columns.str.strip()
chartis.columns = chartis.columns.str.strip()
stratx["MRN"] = stratx["MRN"].astype(str).str.strip().str.replace(".0", "", regex=False)
chartis["MRN"] = chartis["MRN"].astype(str).str.strip().str.replace(".0", "", regex=False)

# === Merge on MRN ===
df = stratx.merge(chartis, on="MRN", how="inner")

# === Fissure mapping ===
fissure_to_outcome = {
    "LUL": "L",
    "RLL": "RLL1",
    "RUL": "RUL1"
}

# === Confusion matrix + metrics at threshold ===
threshold = 95

for stratx_col, chartis_col in fissure_to_outcome.items():
    df_sub = df[[stratx_col, chartis_col]].dropna()

    y_true = df_sub[chartis_col].astype(int).values  # 1 = No CV, 0 = CV
    y_score = df_sub[stratx_col].astype(float).values
    y_pred = (y_score >= threshold).astype(int)

    cm = confusion_matrix(y_true, y_pred, labels=[1, 0])
    cm_flipped = cm.T

    row_labels = [f"≥{threshold} (No CV)", f"<{threshold} (CV)"]
    col_labels = ["1 (No CV)", "0 (CV present)"]

    print(f"\n=== {stratx_col} — Confusion Matrix at Threshold {threshold}% ===")
    print("Confusion Matrix (rows = predicted, cols = actual):")
    print(" " * 20 + "".join([f"{label:>20}" for label in col_labels]))
    for j, row in enumerate(cm_flipped):
        print(f"{row_labels[j]:<20}" + "".join([f"{val:>20}" for val in row]))

    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
    sensitivity = tp / (tp + fn) if (tp + fn) else np.nan
    specificity = tn / (tn + fp) if (tn + fp) else np.nan
    ppv = tp / (tp + fp) if (tp + fp) else np.nan
    npv = tn / (tn + fn) if (tn + fn) else np.nan
    accuracy = (tp + tn) / (tp + tn + fp + fn)

    print(f"\nPerformance Metrics:")
    print(f"Sensitivity (No CV):      {sensitivity:.2f}")
    print(f"Specificity (CV):         {specificity:.2f}")
    print(f"PPV (Predicted No CV):    {ppv:.2f}")
    print(f"NPV (Predicted CV):       {npv:.2f}")
    print(f"Accuracy:                 {accuracy:.2f}")


## ROC of Model and StratX derived fissure completeness scores to predict negative collaternal ventilation status on Chartis (all cases)

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc

# === Load data ===
completeness = pd.read_excel(r"------ INSERT PATH HERE ------\fissure_completeness_all_columnwise.xlsx")
chartis_outcome = pd.read_excel(r"------ INSERT PATH HERE ------\ChartisList_withFEV1.xlsx")
stratx = pd.read_excel(r"------ INSERT PATH HERE ------\StratX_List_2.xlsx")

# === Clean column names and MRNs ===
completeness.columns = completeness.columns.str.strip()
chartis_outcome.columns = chartis_outcome.columns.str.strip()
stratx.columns = stratx.columns.str.strip()
completeness["MRN"] = completeness["MRN"].astype(str).str.strip().str.replace(".0", "", regex=False)
chartis_outcome["MRN"] = chartis_outcome["MRN"].astype(str).str.strip().str.replace(".0", "", regex=False)
stratx["MRN"] = stratx["MRN"].astype(str).str.strip().str.replace(".0", "", regex=False)
completeness = completeness.rename(columns={"ROF_combined": "ROF"})

# === Merge datasets ===
df_model = completeness.merge(chartis_outcome, on="MRN", how="inner")
df_stratx = stratx.merge(chartis_outcome, on="MRN", how="inner")

# === Fissure-to-outcome mapping ===
fissure_to_outcome_model = {
    "LOF": "L",
    "ROF": "RLL1",
    "RUL": "RUL1"
}
fissure_to_outcome_stratx = {
    "LUL": "L",
    "RLL": "RLL1",
    "RUL": "RUL1"
}
fissure_display_names = [
    "Left oblique fissure",
    "Right oblique fissure",
    "Fissures around right upper lobe"
]

# === Bootstrap ROC with CI function ===
def bootstrap_roc_ci(y_true, y_score, n_bootstraps=1000, seed=42):
    rng = np.random.RandomState(seed)
    tpr_list, aucs = [], []
    base_fpr = np.linspace(0, 1, 101)
    for _ in range(n_bootstraps):
        indices = rng.randint(0, len(y_score), len(y_score))
        if len(np.unique(y_true[indices])) < 2:
            continue
        fpr, tpr, _ = roc_curve(y_true[indices], y_score[indices])
        aucs.append(auc(fpr, tpr))
        tpr_interp = np.interp(base_fpr, fpr, tpr)
        tpr_interp[0] = 0.0
        tpr_list.append(tpr_interp)
    tpr_array = np.array(tpr_list)
    mean_tpr = np.mean(tpr_array, axis=0)
    std_tpr = np.std(tpr_array, axis=0)
    tpr_upper = np.minimum(mean_tpr + 1.96 * std_tpr, 1)
    tpr_lower = np.maximum(mean_tpr - 1.96 * std_tpr, 0)
    mean_auc = np.mean(aucs)
    auc_ci_lower = np.percentile(aucs, 2.5)
    auc_ci_upper = np.percentile(aucs, 97.5)
    return base_fpr, mean_tpr, tpr_lower, tpr_upper, mean_auc, auc_ci_lower, auc_ci_upper

# === Create figure with proper spacing ===
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# === Top row: Model-derived scores ===
top_labels = ['A', 'B', 'C']
for i, (fissure, outcome_col) in enumerate(fissure_to_outcome_model.items()):
    ax = axes[0, i]
    df_sub = df_model[[fissure, outcome_col]].dropna()
    y_true = df_sub[outcome_col].astype(int).values
    y_score = df_sub[fissure].astype(float).values
    fpr, mean_tpr, lower_tpr, upper_tpr, auc_mean, auc_lo, auc_hi = bootstrap_roc_ci(y_true, y_score)

    ax.plot(fpr, mean_tpr, label=f"AUC = {auc_mean:.2f} (95% CI: {auc_lo:.2f}–{auc_hi:.2f})", lw=2)
    ax.fill_between(fpr, lower_tpr, upper_tpr, color='b', alpha=0.2, label="95% CI")
    ax.plot([0, 1], [0, 1], 'k--', lw=1)
    ax.set_title(f"{top_labels[i]}. {fissure_display_names[i]}", fontsize=12)
    ax.set_xlabel("False Positive Rate")
    ax.set_ylabel("True Positive Rate")
    ax.grid(True)
    ax.legend()

# === Bottom row: StratX-derived scores ===
bottom_labels = ['D', 'E', 'F']
for i, (fissure, outcome_col) in enumerate(fissure_to_outcome_stratx.items()):
    ax = axes[1, i]
    df_sub = df_stratx[[fissure, outcome_col]].dropna()
    y_true = df_sub[outcome_col].astype(int).values
    y_score = df_sub[fissure].astype(float).values
    fpr, mean_tpr, lower_tpr, upper_tpr, auc_mean, auc_lo, auc_hi = bootstrap_roc_ci(y_true, y_score)

    ax.plot(fpr, mean_tpr, label=f"AUC = {auc_mean:.2f} (95% CI: {auc_lo:.2f}–{auc_hi:.2f})", lw=2)
    ax.fill_between(fpr, lower_tpr, upper_tpr, color='b', alpha=0.2, label="95% CI")
    ax.plot([0, 1], [0, 1], 'k--', lw=1)
    ax.set_title(f"{bottom_labels[i]}. {fissure_display_names[i]}", fontsize=12)
    ax.set_xlabel("False Positive Rate")
    ax.set_ylabel("True Positive Rate")
    ax.grid(True)
    ax.legend()

# === Apply tight_layout first, then add section titles ===
plt.tight_layout()

# === Add section titles with proper spacing ===
fig.text(0.5, 0.95, "Model-derived fissure completeness scores", ha='center', fontsize=15, weight='bold')
fig.text(0.5, 0.48, "StratX-derived fissure completeness scores", ha='center', fontsize=15, weight='bold')

# === Adjust subplot positions to make room for titles ===
plt.subplots_adjust(top=0.92, bottom=0.08, hspace=0.4)

plt.savefig(r"------ INSERT PATH HERE ------\Figures\Figure5.png", dpi=300, bbox_inches='tight')  # Save at 300 DPI
plt.show()

## Comparing AUROC of Model and StratX derived fissure completeness scores to predict negative collaternal ventialtion status on Chartis (common in both datasets)

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
from scipy import stats
import warnings
warnings.filterwarnings('ignore')

# Use established DeLong test from MLstatkit
try:
    from MLstatkit.stats import Delong_test
    HAS_MLSTATKIT = True
    print("Using MLstatkit DeLong test implementation")
except ImportError:
    HAS_MLSTATKIT = False
    print("MLstatkit not available. Install with: pip install MLstatkit")
    print("Using fallback method...")

def delong_roc_test(y_true, y_prob1, y_prob2):
    """
    DeLong test using established MLstatkit library
    """
    if HAS_MLSTATKIT:
        try:
            z_score, p_value = Delong_test(y_true, y_prob1, y_prob2)
            return z_score, p_value
        except Exception as e:
            print(f"MLstatkit DeLong test failed: {str(e)}")
            return np.nan, np.nan
    else:
        # Fallback to simple difference test if MLstatkit not available
        from sklearn.metrics import roc_auc_score
        try:
            auc1 = roc_auc_score(y_true, y_prob1)
            auc2 = roc_auc_score(y_true, y_prob2)
            diff = auc1 - auc2
            # Very simple approximation - not a proper DeLong test
            z_stat = diff / 0.05  # Rough approximation
            p_value = 2 * (1 - stats.norm.cdf(abs(z_stat)))
            return z_stat, p_value
        except:
            return np.nan, np.nan

# === Load and prepare data ===
completeness = pd.read_excel(r"------ INSERT PATH HERE ------\fissure_completeness_all_columnwise.xlsx")
chartis_outcome = pd.read_excel(r"------ INSERT PATH HERE ------\ChartisList_withFEV1.xlsx")
stratx = pd.read_excel(r"------ INSERT PATH HERE ------\StratX_List_2.xlsx")

# === Clean data ===
completeness.columns = completeness.columns.str.strip()
chartis_outcome.columns = chartis_outcome.columns.str.strip()
stratx.columns = stratx.columns.str.strip()
completeness["MRN"] = completeness["MRN"].astype(str).str.strip().str.replace(".0", "", regex=False)
chartis_outcome["MRN"] = chartis_outcome["MRN"].astype(str).str.strip().str.replace(".0", "", regex=False)
stratx["MRN"] = stratx["MRN"].astype(str).str.strip().str.replace(".0", "", regex=False)
completeness = completeness.rename(columns={"ROF_combined": "ROF"})

# === Merge datasets ===
df_model = completeness.merge(chartis_outcome, on="MRN", how="inner")
df_stratx = stratx.merge(chartis_outcome, on="MRN", how="inner")

# === Fissure mappings ===
model_mappings = [("LOF", "L"), ("ROF", "RLL1"), ("RUL", "RUL1")]
stratx_mappings = [("LUL", "L"), ("RLL", "RLL1"), ("RUL", "RUL1")]
fissure_names = ["Left Oblique Fissure", "Right Oblique Fissure", "Right Upper Lobe Fissures"]

# === Bootstrap AUROC function ===
def bootstrap_auroc(y_true, y_score, n_bootstraps=2000, seed=42):
    """Calculate AUROC with bootstrap confidence intervals"""
    rng = np.random.RandomState(seed)
    aucs = []
    
    for _ in range(n_bootstraps):
        indices = rng.randint(0, len(y_score), len(y_score))
        if len(np.unique(y_true[indices])) < 2:
            continue
        fpr, tpr, _ = roc_curve(y_true[indices], y_score[indices])
        aucs.append(auc(fpr, tpr))
    
    aucs = np.array(aucs)
    mean_auc = np.mean(aucs)
    ci_lower = np.percentile(aucs, 2.5)
    ci_upper = np.percentile(aucs, 97.5)
    
    return mean_auc, ci_lower, ci_upper, aucs



print("="*80)
print("COMBINED AUROC ANALYSIS: ALL FISSURES TOGETHER")
print("="*80)

# === Collect all paired observations across all fissures ===
all_y_true_model = []
all_y_score_model = []
all_y_true_stratx = []
all_y_score_stratx = []
all_patient_info = []

total_paired_observations = 0

for i in range(3):
    model_fissure, model_outcome = model_mappings[i]
    stratx_fissure, stratx_outcome = stratx_mappings[i]
    fissure_name = fissure_names[i]
    
    # Prepare data for merging
    df_model_temp = df_model[[model_fissure, model_outcome, 'MRN']].dropna().copy()
    df_stratx_temp = df_stratx[[stratx_fissure, stratx_outcome, 'MRN']].dropna().copy()
    
    # Rename columns to avoid conflicts
    df_model_temp = df_model_temp.rename(columns={
        model_fissure: f'{model_fissure}_model_score',
        model_outcome: f'{model_outcome}_model_outcome'
    })
    df_stratx_temp = df_stratx_temp.rename(columns={
        stratx_fissure: f'{stratx_fissure}_stratx_score', 
        stratx_outcome: f'{stratx_outcome}_stratx_outcome'
    })
    
    # Merge for paired comparison
    df_paired = df_model_temp.merge(df_stratx_temp, on='MRN')
    
    if len(df_paired) > 0:
        # Extract data
        y_true_model = df_paired[f'{model_outcome}_model_outcome'].astype(int).values
        y_score_model = df_paired[f'{model_fissure}_model_score'].astype(float).values
        y_true_stratx = df_paired[f'{stratx_outcome}_stratx_outcome'].astype(int).values
        y_score_stratx = df_paired[f'{stratx_fissure}_stratx_score'].astype(float).values
        
        # Add to combined lists
        all_y_true_model.extend(y_true_model)
        all_y_score_model.extend(y_score_model)
        all_y_true_stratx.extend(y_true_stratx)
        all_y_score_stratx.extend(y_score_stratx)
        
        # Track patient info
        patient_info = [(mrn, fissure_name) for mrn in df_paired['MRN']]
        all_patient_info.extend(patient_info)
        
        total_paired_observations += len(df_paired)
        print(f"{fissure_name}: {len(df_paired)} paired observations")

# Convert to numpy arrays
all_y_true_model = np.array(all_y_true_model)
all_y_score_model = np.array(all_y_score_model)
all_y_true_stratx = np.array(all_y_true_stratx)
all_y_score_stratx = np.array(all_y_score_stratx)

print(f"\nTotal combined observations: {total_paired_observations}")
print(f"Positive outcomes: {np.sum(all_y_true_model)} / {len(all_y_true_model)} ({100*np.mean(all_y_true_model):.1f}%)")

# === Calculate combined AUROCs ===
print(f"\nCombined AUROC Analysis:")
print("-" * 40)

# Model-derived combined AUROC
auc_model_combined, ci_low_model, ci_high_model, aucs_model = bootstrap_auroc(
    all_y_true_model, all_y_score_model
)

# StratX-derived combined AUROC  
auc_stratx_combined, ci_low_stratx, ci_high_stratx, aucs_stratx = bootstrap_auroc(
    all_y_true_stratx, all_y_score_stratx
)

print(f"Model-derived AUROC:  {auc_model_combined:.2f} (95% CI: {ci_low_model:.2f}–{ci_high_model:.2f})")
print(f"StratX-derived AUROC: {auc_stratx_combined:.2f} (95% CI: {ci_low_stratx:.2f}–{ci_high_stratx:.2f})")
print(f"Difference (Model - StratX): {auc_model_combined - auc_stratx_combined:.2f}")

# === Statistical comparison ===
print(f"\nStatistical Comparison:")
print("-" * 30)

# Statistical test using established DeLong test from MLstatkit
try:
    z_stat, p_delong = delong_roc_test(all_y_true_model, all_y_score_model, all_y_score_stratx)
    print(f"DeLong test: z={z_stat:.2f}, p={p_delong:.4f}")
except Exception as e:
    print(f"DeLong test failed: {str(e)}")
    p_delong = np.nan

# Bootstrap difference test
auc_differences = aucs_model - aucs_stratx
p_bootstrap = 2 * min(np.mean(auc_differences >= 0), np.mean(auc_differences <= 0))
print(f"Bootstrap test: p={p_bootstrap:.4f}")
print(f"Mean difference: {np.mean(auc_differences):.2f} ± {np.std(auc_differences):.2f}")

# Statistical interpretation
if not np.isnan(p_delong):
    if p_delong < 0.001:
        sig_level = "*** (p<0.001)"
    elif p_delong < 0.01:
        sig_level = "** (p<0.01)"
    elif p_delong < 0.05:
        sig_level = "* (p<0.05)"
    else:
        sig_level = "ns (not significant)"
    print(f"Significance: {sig_level}")

# === Create 2x2 subplot layout ===
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
axes = axes.flatten()

# Bootstrap for confidence intervals on ROC curves
def bootstrap_roc_curves(y_true, y_score, n_bootstraps=1000, seed=42):
    rng = np.random.RandomState(seed)
    base_fpr = np.linspace(0, 1, 101)
    tpr_list = []
    
    for _ in range(n_bootstraps):
        indices = rng.randint(0, len(y_score), len(y_score))
        if len(np.unique(y_true[indices])) < 2:
            continue
        fpr, tpr, _ = roc_curve(y_true[indices], y_score[indices])
        tpr_interp = np.interp(base_fpr, fpr, tpr)
        tpr_interp[0] = 0.0
        tpr_list.append(tpr_interp)
    
    tpr_array = np.array(tpr_list)
    mean_tpr = np.mean(tpr_array, axis=0)
    tpr_lower = np.percentile(tpr_array, 2.5, axis=0)
    tpr_upper = np.percentile(tpr_array, 97.5, axis=0)
    
    return base_fpr, mean_tpr, tpr_lower, tpr_upper

# Subplot labels
subplot_labels = ['A', 'B', 'C', 'D']

# === First three subplots: Individual fissures ===
individual_results = []

for i in range(3):
    ax = axes[i]
    model_fissure, model_outcome = model_mappings[i]
    stratx_fissure, stratx_outcome = stratx_mappings[i]
    fissure_name = fissure_names[i]
    
    # Prepare data for merging
    df_model_temp = df_model[[model_fissure, model_outcome, 'MRN']].dropna().copy()
    df_stratx_temp = df_stratx[[stratx_fissure, stratx_outcome, 'MRN']].dropna().copy()
    
    # Rename columns to avoid conflicts
    df_model_temp = df_model_temp.rename(columns={
        model_fissure: f'{model_fissure}_model_score',
        model_outcome: f'{model_outcome}_model_outcome'
    })
    df_stratx_temp = df_stratx_temp.rename(columns={
        stratx_fissure: f'{stratx_fissure}_stratx_score', 
        stratx_outcome: f'{stratx_outcome}_stratx_outcome'
    })
    
    # Merge for paired comparison
    df_paired = df_model_temp.merge(df_stratx_temp, on='MRN')
    
    if len(df_paired) > 0:
        # Extract data
        y_true = df_paired[f'{model_outcome}_model_outcome'].astype(int).values
        y_score_model = df_paired[f'{model_fissure}_model_score'].astype(float).values
        y_score_stratx = df_paired[f'{stratx_fissure}_stratx_score'].astype(float).values
        
        # Calculate individual AUROCs
        auc_model_ind, ci_low_model_ind, ci_high_model_ind, _ = bootstrap_auroc(y_true, y_score_model)
        auc_stratx_ind, ci_low_stratx_ind, ci_high_stratx_ind, _ = bootstrap_auroc(y_true, y_score_stratx)
        
        # Calculate ROC curves
        fpr_model, tpr_model, _ = roc_curve(y_true, y_score_model)
        fpr_stratx, tpr_stratx, _ = roc_curve(y_true, y_score_stratx)
        
        # Get confidence intervals for ROC curves
        fpr_model_ci, tpr_model_ci, tpr_model_lower, tpr_model_upper = bootstrap_roc_curves(y_true, y_score_model)
        fpr_stratx_ci, tpr_stratx_ci, tpr_stratx_lower, tpr_stratx_upper = bootstrap_roc_curves(y_true, y_score_stratx)
        
        # Plot ROC curves with confidence intervals
        ax.plot(fpr_model, tpr_model, 'b-', lw=2, 
                label=f'Model: {auc_model_ind:.2f} ({ci_low_model_ind:.2f}–{ci_high_model_ind:.2f})')
        ax.fill_between(fpr_model_ci, tpr_model_lower, tpr_model_upper, 
                        color='blue', alpha=0.2)
        
        ax.plot(fpr_stratx, tpr_stratx, 'r-', lw=2, 
                label=f'StratX: {auc_stratx_ind:.2f} ({ci_low_stratx_ind:.2f}–{ci_high_stratx_ind:.2f})')
        ax.fill_between(fpr_stratx_ci, tpr_stratx_lower, tpr_stratx_upper, 
                        color='red', alpha=0.2)
        
        ax.plot([0, 1], [0, 1], 'k--', lw=1, alpha=0.5)
        
        # Statistical test using established DeLong test from MLstatkit
        try:
            z_stat_ind, p_value_ind = delong_roc_test(y_true, y_score_model, y_score_stratx)
            print(f"  {fissure_name}: DeLong z={z_stat_ind:.3f}, p={p_value_ind:.4f}")
        except:
            z_stat_ind, p_value_ind = np.nan, np.nan
            print(f"  {fissure_name}: DeLong test failed")
        
        # Formatting
        ax.set_title(f'{subplot_labels[i]}. {fissure_name} (n={len(df_paired)})', fontsize=12, weight='bold')
        ax.set_xlabel('False Positive Rate', fontsize=10)
        ax.set_ylabel('True Positive Rate', fontsize=10)
        ax.legend(fontsize=9, loc='lower right')
        ax.grid(True, alpha=0.3)
        
        # Add p-value - use the individual p-value for this subplot
        if not np.isnan(p_value_ind):
            ax.text(0.05, 0.95, f'p = {p_value_ind:.3f}', 
                    fontsize=9, bbox=dict(boxstyle="round,pad=0.2", facecolor="lightgray"),
                    transform=ax.transAxes, verticalalignment='top')
        
        # Store results with individual p-value
        individual_results.append({
            'fissure': fissure_name,
            'n_paired': len(df_paired),
            'model_auc': auc_model_ind,
            'stratx_auc': auc_stratx_ind,
            'p_value': p_value_ind  # Use individual p-value
        })

# === Fourth subplot: Combined analysis ===
ax = axes[3]

# Calculate combined ROC curves (using the already calculated data from above)
fpr_model_combined, tpr_model_combined, _ = roc_curve(all_y_true_model, all_y_score_model)
fpr_stratx_combined, tpr_stratx_combined, _ = roc_curve(all_y_true_stratx, all_y_score_stratx)

# Get confidence intervals for combined ROC curves
fpr_model_ci_comb, tpr_model_ci_comb, tpr_model_lower_comb, tpr_model_upper_comb = bootstrap_roc_curves(
    all_y_true_model, all_y_score_model)
fpr_stratx_ci_comb, tpr_stratx_ci_comb, tpr_stratx_lower_comb, tpr_stratx_upper_comb = bootstrap_roc_curves(
    all_y_true_stratx, all_y_score_stratx)

# Plot combined ROC curves with confidence intervals
ax.plot(fpr_model_combined, tpr_model_combined, 'b-', lw=2, 
        label=f'Model: {auc_model_combined:.2f} ({ci_low_model:.2f}–{ci_high_model:.2f})')
ax.fill_between(fpr_model_ci_comb, tpr_model_lower_comb, tpr_model_upper_comb, 
                color='blue', alpha=0.2)

ax.plot(fpr_stratx_combined, tpr_stratx_combined, 'r-', lw=2, 
        label=f'StratX: {auc_stratx_combined:.2f} ({ci_low_stratx:.2f}–{ci_high_stratx:.2f})')
ax.fill_between(fpr_stratx_ci_comb, tpr_stratx_lower_comb, tpr_stratx_upper_comb, 
                color='red', alpha=0.2)

ax.plot([0, 1], [0, 1], 'k--', lw=1, alpha=0.5)

ax.set_title(f'{subplot_labels[3]}. Combined: All Fissures (n=83)', fontsize=12, weight='bold')
ax.set_xlabel('False Positive Rate', fontsize=10)
ax.set_ylabel('True Positive Rate', fontsize=10)
ax.legend(fontsize=9, loc='lower right')
ax.grid(True, alpha=0.3)

# Add combined p-value
if not np.isnan(p_delong):
    ax.text(0.05, 0.95, f'p = {p_delong:.3f}', 
            fontsize=9, bbox=dict(boxstyle="round,pad=0.2", facecolor="lightgray"),
            transform=ax.transAxes, verticalalignment='top')

plt.tight_layout()
plt.savefig(r"------ INSERT PATH HERE ------\Figures\Figure6.png", dpi=300, bbox_inches='tight')
plt.show()

print(f"\n{'='*80}")
print("SUMMARY:")
print(f"Combined analysis shows {'significant' if p_delong < 0.05 else 'no significant'} difference")
print(f"between Model-derived (AUC={auc_model_combined:.2f}) and StratX-derived (AUC={auc_stratx_combined:.2f})")
print(f"fissure completeness scores across all three fissure types.")
print("="*80)