In [3]:
from HH4b.boosted.TrainBDT import evaluate_model, roc_curve, _get_bdt_scores, get_legtitle
from HH4b.postprocessing import (
    get_evt_testing,
    load_run3_samples,
)
import mplhep as hep
from HH4b import hh_vars, plotting
import pickle
import importlib
import hist
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import xgboost as xgb
import pandas as pd

In [None]:
# definitions
bdt_axis = hist.axis.Regular(40, 0, 1, name="bdt", label=r"BDT")

In [None]:
# Loading data

In [None]:
# Loading the model

model_dir = (
    "/home/users/dprimosc/HH4b/src/HH4b/boosted/bdt_trainings_run3/24May31_lr_0p02_md_8_AK4Away"
)
config_name = "24May31_lr_0p02_md_8_AK4Away"


def load_model(model_path: Path) -> xgb.XGBClassifier:
    with open(model_path, "rb") as f:
        model = pickle.load(f)
    return model


model = load_model(model_dir)

In [None]:
def evaluate_bdt():
    """
    1) Makes ROC curves for testing data
    2) Prints Sig efficiency at Bkg efficiency
    """
    plot_dir = model_dir / "evaluation"
    plot_dir.mkdir(exist_ok=True, parents=True)

    # sorting by importance
    importances = model.feature_importances_
    feature_importances = sorted(
        zip(list(X_test.columns), importances), key=lambda x: x[1], reverse=True
    )
    feature_importance_df = pd.DataFrame.from_dict({"Importance": feature_importances})
    feature_importance_df.to_markdown(f"{model_dir}/feature_importances.md")

    # make and save ROCs for testing data
    def find_nearest(array, value):
        array = np.asarray(array)
        idx = (np.abs(array - value)).argmin()
        return idx

    y_scores = model.predict_proba(X_test)
    y_scores = _get_bdt_scores(y_scores, sig_keys, multiclass)

    for i, sig_key in enumerate(sig_keys):
        (plot_dir / sig_key).mkdir(exist_ok=True, parents=True)

        print(f"Evaluating {sig_key} performance")

        if multiclass:
            # selecting only this signal + BGs for ROC curves
            bgs = y_test >= len(sig_keys)
            sigs = y_test == i
            sel = np.logical_or(sigs, bgs).squeeze()
        else:
            sel = np.ones(len(y_test), dtype=bool)

        print("Test ROC with sample weights")
        fpr, tpr, thresholds = roc_curve(
            yt_test[sel], y_scores[sel][:, i], sample_weight=weights_test[sel]
        )

        roc_info = {
            "fpr": fpr,
            "tpr": tpr,
            "thresholds": thresholds,
        }
        with (plot_dir / sig_key / "roc_dict.pkl").open("wb") as f:
            pickle.dump(roc_info, f)

        # print FPR, TPR for a couple of tprs
        for tpr_val in [0.10, 0.12, 0.15]:
            idx = find_nearest(tpr, tpr_val)
            print(
                f"Signal efficiency: {tpr[idx]:.4f}, Background efficiency: {fpr[idx]:.5f}, BDT Threshold: {thresholds[idx]}"
            )

        # plot BDT scores for test samples
        make_bdt_dataframe = importlib.import_module(
            f".{config_name}", package="HH4b.boosted.bdt_trainings_run3"
        )

        print("Performing inference on all samples")
        # get scores from full dataframe, but only use testing indices
        scores = {}
        weights = {}
        mass_dict = {}
        msd_dict = {}
        xbb_dict = {}

        for key in training_keys:
            score = []
            weight = []
            mass = []
            msd = []
            xbb = []
            for year in events_dict_years:
                evt_list = get_evt_testing(f"{model_dir}/inferences/{year}", key)
                if evt_list is None:
                    continue

                events = events_dict_years[year][key]
                bdt_events = make_bdt_dataframe.bdt_dataframe(events)
                test_bdt_dataframe = bdt_events.copy()
                bdt_events["event"] = events["event"].to_numpy()[:, 0]
                bdt_events["finalWeight"] = events["finalWeight"]
                bdt_events["mass"] = events[pnet_mass_str][1]
                bdt_events["msd"] = events["bbFatJetMsd"][1]
                bdt_events["xbb"] = events[pnet_xbb_str][1]
                mask = bdt_events["event"].isin(evt_list)
                test_dataset = bdt_events[mask]

                test_bdt_dataframe = test_bdt_dataframe[mask]
                test_preds = model.predict_proba(test_bdt_dataframe)

                score.append(_get_bdt_scores(test_preds, sig_keys, multiclass)[:, i])
                weight.append(test_dataset["finalWeight"])
                mass.append(test_dataset["mass"])
                msd.append(test_dataset["msd"])
                xbb.append(test_dataset["xbb"])

            scores[key] = np.concatenate(score)
            weights[key] = np.concatenate(weight)
            mass_dict[key] = np.concatenate(mass)
            msd_dict[key] = np.concatenate(msd)
            xbb_dict[key] = np.concatenate(xbb)

        for key in events_dict_years[year]:
            if key in training_keys:
                continue
            score = []
            weight = []
            xbb = []
            for year in events_dict_years:
                preds = model.predict_proba(
                    make_bdt_dataframe.bdt_dataframe(events_dict_years[year][key])
                )
                score.append(_get_bdt_scores(preds, sig_keys, multiclass)[:, i])
                weight.append(events_dict_years[year][key]["finalWeight"])
                xbb.append(events_dict_years[year][key][pnet_xbb_str][1])
                msd.append(events_dict_years[year][key]["bbFatJetMsd"][1])
                mass.append(events_dict_years[year][key][pnet_mass_str][1])
            scores[key] = np.concatenate(score)
            weights[key] = np.concatenate(weight)
            xbb_dict[key] = np.concatenate(xbb)
            msd_dict[key] = np.concatenate(msd)
            mass_dict[key] = np.concatenate(mass)

        print("Making BDT shape plots")

        legtitle = get_legtitle(legacy, pnet_xbb_str)

        h_bdt = hist.Hist(bdt_axis, cat_axis)
        h_bdt_weight = hist.Hist(bdt_axis, cat_axis)
        for key in scores:
            h_bdt.fill(bdt=scores[key], cat=key)
            h_bdt_weight.fill(scores[key], key, weight=weights[key])

        hists = {
            "weight": h_bdt_weight,
            "no_weight": h_bdt,
        }
        for h_key, h in hists.items():
            colors = plotting.color_by_sample
            legends = plotting.label_by_sample

            fig, ax = plt.subplots(1, 1, figsize=(12, 8))
            for key in scores:
                hep.histplot(
                    h[{"cat": key}],
                    ax=ax,
                    label=f"{legends[key]}",
                    histtype="step",
                    linewidth=1,
                    color=colors[key],
                    density=True,
                    flow="none",
                )
            ax.set_yscale("log")
            ax.legend(
                title=legtitle,
                bbox_to_anchor=(1.03, 1),
                loc="upper left",
            )
            ax.set_ylabel("Density")
            ax.set_title("Pre-Selection")
            ax.xaxis.grid(True, which="major")
            ax.yaxis.grid(True, which="major")
            fig.tight_layout()
            fig.savefig(plot_dir / sig_key / f"bdt_shape_{h_key}.png")
            fig.savefig(plot_dir / sig_key / f"bdt_shape_{h_key}.pdf", bbox_inches="tight")
            plt.close()

        print("Making ROC Curves")

        # Plot and save ROC figure
        for log, logstr in [(False, ""), (True, "_log")]:
            fig, ax = plt.subplots(1, 1, figsize=(18, 12))
            bkg_colors = {**plotting.color_by_sample, "merged": "orange"}
            legends = {**plotting.label_by_sample, "merged": "Total Background"}
            plot_thresholds = bdt_cuts
            th_colours = ["#9381FF", "#1f78b4", "#a6cee3", "cyan", "blue"]

            for bkg in [*bg_keys, "merged"]:
                if bkg != "merged":
                    scores_roc = np.concatenate([scores[sig_key], scores[bkg]])
                    sig_jets_score = scores[sig_key]
                    bkg_jets_score = scores[bkg]
                    scores_true = np.concatenate(
                        [
                            np.ones(len(sig_jets_score)),
                            np.zeros(len(bkg_jets_score)),
                        ]
                    )
                    scores_weights = np.concatenate([weights[sig_key], weights[bkg]])
                    fpr, tpr, thresholds = roc_curve(
                        scores_true, scores_roc, sample_weight=scores_weights
                    )
                    # save background roc curves
                    roc_info_bg = {
                        "fpr": fpr,
                        "tpr": tpr,
                        "thresholds": thresholds,
                    }
                    with (plot_dir / sig_key / f"roc_dict_{bkg}.pkl").open("wb") as f:
                        pickle.dump(roc_info_bg, f)
                else:
                    scores_roc = np.concatenate(
                        [scores[sig_key]] + [scores[bg_key] for bg_key in bg_keys]
                    )
                    sig_jets_score = scores[sig_key]
                    bkg_jets_score = np.concatenate([scores[bg_key] for bg_key in bg_keys])
                    scores_true = np.concatenate(
                        [
                            np.ones(len(sig_jets_score)),
                            np.zeros(len(bkg_jets_score)),
                        ]
                    )
                    scores_weights = np.concatenate(
                        [weights[sig_key]] + [weights[bg_key] for bg_key in bg_keys]
                    )
                    fpr, tpr, thresholds = roc_curve(
                        scores_true, scores_roc, sample_weight=scores_weights
                    )
                    # save background roc curves
                    roc_info_bg = {
                        "fpr": fpr,
                        "tpr": tpr,
                        "thresholds": thresholds,
                    }
                    with (plot_dir / sig_key / f"roc_dict_{bkg}.pkl").open("wb") as f:
                        pickle.dump(roc_info_bg, f)

                ax.plot(tpr, fpr, linewidth=2, color=bkg_colors[bkg], label=legends[bkg])

                pths = {th: [[], []] for th in plot_thresholds}
                for th in plot_thresholds:
                    idx = find_nearest(thresholds, th)
                    pths[th][0].append(tpr[idx])
                    pths[th][1].append(fpr[idx])

                if bkg == "merged":
                    for k, th in enumerate(plot_thresholds):
                        plt.scatter(
                            *pths[th],
                            marker="o",
                            s=40,
                            label=rf"BDT > {th}",
                            color=th_colours[k],
                            zorder=100,
                        )

                        plt.vlines(
                            x=pths[th][0],
                            ymin=0,
                            ymax=pths[th][1],
                            color=th_colours[k],
                            linestyles="dashed",
                            alpha=0.5,
                        )

                        plt.hlines(
                            y=pths[th][1],
                            xmin=0,
                            xmax=pths[th][0],
                            color=th_colours[k],
                            linestyles="dashed",
                            alpha=0.5,
                        )

            ax.set_title(f"{plotting.label_by_sample[sig_key]} BDT ROC Curve")
            ax.set_xlabel("Signal efficiency")
            ax.set_ylabel("Background efficiency")

            if log:
                ax.set_xlim([0.0, 0.6])
                ax.set_ylim([1e-5, 1e-1])
                ax.set_yscale("log")
            else:
                ax.set_xlim([0.0, 0.7])
                ax.set_ylim([0, 0.08])

            ax.xaxis.grid(True, which="major")
            ax.yaxis.grid(True, which="major")
            ax.legend(
                title=legtitle,
                bbox_to_anchor=(1.03, 1),
                loc="upper left",
            )
            fig.tight_layout()
            fig.savefig(plot_dir / sig_key / f"roc_weights{logstr}.png")
            fig.savefig(plot_dir / sig_key / f"roc_weights{logstr}.pdf", bbox_inches="tight")
            plt.close()