# ZZ CR Data/MC inspector

This notebook is designed for **single-file** outputs from the `ZZ_CR` module (e.g. `mkShapes__ZH_4lMET_ZZCR_2024v15.root`) and reproduces the key Data/MC plotting elements used in `2026_01_26_AN_update.ipynb`:

- stacked MC templates,
- data markers with Poisson errors,
- total MC line + uncertainty band,
- Data/MC ratio panel with propagated bands,
- per-process yields in the legend.

## Best course of action (for this tree format)

Because ZZ_CR output is a ROOT file with category/process directories (`trees/<category>/<sample>/Events`) rather than a pkl of dataframes, the robust approach is:

1. Read branches directly with `uproot` per category + sample.
2. Treat `DATA` as unweighted counts.
3. Treat all non-`DATA` samples as MC weighted by branch `weight`.
4. Build weighted histograms + variances (`sumw`, `sumw2`) for MC.
5. Reproduce the same visual elements and ratio logic as in the AN notebook.


In [None]:
import os

import numpy as np
import uproot
import matplotlib.pyplot as plt
import mplhep as hep
from hist.intervals import poisson_interval

hep.style.use("CMS")


In [None]:
# Input produced by ZZ_CR
INPUT_FILE = "../rootFiles/ZH_4lMET/rootFiles__ZH_4lMET_ZZCR_2024v15/mkShapes__ZH_4lMET_ZZCR_2024v15.root"

# Where to store plots
OUTPUT_DIR = "plots_zzcr_datamc"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Optional subset. Keep None to auto-discover all categories under trees/
CATEGORIES = None

# Processes considered data-like (unweighted)
DATA_PROCESSES = {"DATA"}

# Branch used for MC event weight
MC_WEIGHT_BRANCH = "weight"


In [None]:
# Observables to plot (edit freely)
OBSERVABLES = {
    "Z0_mass": {"bins": np.linspace(60, 120, 31), "xlabel": r"$m_{Z0}$ [GeV]", "ratio_ylim": (0.5, 1.5)},
    "X_mass": {"bins": np.linspace(0, 200, 41), "xlabel": r"$m_{X}$ [GeV]", "ratio_ylim": (0.4, 1.8)},
    "m4l": {"bins": np.linspace(0, 800, 41), "xlabel": r"$m_{4\ell}$ [GeV]", "ratio_ylim": (0.4, 1.8)},
    "PuppiMET_pt": {"bins": np.linspace(0, 150, 31), "xlabel": r"$p_{T}^{miss}$ [GeV]", "ratio_ylim": (0.4, 1.8)},
    "pT4l": {"bins": np.linspace(0, 200, 41), "xlabel": r"$p_{T}^{4\ell}$ [GeV]", "ratio_ylim": (0.4, 1.8)},
    "nCleanJet": {"bins": np.arange(-0.5, 8.5, 1.0), "xlabel": r"$N_{\mathrm{clean\ jet}}$", "ratio_ylim": (0.2, 2.2)},
}


In [None]:
def list_categories(root_file):
    with uproot.open(root_file) as f:
        if "trees" not in f:
            raise KeyError("No 'trees' directory found in ROOT file.")
        trees_dir = f["trees"]
        cats = []
        for k in trees_dir.keys():
            cats.append(k.split(";")[0])
        return sorted(cats)


def list_processes(root_file, category):
    with uproot.open(root_file) as f:
        d = f[f"trees/{category}"]
        out = []
        for k in d.keys():
            out.append(k.split(";")[0])
        return sorted(out)


def load_category_arrays(root_file, category, branches, weight_branch="weight"):
    """Load selected branches for all processes in one category.

    Returns dict: process -> {branch: np.ndarray}
    """
    out = {}
    with uproot.open(root_file) as f:
        for proc in list_processes(root_file, category):
            tree_path = f"trees/{category}/{proc}/Events"
            if tree_path not in f:
                continue

            needed = list(branches)
            if proc not in DATA_PROCESSES and weight_branch not in needed:
                needed.append(weight_branch)

            tree = f[tree_path]
            avail = set(tree.keys())
            read_branches = [b for b in needed if b in avail]
            if not read_branches:
                continue

            out[proc] = tree.arrays(read_branches, library="np")
    return out


