In [None]:
import sys
sys.path.append("/home/belle2/amubarak/Ds2D0enue_Analysis/08-Python_Functions")

import uproot
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm

from Ds2D0e_config import DECAY_CONFIG, BACKGROUND_SAMPLES, get_signal_file, get_generic_file

In [None]:
plt.rcParams.update({
    "axes.labelsize": 14,
    "xtick.labelsize": 12,
    "ytick.labelsize": 12,
    "legend.fontsize": 12,
    "figure.titlesize": 16
})

  # Real vs Fake D⁰ sanity study

  * Load the same signal/generic samples as Phase 1 (centralized in Ds2D0e_config)
  * Apply the D⁰ mass window cuts
  * Define "real D⁰" and "fake D⁰"
    - Real D⁰: `abs(D0_mcPDG) == 421` (D⁰ and D̄⁰)
    - Fake D⁰: `abs(D0_mcPDG) != 421` or NaN
  * Correlations:
    - Full correlation matrix (D⁰ vars + fit vars)
    - Heatmaps of D⁰ vars vs each fit variable, using corr()[[fit_var]]
  * Real vs fake D⁰ distributions for all K_*, pi_*, pi0_*, D0_* variables

In [None]:
# ========================================
# CONFIG
# ========================================
APPLY_GAMMA_VETO = False
GAMMAVETO_THRESHOLD = 0.1  # reserved for later phases

# Control sample toggle (use control-sample ROOT files)
USE_CONTROL_SAMPLE = False
CONTROL_SAMPLE_TAG = "noEID"  # options: "noEID", "wrongCharge", ["noEID", "wrongCharge"], or "all"


# Histogram settings for real vs fake comparisons
BINS_HIST = 50
DENSITY_HIST = True

COLOR_REAL = "purple"
COLOR_FAKE = "#D55E00"
LINEWIDTH = 2.5

# Fit variables
fit_vars = ["Ds_massDifference_0", "Ds_diff_D0pi"]

# Nice LaTeX titles per mode (decay chain only)
mode_titles = {
    "kmpip": r"$D_s^{+} \rightarrow [D^{0} \rightarrow K^{-} \pi^{+}] e^{+} \nu_{e}$",
    "kmpippi0_eff20_May2020": r"$D_s^{+} \rightarrow [D^{0} \rightarrow K^{-} \pi^{+} \pi^{0}] e^{+} \nu_{e}$",
    "km3pi": r"$D_s^{+} \rightarrow [D^{0} \rightarrow K^{-} 3\pi] e^{+} \nu_{e}$",
}

# For convenience, keep local aliases
decay_config = DECAY_CONFIG
background_samples = BACKGROUND_SAMPLES

# =====================================================================
# uproot performance settings
# =====================================================================
UROOT_BRANCH_FILTER = [
    "D0_*",                         # D0_mcPDG, D0_dM, D0_* kinematics etc
    "K_*",                          # K_* daughters
    "pi_*",                         # pi_* daughters
    "pi0_*",                        # pi0_* daughters
    "Ds_massDifference_0",          # Δm_e
    "Ds_diff_D0pi",                 # Δm_π
    "Ds_isSignal",                  # truth flag, useful elsewhere
    "Ds_gammaveto_M_Correction",    # for later gamma veto phases
]

UROOT_NUM_WORKERS = 16

  ## Data loading (with D⁰ mass window cuts, branch filtered)

In [None]:
DataFrames = {}  # dictionary to hold dataframes

def _build_tree_paths(file_or_files, tree_name):
    if isinstance(file_or_files, (list, tuple, set)):
        return [f"{f}:{tree_name}" for f in file_or_files]
    return [f"{file_or_files}:{tree_name}"]

