In [1]:
from enreg.tools.metrics import decay_mode_evaluator as dme
from enreg.tools.metrics import regression_evaluator as re
from enreg.tools.metrics import tagger_evaluator as te
from enreg.tools import general as g
import os
import awkward as ak
from hydra import compose, initialize
from omegaconf import OmegaConf

with initialize(version_base=None, config_path="enreg/config/", job_name="test_app"):
    cfg = compose(config_name="benchmarking")

In [2]:
def evaluate_binary_cls(cfg, base_data):
    evaluators = []
    for algorithm in cfg.comparison_algorithms:
        algorithm_pred_dir = os.path.join(cfg.metrics.classifier.base_dir, algorithm)
        bkg_data = g.load_all_data([os.path.join(cfg.metrics.classifier.base_dir, algorithm, f"{bkg_sample}.parquet") for bkg_sample in cfg.comparison_samples.background_samples])
        for signal_sample in cfg.comparison_samples.signal_samples:
            sig_info_data = base_data[signal_sample]
            bkg_info_data = ak.concatenate([base_data[background_sample] for background_sample in cfg.comparison_samples.background_samples])
            sig_data = g.load_all_data(os.path.join(algorithm_pred_dir, f"{signal_sample}.parquet"))
            
            evaluator = te.TaggerEvaluator(
                signal_predictions=sig_data.binary_classification.pred,
                signal_truth=sig_data.binary_classification.target,
                signal_gen_tau_p4=sig_info_data.gen_jet_tau_p4s,
                signal_reco_jet_p4=sig_info_data.reco_jet_p4s,
                bkg_predictions=bkg_data.binary_classification.pred,
                bkg_truth=bkg_data.binary_classification.target,
                bkg_gen_jet_p4=bkg_info_data.gen_jet_p4s,
                bkg_reco_jet_p4=bkg_info_data.reco_jet_p4s,
                cfg=cfg,
                sample=signal_sample,
                algorithm=algorithm
            )
            evaluators.append(evaluator)
    output_dir = os.path.join(cfg.PLOTS_OUTPUT_DIR, "binary_classifier")
    tme = te.TaggerMultiEvaluator(output_dir, cfg)
    tme.combine_results(evaluators)
    tme.save_results()


def evaluate_decay_mode_reco(cfg):
    for algorithm in cfg.comparison_algorithms:
        algorithm_pred_dir = os.path.join(cfg.metrics.dm_reconstruction.base_dir, algorithm)
        for signal_sample in cfg.comparison_samples.signal_samples:
            print(os.path.join(algorithm_pred_dir, f"{signal_sample}.parquet"))
            sig_data = g.load_all_data(os.path.join(algorithm_pred_dir, f"{signal_sample}.parquet"))
    
            output_dir = os.path.join(cfg.PLOTS_OUTPUT_DIR, "dm_reconstruction")
            evaluator = dme.DecayModeEvaluator(
                g.one_hot_decoding(sig_data.dm_multiclass.pred), g.one_hot_decoding(sig_data.dm_multiclass.target), output_dir, signal_sample, algorithm)
            evaluator.save_performance()


def evaluate_jet_regression(cfg):
    evaluators = []
    for algorithm in cfg.comparison_algorithms:
        algorithm_pred_dir = os.path.join(cfg.metrics.regression.base_dir, algorithm)
        for signal_sample in cfg.comparison_samples.signal_samples:
            sig_data = g.load_all_data(os.path.join(algorithm_pred_dir, f"{signal_sample}.parquet"))
    
            evaluator = re.RegressionEvaluator(sig_data.jet_regression.pred, sig_data.jet_regression.target, cfg, signal_sample.split("_")[0], algorithm)
            evaluators.append(evaluator)
    output_dir = os.path.join(cfg.PLOTS_OUTPUT_DIR, "jet_regression")
    rme = re.RegressionMultiEvaluator(output_dir, cfg, signal_sample.split("_")[0])
    rme.combine_results(evaluators)
    rme.save()

In [None]:
all_samples = cfg.comparison_samples.signal_samples + cfg.comparison_samples.background_samples
base_data = {sample: g.load_all_data(os.path.join(cfg.NTUPLE_BASE_DIR, f"{sample}.parquet")) for sample in all_samples}

if 'binary_classification' in cfg.comparison_tasks:
    evaluate_binary_cls(cfg, base_data)
if 'dm_multiclass' in cfg.comparison_tasks:
    evaluate_decay_mode_reco(cfg)
if 'jet_regression' in cfg.comparison_tasks:
    evaluate_jet_regression(cfg)

[1/1] Loading from /scratch/persistent/laurits/ml-tau/20241002_Training_ntuples_geq20gev/z_test.parquet
Input data loaded
[1/1] Loading from /scratch/persistent/laurits/ml-tau/20241002_Training_ntuples_geq20gev/qq_test.parquet