def weighted_poisson_errors(sumw, sumw2):
    """Garwood-like interval for weighted counts using effective entries."""
    sumw = np.asarray(sumw, dtype=float)
    sumw2 = np.asarray(sumw2, dtype=float)

    neff = np.divide(sumw**2, sumw2, out=np.zeros_like(sumw), where=sumw2 > 0)
    lo_eff, hi_eff = poisson_interval(neff)

    err_lo = np.divide(sumw * (neff - lo_eff), neff, out=np.zeros_like(sumw), where=neff > 0)
    err_hi = np.divide(sumw * (hi_eff - neff), neff, out=np.zeros_like(sumw), where=neff > 0)
    return err_lo, err_hi


In [None]:
def make_datamc_plot(category, observable, cfg, proc_arrays, output_dir=OUTPUT_DIR):
    bins = np.asarray(cfg["bins"])
    centers = 0.5 * (bins[1:] + bins[:-1])
    widths = np.diff(bins)

    data_counts = np.zeros(len(centers), dtype=float)
    mc_hists = {}
    mc_vars = {}

    for proc, arrs in proc_arrays.items():
        if observable not in arrs:
            continue

        x = arrs[observable]
        if proc in DATA_PROCESSES:
            h, _ = np.histogram(x, bins=bins)
            data_counts += h
        else:
            if MC_WEIGHT_BRANCH not in arrs:
                continue
            w = arrs[MC_WEIGHT_BRANCH]
            h, _ = np.histogram(x, bins=bins, weights=w)
            h2, _ = np.histogram(x, bins=bins, weights=w * w)
            mc_hists[proc] = mc_hists.get(proc, 0.0) + h
            mc_vars[proc] = mc_vars.get(proc, 0.0) + h2

    if not mc_hists:
        print(f"[WARN] No MC histograms found for {category}/{observable}")
        return

    procs_sorted = sorted(mc_hists, key=lambda p: float(np.sum(mc_hists[p])))
    mc_total = np.sum([mc_hists[p] for p in procs_sorted], axis=0)
    mc_total_var = np.sum([mc_vars[p] for p in procs_sorted], axis=0)
    mc_err_lo, mc_err_hi = weighted_poisson_errors(mc_total, mc_total_var)

    d_lo, d_hi = poisson_interval(data_counts)
    d_err_lo = data_counts - d_lo
    d_err_hi = d_hi - data_counts

    fig = plt.figure(figsize=(10, 7))
    gs = fig.add_gridspec(2, 2, width_ratios=[8, 1], height_ratios=[3, 1], wspace=0.05, hspace=0.05)
    ax = fig.add_subplot(gs[0, 0])
    rax = fig.add_subplot(gs[1, 0], sharex=ax)
    lax = fig.add_subplot(gs[:, 1])
    lax.axis("off")
    ax.tick_params(labelbottom=False)

    color_cycle = plt.cm.tab20.colors
    bottoms = np.zeros_like(centers)
    stack_handles, stack_labels = [], []

    for i, p in enumerate(procs_sorted):
        vals = mc_hists[p]
        bar = ax.bar(
            centers,
            vals,
            width=widths,
            bottom=bottoms,
            align="center",
            color=color_cycle[i % len(color_cycle)],
            alpha=0.8,
            linewidth=0,
        )
        bottoms = bottoms + vals
        stack_handles.append(bar[0])
        stack_labels.append(f"{p} ({np.sum(vals):.2f})")

    step_x = np.repeat(bins, 2)[1:-1]
    step_y = np.repeat(mc_total, 2)
    step_lo = np.repeat(mc_err_lo, 2)
    step_hi = np.repeat(mc_err_hi, 2)

    mc_line = ax.plot(step_x, step_y, color="red", lw=1.5, drawstyle="steps-mid")[0]
    mc_band = ax.fill_between(step_x, step_y - step_lo, step_y + step_hi, step="mid", color="red", alpha=0.15)

    d_err_lines = ax.vlines(centers, data_counts - d_err_lo, data_counts + d_err_hi, color="black", lw=1.0)
    d_markers = ax.plot(centers, data_counts, "o", color="black", ms=4)[0]

    ax.set_ylabel("Events")
    ax.grid(alpha=0.25)

    with np.errstate(divide="ignore", invalid="ignore"):
        ratio = np.divide(data_counts, mc_total, out=np.zeros_like(data_counts), where=mc_total > 0)
        rlo = np.divide(d_err_lo, mc_total, out=np.zeros_like(data_counts), where=mc_total > 0)
        rhi = np.divide(d_err_hi, mc_total, out=np.zeros_like(data_counts), where=mc_total > 0)

        mc_rel_lo = np.divide(mc_err_lo, mc_total, out=np.zeros_like(mc_total), where=mc_total > 0)
        mc_rel_hi = np.divide(mc_err_hi, mc_total, out=np.zeros_like(mc_total), where=mc_total > 0)

    rax.errorbar(centers, ratio, yerr=[rlo, rhi], fmt="o", color="black", ms=4)
    rax.plot(step_x, np.ones_like(step_y), color="red", lw=1.0, drawstyle="steps-mid")
    rax.fill_between(step_x, 1 - np.repeat(mc_rel_lo, 2), 1 + np.repeat(mc_rel_hi, 2), step="mid", color="red", alpha=0.15)
    rax.set_ylabel("Data/MC")
    rax.set_xlabel(cfg.get("xlabel", observable))
    rax.set_ylim(*cfg.get("ratio_ylim", (0.5, 1.5)))
    rax.grid(alpha=0.25)

    handles = [(d_err_lines, d_markers), (mc_line, mc_band)] + stack_handles[::-1]
    labels = [
        f"Data ({np.sum(data_counts):.0f})",
        f"Total MC ({np.sum(mc_total):.2f})",
    ] + stack_labels[::-1]
    lax.legend(handles, labels, loc="upper left", fontsize="xx-small", frameon=False, ncol=1)

    ax.set_title(f"{category} • {observable}", fontsize=11)
    fig.tight_layout()

    out_cat = os.path.join(output_dir, category)
    os.makedirs(out_cat, exist_ok=True)
    out_png = os.path.join(out_cat, f"{observable}.png")
    fig.savefig(out_png, dpi=180, bbox_inches="tight")
    plt.show()
    print(f"Saved: {out_png}")


In [None]:
if CATEGORIES is None:
    CATEGORIES = list_categories(INPUT_FILE)

print("Categories:", CATEGORIES)
for c in CATEGORIES:
    print(f"  {c}:", list_processes(INPUT_FILE, c))


In [None]:
for category in CATEGORIES:
    needed_branches = set(OBSERVABLES.keys()) | {MC_WEIGHT_BRANCH}
    proc_arrays = load_category_arrays(
        INPUT_FILE,
        category=category,
        branches=needed_branches,
        weight_branch=MC_WEIGHT_BRANCH,
    )

    for obs, cfg in OBSERVABLES.items():
        make_datamc_plot(category, obs, cfg, proc_arrays, output_dir=OUTPUT_DIR)


### Notes

- To add observables, append entries to `OBSERVABLES`.
- To change category scope, set `CATEGORIES = ["zz_cr_XSF_ZEE", ...]`.
- If later MC processes beyond `ZZ` are added to the ROOT file, they will be auto-stacked.
- Data remains unweighted by construction; MC uses only branch `weight`.