print("Loading Signal files (branch filtered)...")
for mode, config in tqdm(list(decay_config.items()), desc="Signal modes"):
    signal_file = get_signal_file(
        mode,
        use_control_sample=USE_CONTROL_SAMPLE,
        control_sample_tag=CONTROL_SAMPLE_TAG,
    )
    signal_tree_paths = _build_tree_paths(signal_file, config['ds_tree'])

    try:
        df = uproot.concatenate(
            signal_tree_paths,
            library="pd",
            filter_name=UROOT_BRANCH_FILTER,
            num_workers=UROOT_NUM_WORKERS,
        )
    except TypeError:
        # fallback if uproot version does not support num_workers
        df = uproot.concatenate(
            signal_tree_paths,
            library="pd",
            filter_name=UROOT_BRANCH_FILTER,
        )

    df = df.query(config["cut"])
    DataFrames[f"Signal_{mode}"] = df

print("\nLoading Background files (branch filtered)...")
for sample in tqdm(background_samples, desc="Background samples"):
    for mode, config in decay_config.items():
        generic_file = get_generic_file(
            sample,
            mode,
            use_control_sample=USE_CONTROL_SAMPLE,
            control_sample_tag=CONTROL_SAMPLE_TAG,
        )
        generic_tree_paths = _build_tree_paths(generic_file, config['ds_tree'])

        try:
            df = uproot.concatenate(
                generic_tree_paths,  # FIXED: was signal_tree_paths
                library="pd",
                filter_name=UROOT_BRANCH_FILTER,
                num_workers=UROOT_NUM_WORKERS,
            )
        except TypeError:
            df = uproot.concatenate(
                generic_tree_paths,  # FIXED: was signal_tree_paths
                library="pd",
                filter_name=UROOT_BRANCH_FILTER,
            )

        df = df.query(config["cut"])
        DataFrames[f"{sample}_{mode}"] = df

print("\nCombining background samples by mode...")
for mode in decay_config.keys():
    dfs_list = [DataFrames[f"{sample}_{mode}"] for sample in background_samples]
    DataFrames[f"All_{mode}"] = pd.concat(dfs_list, ignore_index=True)

print("\nData loading complete!")
print(f"Successfully loaded {len(DataFrames)} dataframes")
print(f"APPLY_GAMMA_VETO = {APPLY_GAMMA_VETO}")
print(f"USE_CONTROL_SAMPLE = {USE_CONTROL_SAMPLE} (type: {CONTROL_SAMPLE_TAG})")
print(f"Branch filter: {UROOT_BRANCH_FILTER}")
print(f"uPROOT workers: {UROOT_NUM_WORKERS}")

In [None]:
pd.set_option('display.max_rows', 200000)
pd.set_option('display.max_columns', 200000)

  ## Global drop rules for variables

  These are removed from:
  * Correlation matrices
  * Real vs fake D⁰ histograms
  but are still loaded from ROOT (e.g. D0_mcPDG is needed for truth labels).

In [None]:
# Drop any variable whose name ends with one of these suffixes
GLOBAL_DROP_SUFFIXES = [
    "seenInCDC",
    "seenInPXD",
    "seenInSVD",
    "seenInTOP",
    "mcP", "mcE", 
    "isSignal",
    "nMCMatches", "mcMatchWeight",
    "nMCDaughters",
    "mcPDG",
    "genMotherPDG", "genMotherPDG_1", "genMotherPDG_2"
    "genMotherID", "genMotherID_1", "genMotherID_2", 
    "mcErrors"
]

# Explicit variables to drop (even if they do not match the suffixes)
GLOBAL_EXPLICIT_DROP_VARS = [
    "D0_mcDecayTime", "D0_mcLifeTime", "D0_mcFlightTime",
    "D0_D0Mode", "D0_Dbar0Mode", "D0_D0orD0bar",
    "D0_mcMother_nMCDaughters",
    "D0_mcMother_mcDaughter_1_PDG",
    "D0_mcMother_mcDaughter_1_nMCDaughters",
    "D0_mcMother_mcDaughter_1_pt",
    "D0_mcMother_mcDaughter_1_pz",
    "D0_mcMother_mcDaughter_1_cos_theta",
    "D0_mcMother_mcDaughter_1_mcDaughter_0_nMCDaughters",
    "D0_mcMother_mcDaughter_1_mcDaughter_1_nMCDaughters",
    "D0_isSignalAcceptBremsPhotons",
    "D0_isSignalAcceptMissing",
    "D0_isSignalAcceptMissingGamma",
    "D0_isSignalAcceptMissingMassive",
    "D0_isSignalAcceptMissingNeutrino",
    "D0_isSignalAcceptWrongFSPs",
]

  ## Per-mode manual variable drop lists

  Any variable listed for a mode will be excluded from:
  * Correlation matrices
  * Real vs fake D⁰ histograms

  This is per mode so you can drop things that are only useless or noisy in a specific decay channel.

In [None]:
manual_drop_vars = {
    "kmpip": [
        # Example entries:
        # "D0_mcPDG",
        # "D0_genMotherPDG",
        # "K_Ch1_mcPDG",
        # "pi_Ch1_mcPDG",
    ],
    "km3pi": [
        # Example entries for km3pi
    ],
    "kmpippi0_eff20_May2020": [
        # Example entries for kmpippi0_eff20_May2020
    ],
}

In [None]:
def get_var_lists(df_example, mode):
    """
    From one example dataframe (e.g. Signal_kmpip), build lists of:
      - D0_vars
      - K_vars
      - pi_vars
      - pi0_vars
      - all_vars (union of the above)

    Only keep numeric dtypes and drop:
      - per-mode manual_drop_vars[mode]
      - anything in GLOBAL_EXPLICIT_DROP_VARS
      - any column whose name ends with one of GLOBAL_DROP_SUFFIXES
    """
    # Mode specific drops
    drop_set = set(manual_drop_vars.get(mode, []))

    # Add global explicit drops that actually exist in this dataframe
    drop_set |= {c for c in GLOBAL_EXPLICIT_DROP_VARS if c in df_example.columns}

    # Start from numeric columns only
    numeric_cols = df_example.select_dtypes(include=[np.number]).columns

    # Apply suffix and explicit drop rules
    filtered_numeric_cols = []
    for c in numeric_cols:
        if c in drop_set:
            continue
        if any(c.endswith(suf) for suf in GLOBAL_DROP_SUFFIXES):
            continue
        filtered_numeric_cols.append(c)

    def filter_vars(prefix):
        return [
            c for c in filtered_numeric_cols
            if c.startswith(prefix)
        ]

    D0_vars  = filter_vars("D0_")
    K_vars   = [c for c in filtered_numeric_cols if c.startswith("K_")]
    pi_vars  = [c for c in filtered_numeric_cols if c.startswith("pi_")]
    pi0_vars = [c for c in filtered_numeric_cols if c.startswith("pi0_")]

    all_vars = D0_vars + K_vars + pi_vars + pi0_vars

    return D0_vars, K_vars, pi_vars, pi0_vars, all_vars

  ## Correlation heatmaps

  For each mode:

  * Real D⁰: `abs(D0_mcPDG) == 421` from signal and generic
  * Fake D⁰: `abs(D0_mcPDG) != 421` or NaN from generic only

  Correlations:
  * Full correlation matrix heatmap:
    ```python
    heatmap = sns.heatmap(df[features].corr(), annot=False,
                          cmap="coolwarm", vmin=-1, vmax=1)
    ```
  * Pairwise (1-column) heatmap vs fit variable:
    ```python
    heatmap = sns.heatmap(df[features].corr()[['Ds_massDifference_0']]
                          .sort_values(by='Ds_massDifference_0', ascending=False),
                          cmap="coolwarm", annot=True, vmin=-1, vmax=1)
    ```

  ## Correlation study: D⁰ variables vs fit variables

In [None]:
# for mode in decay_config.keys():
#     print(f"\n{'='*80}")
#     print(f"Correlation study for mode: {mode}")
#     print('='*80)

#     title_base = mode_titles.get(mode, mode)

#     df_sig = DataFrames[f"Signal_{mode}"].copy()
#     df_all = DataFrames[f"All_{mode}"].copy()

#     # Build var lists from this mode's signal dataframe
#     D0_vars, K_vars, pi_vars, pi0_vars, all_vars = get_var_lists(df_sig, mode)

#     # D0 vars to use in correlation (excluding fit vars themselves)
#     D0_vars_corr = [v for v in D0_vars if v not in fit_vars]

#     # Fit vars that actually exist
#     used_fit_vars = [fv for fv in fit_vars if fv in df_sig.columns]
#     if len(used_fit_vars) == 0:
#         print("  No fit variables found in dataframe, skipping mode.")
#         continue

#     # Define real and fake D0 (mode by mode)
#     if "D0_mcPDG" not in df_sig.columns or "D0_mcPDG" not in df_all.columns:
#         print("  D0_mcPDG missing, skipping mode.")
#         continue

#     # Real: abs(D0_mcPDG) == 421 (D0 and D̄0), from signal and generic
#     real_sig = df_sig[abs(df_sig["D0_mcPDG"]) == 421].copy()
#     real_bkg = df_all[abs(df_all["D0_mcPDG"]) == 421].copy()

#     # Fake: everything else in generic (including NaN)
#     fake_bkg = df_all[(abs(df_all["D0_mcPDG"]) != 421) | (df_all["D0_mcPDG"].isna())].copy()

#     df_real = pd.concat([real_sig, real_bkg], ignore_index=True)
#     df_fake = fake_bkg

#     print(f"  N_real D0 (signal + generic): {len(df_real)}")
#     print(f"  N_fake D0 (generic only)    : {len(df_fake)}")

#     # --------------------------------------------------------
#     # Helper to compute and plot correlations for one sample
#     # --------------------------------------------------------
#     def do_corr_plots(df_sample, sample_label):
#         if len(df_sample) == 0:
#             print(f"  No entries for {sample_label}, skipping.")
#             return

#         # Features = D0 vars + fit vars
#         features = [v for v in D0_vars_corr + used_fit_vars if v in df_sample.columns]
#         if len(features) == 0:
#             print(f"  No usable features for {sample_label}, skipping.")
#             return

#         # Let corr() handle NaNs pairwise
#         df_corr = df_sample[features]

#         # ---------------------------
#         # Full correlation matrix
#         # ---------------------------
#         plt.figure(figsize=(30, 20))
#         heatmap = sns.heatmap(
#             df_corr.corr(),
#             annot=False,
#             cmap="coolwarm",
#             vmin=-1,
#             vmax=1
#         )
#         heatmap.set_title(
#             f"{title_base} ({sample_label}) Correlation Heatmap",
#             fontdict={'fontsize': 18},
#             pad=16
#         )
#         plt.show()

#         # ---------------------------
#         # 1-column heatmaps vs each fit var
#         # ---------------------------
#         for fv in used_fit_vars:
#             if fv not in df_corr.columns:
#                 continue

#             corr_col = df_corr.corr()[[fv]].sort_values(
#                 by=fv, ascending=False
#             )

#             plt.figure(figsize=(8, 100))
#             heatmap = sns.heatmap(
#                 corr_col,
#                 cmap="coolwarm",
#                 annot=True,
#                 vmin=-1,
#                 vmax=1
#             )
#             heatmap.set_title(
#                 f"{title_base} ({sample_label}) Correlation Heatmap w.r.t {fv}",
#                 fontdict={'fontsize': 18},
#                 pad=16
#             )
#             plt.show()

#     # Real D0 correlations
#     print("\n  -> Real D0 correlations")
#     do_corr_plots(df_real, sample_label="real $D^{0}$")

#     # Fake D0 correlations
#     print("\n  -> Fake D0 correlations")
#     do_corr_plots(df_fake, sample_label="fake $D^{0}$")

  ## Real vs fake D⁰ distributions for K_*, pi_*, pi0_*, D0_* variables

  For each mode:
  * Real/fake definition same as above
  * Lists of:
    * D0_*,
    * K_*,
    * pi_*,
    * pi0_*,
    after applying per-mode manual_drop_vars and numeric-only
  * For each variable:
    * Range = [1st percentile, 99th percentile] (combined real+fake)
    * Real vs fake with:
      * "purple" for real,
      * "#D55E00" for fake,
      * linewidth 2.5,
      * density=False (raw entries per bin)
  * Title = decay chain only
  * Y label includes the bin width value

In [None]:
for mode in decay_config.keys():
    print(f"\n{'='*80}")
    print(f"Real vs fake D0 distributions - mode: {mode}")
    print('='*80)

    title_base = mode_titles.get(mode, mode)

    df_sig = DataFrames[f"Signal_{mode}"].copy()
    df_all = DataFrames[f"All_{mode}"].copy()

    if "D0_mcPDG" not in df_sig.columns or "D0_mcPDG" not in df_all.columns:
        print("  D0_mcPDG missing, skipping mode.")
        continue

    # Real: abs(D0_mcPDG) == 421 (D0 and D̄0), from signal and generic
    real_sig = df_sig[abs(df_sig["D0_mcPDG"]) == 421].copy()
    real_bkg = df_all[abs(df_all["D0_mcPDG"]) == 421].copy()

    # Fake: everything else in generic (including NaN)
    fake_bkg = df_all[(abs(df_all["D0_mcPDG"]) != 421) | (df_all["D0_mcPDG"].isna())].copy()

    df_real = pd.concat([real_sig, real_bkg], ignore_index=True)
    df_fake = fake_bkg

    print(f"  N_real D0 (signal + generic): {len(df_real)}")
    print(f"  N_fake D0 (generic only)    : {len(df_fake)}")

    if len(df_real) == 0 or len(df_fake) == 0:
        print("  Either real or fake sample is empty, skipping histograms for this mode.")
        continue

    # Build lists of variables for this mode (from real sample)
    D0_vars, K_vars, pi_vars, pi0_vars, all_vars = get_var_lists(df_real, mode)

    print(f"  N_D0 vars  : {len(D0_vars)}")
    print(f"  N_K vars   : {len(K_vars)}")
    print(f"  N_pi vars  : {len(pi_vars)}")
    print(f"  N_pi0 vars : {len(pi0_vars)}")

    # Loop over all vars (D0 + daughters)
    for v in all_vars:
        if v not in df_real.columns or v not in df_fake.columns:
            continue

        x_real = df_real[v].to_numpy()
        x_fake = df_fake[v].to_numpy()

        # keep finite values only
        x_real = x_real[np.isfinite(x_real)]
        x_fake = x_fake[np.isfinite(x_fake)]

        if len(x_real) == 0 or len(x_fake) == 0:
            continue

        # automatic range using 1–99 percentile to avoid crazy tails
        x_combined = np.concatenate([x_real, x_fake])
        x_combined = x_combined[np.isfinite(x_combined)]
        if len(x_combined) == 0:
            continue

        low = np.nanpercentile(x_combined, 1)
        high = np.nanpercentile(x_combined, 99)

        if not np.isfinite(low) or not np.isfinite(high) or low == high:
            continue

        # bin width in the variable's native units
        per_bin = (high - low) / BINS_HIST
        print(f"  {mode}, {v}: bin width = {per_bin:.6g}")

        plt.hist(
            x_real,
            bins=BINS_HIST,
            range=[low, high],
            histtype='step',
            color=COLOR_REAL,
            linewidth=LINEWIDTH,
            density=DENSITY_HIST,
            label='Real $D^{0}$ (|mcPDG| = 421)'
        )

        plt.hist(
            x_fake,
            bins=BINS_HIST,
            range=[low, high],
            histtype='step',
            color=COLOR_FAKE,
            linewidth=LINEWIDTH,
            density=DENSITY_HIST,
            label='Fake $D^{0}$'
        )

        plt.title(title_base, loc="left")
        # Just keep the numerical value, no explicit unit label
        plt.ylabel(f"Entries / ({per_bin:.3g})")
        plt.xlabel(v)
        plt.legend()
        plt.show()

  ## Dalitz plot: real vs fake D⁰ (kmpippi0_eff20_May2020, generic MC only)
     x = m²(K⁻π⁺), y = m²(π⁺π⁰)

In [None]:
mode_dalitz = "kmpippi0_eff20_May2020"

if f"All_{mode_dalitz}" not in DataFrames:
    print(f"[Dalitz] Background dataframe All_{mode_dalitz} not found, skipping.")
else:
    df_bkg = DataFrames[f"All_{mode_dalitz}"].copy()

    required_cols = [
        "D0_mcPDG",
        "D0_daughterInvM_0_1",  # m(K- pi+)
        "D0_daughterInvM_1_2",  # m(pi+ pi0)
    ]
    missing = [c for c in required_cols if c not in df_bkg.columns]
    if missing:
        print(f"[Dalitz] Missing columns for Dalitz plot: {missing}")
    else:
        # Real and fake based on abs(D0_mcPDG)
        mask_real = (df_bkg["D0_mcPDG"].abs() == 421)
        mask_fake = (df_bkg["D0_mcPDG"].abs() != 421) | (df_bkg["D0_mcPDG"].isna())

        def extract_xy(df_sel):
            df_sel = df_sel[["D0_daughterInvM_0_1", "D0_daughterInvM_1_2"]].dropna()
            if df_sel.empty:
                return None, None
            m_kpi   = df_sel["D0_daughterInvM_0_1"].to_numpy(dtype=float)
            m_pipi0 = df_sel["D0_daughterInvM_1_2"].to_numpy(dtype=float)
            x = m_kpi**2
            y = m_pipi0**2
            return x, y

        x_real, y_real = extract_xy(df_bkg[mask_real])
        x_fake, y_fake = extract_xy(df_bkg[mask_fake])

        print(f"[Dalitz] N_real D0 (generic, Dalitz vars non-NaN): {0 if x_real is None else len(x_real)}")
        print(f"[Dalitz] N_fake D0 (generic, Dalitz vars non-NaN): {0 if x_fake is None else len(x_fake)}")

        if (x_real is None) or (x_fake is None):
            print("[Dalitz] One of the samples is empty, skipping Dalitz plot.")
        else:
            fig, axes = plt.subplots(1, 2, figsize=(11, 5))
            ax_real, ax_fake = axes

            # Real D0
            hb_real = ax_real.hexbin(
                x_real,
                y_real,
                gridsize=60,
                bins="log",
                cmap="viridis",
            )
            cb_real = fig.colorbar(hb_real, ax=ax_real)
            cb_real.set_label("Counts")

            ax_real.set_xlabel(r"$m^{2}(K^{-}\pi^{+})\;[\mathrm{GeV}^{2}/c^{4}]$")
            ax_real.set_ylabel(r"$m^{2}(\pi^{+}\pi^{0})\;[\mathrm{GeV}^{2}/c^{4}]$")
            ax_real.set_title(r"Real $D^{0}$ (|mcPDG| = 421)", loc="left")

            # Fake D0
            hb_fake = ax_fake.hexbin(
                x_fake,
                y_fake,
                gridsize=60,
                bins="log",
                cmap="viridis",
            )
            cb_fake = fig.colorbar(hb_fake, ax=ax_fake)
            cb_fake.set_label("Counts")

            ax_fake.set_xlabel(r"$m^{2}(K^{-}\pi^{+})\;[\mathrm{GeV}^{2}/c^{4}]$")
            ax_fake.set_ylabel(r"$m^{2}(\pi^{+}\pi^{0})\;[\mathrm{GeV}^{2}/c^{4}]$")
            ax_fake.set_title(r"Fake $D^{0}$", loc="left")

            fig.suptitle(
                r"$D^{0}\!\to K^{-}\pi^{+}\pi^{0}$ Dalitz: Real vs Fake $D^{0}$ (generic MC)",
                fontsize=16,
            )
            fig.subplots_adjust(
                left=0.08, right=0.95, bottom=0.12, top=0.88,
                wspace=0.30
            )
            plt.show()