In [1]:
import uproot
import awkward as ak
import numpy as np
from scipy.stats import poisson
import sklearn.metrics as m
import boost_histogram as bh
import glob
import os
from scipy.interpolate import interp1d

from matplotlib import pyplot as plt
import matplotlib as mpl
import matplotlib.gridspec as gridspec
from matplotlib.legend_handler import HandlerTuple
import matplotlib.lines as mlines

from cycler import cycler
import mplhep as hep
# plt.style.use(hep.style.ROOT)
mpl.rcParams['mathtext.fontset'] = 'stix'
mpl.rcParams['font.family'] = 'STIXGeneral'

from multiprocessing import Pool, cpu_count
from functools import partial

def _p4_from_ptetaphie(pt, eta, phi, energy):
    import vector
    vector.register_awkward()
    return vector.zip({'pt': pt, 'eta': eta, 'phi': phi, 'energy': energy})
def _p4_from_ptetaphim(pt, eta, phi, mass):
    import vector
    vector.register_awkward()
    return vector.zip({'pt': pt, 'eta': eta, 'phi': phi, 'mass': mass})

from concurrent.futures import ThreadPoolExecutor
from functools import reduce
from operator import add
import re
from multiprocessing import Pool, cpu_count
from scipy.interpolate import interp1d, RectBivariateSpline
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures


In [2]:
# Configuration
lumi_scale = 4.5  # 450 fb^-1

# Weight dictionary
weight_dict = {
    "QCD": lumi_scale * 4.5226e+06 * 1e5 / ((17600+38072) * 50e5),
    "ZJetsToQQ": lumi_scale * 1253.61 * 1e5 / (4000 * 1e5),
    "TTbar": lumi_scale * 83175900 / (40305472 + 120928855),
    "SingleTop": lumi_scale * 18487900 / 17032212,
    "TW": lumi_scale * 6496000 / 4694318,
    "TTbarW": lumi_scale * 74530 / 1000000,
    "TTbarZ": lumi_scale * 85900 / 1000000,
    "WW": lumi_scale * 11870000 / 14330905,
    "ZW": lumi_scale * 4674000 / 7117197,
    "ZZ": lumi_scale * 1691000 / 7055884,
    "ZZherwig": lumi_scale * 1691000 / 9899325,
    "ZZvincia": lumi_scale * 1691000 / 10000000,
    "SingleHiggs": lumi_scale * 4858000 / 10000000,
    "VBFH": lumi_scale * 378200 / 1000000,
    "WplusH": lumi_scale * 83990 / 499991,
    "WminusH": lumi_scale * 53270 / 499999,
    "ZH": lumi_scale * 76120 / 300000,
    "ttH": lumi_scale * 50710 / 300000,
    "ggHH": lumi_scale * 1051.7 / 10000000,
    "ggHHherwig": lumi_scale * 1051.7 / 10000000,
    "ggHHvincia": lumi_scale * 1051.7 / 10000000,
    "ggHHkl0": lumi_scale * 2361.8 / 10000000,
    "ggHHkl5": lumi_scale * 3107.2 / 10000000,
    "qqHH": lumi_scale * 58.5 / 3000000,
    "qqHHCV1C2V1kl0": lumi_scale * 155.8 / 3000000,
    "qqHHCV1C2V1kl2": lumi_scale * 48.2 / 3000000,
    "qqHHCV1C2V2kl1": lumi_scale * 482.3 / 3000000,
    "qqHHCV0p5C2V1kl1": lumi_scale * 365.6 / 3000000,
    "qqHHCV1p5C2V1kl1": lumi_scale * 2241.2 / 3000000,
}

# Process list
# process_list = ['ggHH', 'QCD', 'TTbar', 'SingleTop', 'TW', 'TTbarW', 'TTbarZ', 
#                 'WW', 'ZW', 'ZZ', 'SingleHiggs', 'VBFH', 'WplusH', 'WminusH', 
#                 'ZH', 'ttH', 'ZJetsToQQ']

process_list = ['ggHHkl0', 'ggHHkl5', "qqHH",
    "qqHHCV1C2V1kl0",
    "qqHHCV1C2V1kl2",
    "qqHHCV1C2V2kl1",
    "qqHHCV0p5C2V1kl1",
    "qqHHCV1p5C2V1kl1",]

PROCESS_WITH_MODELLITE = ['ggHH', 'QCD', 'TTbar', 'SingleTop', 'TW', 'TTbarW', 'TTbarZ', 
                          'WW', 'ZW', 'ZZ', 'SingleHiggs', 'VBFH', 'WplusH', 'WminusH', 
                          'ZH', 'ttH', 'ZJetsToQQ']

ensemble_models = [
    "../../predict/hh4b_resolved_newsp4_allparts_nosel_138clswtop.noweights.ddp4-bs512-lr2e-3",
    "../../predict/hh4b_resolved_newsp4_allparts_nosel_138clswtop.noweights.ddp4-bs512-lr2e-3.model1", 
    "../../predict/hh4b_resolved_newsp4_allparts_nosel_138clswtop.noweights.ddp4-bs512-lr2e-3.model2",
    "../../predict/hh4b_resolved_newsp4_allparts_nosel_138clswtop.noweights.ndiv10.ddp4-bs400-lr1p6e-3.wd0p01"
]

pred_folder = "/home/olympus/tyyang99/weaver-core-dev/weaver/pheno/predict"
folder_pattern = "/data/bond/tyyang99/HH4b/sm_incl_derived_4j3bor2b/*"

output_base_dir = "/data/bond/tyyang99/HH4b/ensemble_method2_ntuples_cdfv2"
os.makedirs(output_base_dir, exist_ok=True)

N_WORKERS = 40

# ===============================================================================
# NEW: Load Universal Corrector
# ===============================================================================

from scipy.interpolate import interp1d, RectBivariateSpline
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures

# ==============================================================================
# Corrector Classes
# ==============================================================================

class SavedSmoothSurfaceInterpolator:
    def __init__(self):
        self.interpolator = None
        self.fitted = False
        self.m_min = None
        self.m_max = None

    def fit(self, H_smooth, bins):
        H_sym = (H_smooth + H_smooth.T) / 2.0
        centers = (bins[:-1] + bins[1:]) / 2
        
        self.interpolator = RectBivariateSpline(
            centers, centers, H_sym, 
            kx=3, ky=3
        )
        
        self.m_min = bins[0]
        self.m_max = bins[-1]
        self.fitted = True
        
        print(f"  Source interpolator built: range [{self.m_min:.1f}, {self.m_max:.1f}] GeV")
        return True
    
    def evaluate(self, m1, m2):
        if not self.fitted:
            return np.zeros_like(m1)
        
        m1_flat = np.atleast_1d(m1).flatten()
        m2_flat = np.atleast_1d(m2).flatten()
        
        m1_clipped = np.clip(m1_flat, self.m_min, self.m_max)
        m2_clipped = np.clip(m2_flat, self.m_min, self.m_max)
        
        z_pred = self.interpolator.ev(m1_clipped, m2_clipped)
        z_pred = np.maximum(z_pred, 0)
        
        if np.isscalar(m1):
            return float(z_pred[0])
        else:
            return z_pred.reshape(np.shape(m1))

class SmoothSurfaceFitter:
    def __init__(self, degree=4):
        self.degree = degree
        self.model = LinearRegression()
        self.poly = PolynomialFeatures(degree=degree)
        self.fitted = False

    def fit(self, H_coarse, bins_coarse):
        H_sym = (H_coarse + H_coarse.T) / 2.0
        
        c_coarse = (bins_coarse[:-1] + bins_coarse[1:]) / 2
        X_grid, Y_grid = np.meshgrid(c_coarse, c_coarse)
        x_flat, y_flat = X_grid.flatten(), Y_grid.flatten()
        z_flat = H_sym.flatten()
        
        mask = z_flat > 0
        if np.sum(mask) < 10:
            print("  [Warning] Not enough data points for fitting!")
            return False
            
        X_train = np.column_stack((x_flat[mask], y_flat[mask]))
        z_log = np.log(z_flat[mask])
        
        X_poly = self.poly.fit_transform(X_train)
        self.model.fit(X_poly, z_log)
        
        score = self.model.score(X_poly, z_log)
        print(f"  Target polynomial fit R^2 = {score:.4f}")
        
        self.fitted = True
        return True
    
    def evaluate(self, m1, m2):
        if not self.fitted:
            return np.zeros_like(m1)
        
        m1_flat = np.atleast_1d(m1).flatten()
        m2_flat = np.atleast_1d(m2).flatten()
        
        X_test = np.column_stack((m1_flat, m2_flat))
        X_test_poly = self.poly.transform(X_test)
        z_pred_log = self.model.predict(X_test_poly)
        z_pred = np.exp(z_pred_log)
        
        if np.isscalar(m1):
            return float(z_pred[0])
        else:
            return z_pred.reshape(np.shape(m1))

class ContinuousSliceMatcher:
    def __init__(self, M_value, m_min_correct=70, m_max_correct=170, 
                 m_min_extended=60, m_max_extended=180, n_samples=1000):
        self.M = M_value
        self.m_min_correct = m_min_correct
        self.m_max_correct = m_max_correct
        self.m_min_extended = m_min_extended
        self.m_max_extended = m_max_extended
        self.n_samples = n_samples
        
        self.M_center = (m_min_correct + m_max_correct) / 2
        
        self.cdf_source = None
        self.inv_cdf_target = None
    
    def _get_sampling_bounds(self):
        dist_to_center = abs(self.M - self.M_center)
        max_dist_to_center = self.M_center - self.m_min_correct
        
        normalized_dist = min(dist_to_center / max_dist_to_center, 1.0)
        
        if self.M < self.M_center:
            extension_lower = (self.m_min_correct - self.m_min_extended) * normalized_dist
            extension_upper = 0
        else:
            extension_lower = 0
            extension_upper = (self.m_max_extended - self.m_max_correct) * normalized_dist
        
        m_min_use = self.m_min_correct - extension_lower
        m_max_use = self.m_max_correct + extension_upper
        
        return m_min_use, m_max_use
    
    def fit(self, source_func, target_func):
        m_min_use, m_max_use = self._get_sampling_bounds()
        
        self.Delta_max_sampling = 2 * min(self.M - m_min_use, m_max_use - self.M)
        
        if self.Delta_max_sampling <= 1e-6:
            return False
        
        self.m_min_sampling = m_min_use
        self.m_max_sampling = m_max_use
        
        rho_samples = np.linspace(0, 1, self.n_samples)
        Delta_samples = rho_samples * self.Delta_max_sampling
        
        m1_samples = self.M - Delta_samples / 2
        m2_samples = self.M + Delta_samples / 2
        
        mask_valid = (m1_samples >= m_min_use) & (m1_samples <= m_max_use) & \
                     (m2_samples >= m_min_use) & (m2_samples <= m_max_use)
        
        if np.sum(mask_valid) < 10:
            return False
        
        m1_valid = m1_samples[mask_valid]
        m2_valid = m2_samples[mask_valid]
        rho_valid = rho_samples[mask_valid]
        
        density_source = source_func(m1_valid, m2_valid)
        density_target = target_func(m1_valid, m2_valid)
        
        pdf_source = density_source * self.Delta_max_sampling
        pdf_target = density_target * self.Delta_max_sampling
        
        pdf_source = pdf_source / (np.sum(pdf_source) + 1e-9)
        pdf_target = pdf_target / (np.sum(pdf_target) + 1e-9)
        
        cdf_source = np.cumsum(pdf_source)
        cdf_source = cdf_source / (cdf_source[-1] + 1e-9)
        
        cdf_target = np.cumsum(pdf_target)
        cdf_target = cdf_target / (cdf_target[-1] + 1e-9)
        
        rho_min = rho_valid[0]
        rho_max = rho_valid[-1]
        
        self.cdf_source = interp1d(
            rho_valid, cdf_source,
            kind='linear', bounds_error=False, fill_value=(0, 1)
        )
        
        self.inv_cdf_target = interp1d(
            cdf_target, rho_valid,
            kind='linear', bounds_error=False, fill_value=(rho_min, rho_max)
        )
        
        self.rho_min = rho_min
        self.rho_max = rho_max
        
        return True
    
    def transform(self, m1, m2):
        if self.cdf_source is None or self.inv_cdf_target is None:
            return m1, m2
        
        M = (m1 + m2) / 2
        Delta = np.abs(m1 - m2)
        
        rho = Delta / self.Delta_max_sampling
        rho = np.clip(rho, self.rho_min, self.rho_max)
        
        u = self.cdf_source(rho)
        u = np.clip(u, 0, 1)
        rho_new = self.inv_cdf_target(u)
        
        Delta_new = rho_new * self.Delta_max_sampling
        m1_new = M - Delta_new / 2
        m2_new = M + Delta_new / 2
        
        return m1_new, m2_new


class ContinuousDiagonalCorrector:
    def __init__(self, m_min_fit=50, m_max_fit=190, 
                 m_min_correct=70, m_max_correct=170,
                 m_min_extended=60, m_max_extended=180,
                 n_slices=40):
        self.m_min_fit = m_min_fit
        self.m_max_fit = m_max_fit
        self.m_min_correct = m_min_correct
        self.m_max_correct = m_max_correct
        self.m_min_extended = m_min_extended
        self.m_max_extended = m_max_extended
        
        self.M_values = np.linspace(m_min_correct, m_max_correct, n_slices)
        
        self.source_func = None
        self.target_func = None
        
        self.matchers = []
    
    def fit_from_data(self, H_source_smooth, bins_source, H_target, bins_target):
        interp_source = SavedSmoothSurfaceInterpolator()
        success_src = interp_source.fit(H_source_smooth, bins_source)
        
        fitter_target = SmoothSurfaceFitter(degree=4)
        success_tgt = fitter_target.fit(H_target, bins_target)
        
        if not (success_src and success_tgt):
            print("  [Error] Failed to build functions!")
            return
        
        self.source_func = interp_source.evaluate
        self.target_func = fitter_target.evaluate
        
        for M_val in self.M_values:
            matcher = ContinuousSliceMatcher(
                M_val, 
                self.m_min_correct, self.m_max_correct,
                self.m_min_extended, self.m_max_extended,
                n_samples=1000
            )
            
            success = matcher.fit(self.source_func, self.target_func)
            
            if success:
                self.matchers.append((M_val, matcher))

    
    def correct(self, m1, m2):
        if len(self.matchers) == 0:
            print("  [Warning] No matchers available!")
            return m1, m2
        
        M = (m1 + m2) / 2
        m1_new = np.copy(m1)
        m2_new = np.copy(m2)
        
        M_slice_values = np.array([m for m, _ in self.matchers])
        
        mask_M_in_range = (M >= self.m_min_correct) & (M <= self.m_max_correct)
        
        for i in np.where(mask_M_in_range)[0]:
            idx = np.argmin(np.abs(M_slice_values - M[i]))
            M_slice, matcher = self.matchers[idx]
            
            m_min_slice = matcher.m_min_sampling
            m_max_slice = matcher.m_max_sampling
            
            if (m1[i] >= m_min_slice and m1[i] <= m_max_slice and
                m2[i] >= m_min_slice and m2[i] <= m_max_slice):
                m1_new[i], m2_new[i] = matcher.transform(m1[i], m2[i])
        
        return m1_new, m2_new


class UniversalCorrector:
    def __init__(self, source_dir, model_name, tomography_dir, 
                 m_min_fit=50, m_max_fit=190, 
                 m_min_correct=70, m_max_correct=170,
                 m_min_extended=60, m_max_extended=180):
        self.m_min_fit = m_min_fit
        self.m_max_fit = m_max_fit
        self.m_min_correct = m_min_correct
        self.m_max_correct = m_max_correct
        self.m_min_extended = m_min_extended
        self.m_max_extended = m_max_extended

        
        self.correctors = {}
        
        source_npz = os.path.join(source_dir, f"source_distributions_{model_name}.npz")
        
        print(f"Loading source distributions for {model_name}...")
        src_data = np.load(source_npz)
        src_bins = src_data['bins']
        
        saved_model_name = str(src_data['model_name'])
        if saved_model_name != model_name:
            print(f"  [Warning] Model name mismatch: expected {model_name}, got {saved_model_name}")
        
        keys = [k for k in src_data.keys() if "_smooth" in k]
        
        print(f"Found {len(keys)} score bins to process.\n")
        
        for k in keys:
            parts = k.replace("_smooth", "").split("_")
            low = float(parts[1].replace("p", "."))
            high = float(parts[2].replace("p", "."))
            
            H_src_smooth = src_data[k]
            
            tgt_folder = f"range_{parts[1]}_{parts[2]}"
            tgt_path = f"{tomography_dir}/{tgt_folder}/data.npz"
            
            if not os.path.exists(tgt_path):
                print(f"  [Warning] Target not found for bin [{low}, {high})")
                continue
            
            tgt_data = np.load(tgt_path)
            H_tgt = tgt_data['corrected']
            tgt_bins = tgt_data['bins']
            
            print(f"  [Corrector] Processing bin [{low:.3f}, {high:.3f})...")
            
            corrector = ContinuousDiagonalCorrector(
                m_min_fit=m_min_fit, m_max_fit=m_max_fit,
                m_min_correct=m_min_correct, m_max_correct=m_max_correct,
                m_min_extended=self.m_min_extended, m_max_extended=self.m_max_extended,
                n_slices=40
            )

            
            corrector.fit_from_data(H_src_smooth, src_bins, H_tgt, tgt_bins)
            
            self.correctors[(low, high)] = corrector
    
    def correct(self, m1, m2, scores):
        m1_out = np.copy(m1)
        m2_out = np.copy(m2)
        
        for (low, high), corrector in self.correctors.items():
            if high == 1.0:
                mask_score = (scores >= low) & (scores <= high)
            else:
                mask_score = (scores >= low) & (scores < high)
            
            if np.sum(mask_score) == 0:
                continue
            
            m1_sub = m1[mask_score]
            m2_sub = m2[mask_score]
            
            m1_corr, m2_corr = corrector.correct(m1_sub, m2_sub)
            
            m1_out[mask_score] = m1_corr
            m2_out[mask_score] = m2_corr
        
        return m1_out, m2_out


SOURCE_DIR = "./source_2d_distributions"
TOMOGRAPHY_DIR = "./tomography_output"

print("="*80)
print("Loading Universal Correctors for all models...")
print("="*80)

CORRECTORS = {}
MODEL_NAMES = ["Ensemble", "Model0", "Model1", "Model2", "ModelLite"]

for model_name in MODEL_NAMES:
    print(f"\nInitializing corrector for {model_name}...")
    try:
        corrector = UniversalCorrector(
            SOURCE_DIR, model_name, TOMOGRAPHY_DIR,
            m_min_fit=50, m_max_fit=190,
            m_min_correct=70, m_max_correct=170,
            m_min_extended=60, m_max_extended=180
        )

        CORRECTORS[model_name] = corrector
        print(f"  Success: {model_name} corrector loaded")
    except Exception as e:
        print(f"  Error loading {model_name} corrector: {e}")
        CORRECTORS[model_name] = None

print("\n" + "="*80)
print("Corrector initialization complete")
print("="*80 + "\n")

# ===============================================================================
# Helper Functions
# ===============================================================================

def apply_correction(original_x, original_y, scores, corrector):
    """
    Apply correction using UniversalCorrector
    
    Parameters:
    -----------
    original_x, original_y : array-like
        Original peak positions
    scores : array-like
        Model scores
    corrector : UniversalCorrector
        Corrector instance
    
    Returns:
    --------
    corrected_x, corrected_y : arrays
        Corrected peak positions
    """
    if corrector is None:
        return original_x, original_y
    
    try:
        corrected_x, corrected_y = corrector.correct(original_x, original_y, scores)
        return corrected_x, corrected_y
    except Exception as e:
        print(f"  Warning: Correction failed, using original peaks. Error: {e}")
        return original_x, original_y

# ===============================================================================
# File Processing Function
# ===============================================================================

def process_single_file(file_info):
    """
    Process a single file
    
    Parameters:
    - file_info: tuple of (ifile, name, proc_name, weight, file_index, total_files)
    
    Returns:
    - dict with processing results
    """
    ifile, name, proc_name, weight, file_index, total_files = file_info
    
    try:
        print(f"[Worker {os.getpid()}] [{file_index}/{total_files}] Processing: {name}")
        
        process_modellite = (proc_name in PROCESS_WITH_MODELLITE)
        num_models = 4 if process_modellite else 3
        
        model_data = {}
        model_peak_positions = {}
        model_fit_info = {}
        
        for model_idx in range(num_models):
            ensemble_model_name = ensemble_models[model_idx]
            
            pred_file = f"{pred_folder}/{ensemble_model_name}/pred_{name}.root"
            dcbfit_file = f"/data/bond/tyyang99/HH4b/dcb_results/FullNewMethodv2/{ensemble_model_name.split('/')[-1]}/pred_{name}_combined_fit_results.npz"
            
            try:
                pred_data_tmp = uproot.lazy(pred_file)
                pred_data = pred_data_tmp[(pred_data_tmp['pass_selection']==1) & 
                                         (pred_data_tmp['pass_4j3b_selection']==1)]
                
                scores_ALLHH4b = np.zeros_like(pred_data['score_0'])
                
                for j in range(136):
                    scores_ALLHH4b = scores_ALLHH4b + pred_data[f'score_{j}']
                
                pred_data['score_hh4bvsboth'] = scores_ALLHH4b / (
                    scores_ALLHH4b + pred_data['score_136'] + pred_data['score_137'])
                pred_data['score_hh4bvsqcd'] = scores_ALLHH4b / (
                    scores_ALLHH4b + pred_data['score_136'])
                
                cut_value = 0
                cut = (pred_data['score_hh4bvsboth'] > cut_value)
                pred_data = pred_data[cut]
                
                fit_results = np.load(dcbfit_file)
                
                fit_success = fit_results['fit_success'][cut]
                p1_amp = fit_results['p1_amp'][cut]
                p2_amp = np.where(fit_results['p2_amp'][cut], 
                                 fit_results['p2_amp'][cut], 1e-9)
                p1_x_mean = fit_results['p1_x_mean'][cut]
                p1_y_mean = fit_results['p1_y_mean'][cut]
                
                swap_mask = p1_x_mean > p1_y_mean
                original_peak_x = np.where(swap_mask, p1_y_mean, p1_x_mean)
                original_peak_y = np.where(swap_mask, p1_x_mean, p1_y_mean)
                
                if model_idx == 0:
                    model_name = 'Model0'
                elif model_idx == 1:
                    model_name = 'Model1'
                elif model_idx == 2:
                    model_name = 'Model2'
                else:
                    model_name = 'ModelLite'
                
                corrector = CORRECTORS[model_name]
                scores_for_correction = ak.to_numpy(pred_data['score_hh4bvsboth'])
                
                final_peak_x, final_peak_y = apply_correction(
                    original_peak_x, original_peak_y, scores_for_correction, corrector)
                
                amp_cut = (abs(p1_amp/p2_amp) > 0)
                basic_fit_cut = (fit_success == 1) & amp_cut
                
                model_data[model_idx] = {
                    'score_hh4bvsqcd': ak.to_numpy(pred_data['score_hh4bvsqcd']),
                    'score_hh4bvsboth': scores_for_correction,
                }
                model_peak_positions[model_idx] = {
                    'original_x': original_peak_x,
                    'original_y': original_peak_y,
                    'final_x': final_peak_x,
                    'final_y': final_peak_y,
                }
                model_fit_info[model_idx] = {
                    'fit_success': fit_success,
                    'basic_fit_cut': basic_fit_cut,
                    'p1_amp': p1_amp,
                    'p2_amp': p2_amp
                }
                
            except Exception as e:
                print(f"[Worker {os.getpid()}] Error processing model {model_idx} for {name}: {e}")
                model_data[model_idx] = None
                model_peak_positions[model_idx] = None
                model_fit_info[model_idx] = None
        
        if any(model_data.get(i) is None for i in range(num_models)):
            return {
                'success': False,
                'name': name,
                'reason': 'Missing model data'
            }
        
        lengths = [len(model_data[i]['score_hh4bvsqcd']) for i in range(num_models)]
        if not all(length == lengths[0] for length in lengths):
            return {
                'success': False,
                'name': name,
                'reason': f'Model data lengths mismatch: {lengths}'
            }
        
        n_events = lengths[0]
        
        if n_events == 0:
            return {
                'success': False,
                'name': name,
                'reason': 'No events after basic selection'
            }
        
        original_peaks_x = np.array([model_peak_positions[i]['original_x'] for i in range(3)])
        original_peaks_y = np.array([model_peak_positions[i]['original_y'] for i in range(3)])
        
        final_peaks_x = np.array([model_peak_positions[i]['final_x'] for i in range(3)])
        final_peaks_y = np.array([model_peak_positions[i]['final_y'] for i in range(3)])
        
        dist_01 = np.sqrt((original_peaks_x[0] - original_peaks_x[1])**2 + 
                         (original_peaks_y[0] - original_peaks_y[1])**2)
        dist_02 = np.sqrt((original_peaks_x[0] - original_peaks_x[2])**2 + 
                         (original_peaks_y[0] - original_peaks_y[2])**2)
        dist_12 = np.sqrt((original_peaks_x[1] - original_peaks_x[2])**2 + 
                         (original_peaks_y[1] - original_peaks_y[2])**2)
        
        all_distances = np.stack([dist_01, dist_02, dist_12], axis=0)
        min_dist_idx = np.argmin(all_distances, axis=0)
        
        conditions = [
            min_dist_idx == 0,
            min_dist_idx == 1,
            min_dist_idx == 2
        ]
        
        choices_x_original = [
            (original_peaks_x[0] + original_peaks_x[1]) / 2.0,
            (original_peaks_x[0] + original_peaks_x[2]) / 2.0,
            (original_peaks_x[1] + original_peaks_x[2]) / 2.0
        ]
        ensemble_original_peak_x = np.select(conditions, choices_x_original, default=125.0)
        
        choices_y_original = [
            (original_peaks_y[0] + original_peaks_y[1]) / 2.0,
            (original_peaks_y[0] + original_peaks_y[2]) / 2.0,
            (original_peaks_y[1] + original_peaks_y[2]) / 2.0
        ]
        ensemble_original_peak_y = np.select(conditions, choices_y_original, default=125.0)
        
        ensemble_scores = (model_data[0]['score_hh4bvsboth'] + 
                          model_data[1]['score_hh4bvsboth'] + 
                          model_data[2]['score_hh4bvsboth']) / 3.0
        
        ensemble_corrector = CORRECTORS['Ensemble']
        ensemble_final_peak_x, ensemble_final_peak_y = apply_correction(
            ensemble_original_peak_x, ensemble_original_peak_y, ensemble_scores, ensemble_corrector)
        
        dist_01_final = np.sqrt((final_peaks_x[0] - final_peaks_x[1])**2 + 
                               (final_peaks_y[0] - final_peaks_y[1])**2)
        dist_02_final = np.sqrt((final_peaks_x[0] - final_peaks_x[2])**2 + 
                               (final_peaks_y[0] - final_peaks_y[2])**2)
        dist_12_final = np.sqrt((final_peaks_x[1] - final_peaks_x[2])**2 + 
                               (final_peaks_y[1] - final_peaks_y[2])**2)
        
        all_distances_final = np.stack([dist_01_final, dist_02_final, dist_12_final], axis=0)
        min_dist_idx_final = np.argmin(all_distances_final, axis=0)
        
        conditions_final = [
            min_dist_idx_final == 0,
            min_dist_idx_final == 1,
            min_dist_idx_final == 2
        ]
        
        choices_x_final = [
            (final_peaks_x[0] + final_peaks_x[1]) / 2.0,
            (final_peaks_x[0] + final_peaks_x[2]) / 2.0,
            (final_peaks_x[1] + final_peaks_x[2]) / 2.0
        ]
        ensemble_final_peak_v2_x = np.select(conditions_final, choices_x_final, default=125.0)
        
        choices_y_final = [
            (final_peaks_y[0] + final_peaks_y[1]) / 2.0,
            (final_peaks_y[0] + final_peaks_y[2]) / 2.0,
            (final_peaks_y[1] + final_peaks_y[2]) / 2.0
        ]
        ensemble_final_peak_v2_y = np.select(conditions_final, choices_y_final, default=125.0)
        
        combined_basic_fit_cut = (
            np.sum([model_fit_info[0]['basic_fit_cut'], 
                   model_fit_info[1]['basic_fit_cut'], 
                   model_fit_info[2]['basic_fit_cut']], axis=0) >= 3
        )
        
        output_data = {
            'ensemble_original_peak_x': ensemble_original_peak_x,
            'ensemble_original_peak_y': ensemble_original_peak_y,
            'ensemble_final_peak_x': ensemble_final_peak_x,
            'ensemble_final_peak_y': ensemble_final_peak_y,
            'ensemble_final_peak_v2_x': ensemble_final_peak_v2_x,
            'ensemble_final_peak_v2_y': ensemble_final_peak_v2_y,
            'min_dist_pair_idx_final': min_dist_idx_final.astype(np.int32),
            
            'model0_original_peak_x': original_peaks_x[0],
            'model0_original_peak_y': original_peaks_y[0],
            'model0_final_peak_x': final_peaks_x[0],
            'model0_final_peak_y': final_peaks_y[0],
            
            'model1_original_peak_x': original_peaks_x[1],
            'model1_original_peak_y': original_peaks_y[1],
            'model1_final_peak_x': final_peaks_x[1],
            'model1_final_peak_y': final_peaks_y[1],
            
            'model2_original_peak_x': original_peaks_x[2],
            'model2_original_peak_y': original_peaks_y[2],
            'model2_final_peak_x': final_peaks_x[2],
            'model2_final_peak_y': final_peaks_y[2],
            
            'model0_score_hh4bvsqcd': model_data[0]['score_hh4bvsqcd'],
            'model1_score_hh4bvsqcd': model_data[1]['score_hh4bvsqcd'],
            'model2_score_hh4bvsqcd': model_data[2]['score_hh4bvsqcd'],
            'model0_score_hh4bvsboth': model_data[0]['score_hh4bvsboth'],
            'model1_score_hh4bvsboth': model_data[1]['score_hh4bvsboth'],
            'model2_score_hh4bvsboth': model_data[2]['score_hh4bvsboth'],
            
            'model0_fit_success': model_fit_info[0]['fit_success'].astype(np.int32),
            'model1_fit_success': model_fit_info[1]['fit_success'].astype(np.int32),
            'model2_fit_success': model_fit_info[2]['fit_success'].astype(np.int32),
            'model0_basic_fit_cut': model_fit_info[0]['basic_fit_cut'].astype(np.int32),
            'model1_basic_fit_cut': model_fit_info[1]['basic_fit_cut'].astype(np.int32),
            'model2_basic_fit_cut': model_fit_info[2]['basic_fit_cut'].astype(np.int32),
            'model0_p1_amp': model_fit_info[0]['p1_amp'],
            'model1_p1_amp': model_fit_info[1]['p1_amp'],
            'model2_p1_amp': model_fit_info[2]['p1_amp'],
            'model0_p2_amp': model_fit_info[0]['p2_amp'],
            'model1_p2_amp': model_fit_info[1]['p2_amp'],
            'model2_p2_amp': model_fit_info[2]['p2_amp'],
            
            'combined_basic_fit_cut': combined_basic_fit_cut.astype(np.int32),
            'min_dist_pair_idx': min_dist_idx.astype(np.int32),
            
            'weight': np.ones(n_events) * weight
        }
        
        if process_modellite:
            output_data.update({
                'modellite_original_peak_x': model_peak_positions[3]['original_x'],
                'modellite_original_peak_y': model_peak_positions[3]['original_y'],
                'modellite_final_peak_x': model_peak_positions[3]['final_x'],
                'modellite_final_peak_y': model_peak_positions[3]['final_y'],
                'modellite_score_hh4bvsqcd': model_data[3]['score_hh4bvsqcd'],
                'modellite_score_hh4bvsboth': model_data[3]['score_hh4bvsboth'],
                'modellite_fit_success': model_fit_info[3]['fit_success'].astype(np.int32),
                'modellite_basic_fit_cut': model_fit_info[3]['basic_fit_cut'].astype(np.int32),
                'modellite_p1_amp': model_fit_info[3]['p1_amp'],
                'modellite_p2_amp': model_fit_info[3]['p2_amp'],
            })
        
        output_proc_dir = os.path.join(output_base_dir, proc_name)
        os.makedirs(output_proc_dir, exist_ok=True)
        
        output_file = os.path.join(output_proc_dir, f"ensemble_method2_{name}.root")
        
        with uproot.recreate(output_file) as f:
            f["tree"] = output_data
        
        n_saved = len(output_data['ensemble_original_peak_x'])
        n_pass_combined = np.sum(combined_basic_fit_cut)
        
        modellite_status = "with modellite" if process_modellite else "without modellite"
        print(f"[Worker {os.getpid()}] Saved {n_saved} events ({n_pass_combined} pass combined_basic_fit_cut) "
              f"to: {output_file} ({modellite_status})")
        
        return {
            'success': True,
            'name': name,
            'proc_name': proc_name,
            'n_events': n_events,
            'n_saved': n_saved,
            'n_pass_combined': n_pass_combined,
            'output_file': output_file,
            'has_modellite': process_modellite,
        }
        
    except Exception as e:
        print(f"[Worker {os.getpid()}] Error processing {name}: {e}")
        import traceback
        traceback.print_exc()
        return {
            'success': False,
            'name': name,
            'reason': str(e)
        }

# ===============================================================================
# Main Processing Functions
# ===============================================================================

def collect_file_list():
    """
    Collect all files to be processed
    
    Returns:
    - List of file_info tuples
    """
    file_list = []
    matching_folders = glob.glob(folder_pattern)
    
    for ifolder in matching_folders:
        proc_name = ifolder.split("/")[-1].split("_")[0]
        
        if "forInfer2" in ifolder:
            continue
        
        if proc_name not in process_list:
            continue
        
        weight = weight_dict[proc_name]
        matching_files = glob.glob(ifolder + "/*")
        
        if proc_name == "QCD" or proc_name == "TTbar":
            ext_folder = ifolder.replace("forInfer", "forInfer2")
            matching_files += glob.glob(ext_folder + "/*")
        
        for ifile in matching_files:
            if "forInfer2" in ifile:
                name = proc_name + "EXT_" + ifile.replace(".root", "").split("_")[-1]
            else:
                name = proc_name + "_" + ifile.replace(".root", "").split("_")[-1]
            
            file_list.append((ifile, name, proc_name, weight))
    
    return file_list

def process_ensemble_method2_parallel(n_workers=None):
    """
    Process all files using ensemble method2 with parallel processing
    
    Parameters:
    - n_workers: Number of parallel workers (default: use N_WORKERS)
    """
    if n_workers is None:
        n_workers = N_WORKERS
    
    print(f"Collecting file list...")
    file_list = collect_file_list()
    total_files = len(file_list)
    
    if total_files == 0:
        print("No files to process!")
        return
    
    print(f"Found {total_files} files to process")
    print(f"Using {n_workers} parallel workers")
    
    file_list_with_index = [
        (ifile, name, proc_name, weight, idx+1, total_files)
        for idx, (ifile, name, proc_name, weight) in enumerate(file_list)
    ]
    
    with Pool(processes=n_workers) as pool:
        results = pool.map(process_single_file, file_list_with_index)
    
    print(f"\n{'='*80}")
    print("Processing Summary")
    print(f"{'='*80}")
    
    successful = [r for r in results if r['success']]
    failed = [r for r in results if not r['success']]
    
    print(f"Total files processed: {total_files}")
    print(f"Successful: {len(successful)}")
    print(f"Failed: {len(failed)}")
    
    if successful:
        total_saved = sum(r['n_saved'] for r in successful)
        total_pass_combined = sum(r['n_pass_combined'] for r in successful)
        n_with_modellite = sum(1 for r in successful if r.get('has_modellite', False))
        n_without_modellite = len(successful) - n_with_modellite
        
        print(f"\nTotal events saved: {total_saved}")
        print(f"Events passing combined_basic_fit_cut: {total_pass_combined} ({100*total_pass_combined/total_saved:.2f}%)")
        print(f"Files with modellite: {n_with_modellite}")
        print(f"Files without modellite: {n_without_modellite}")
        
        proc_summary = {}
        for r in successful:
            proc = r['proc_name']
            if proc not in proc_summary:
                proc_summary[proc] = {'files': 0, 'events': 0, 'pass_combined': 0}
            proc_summary[proc]['files'] += 1
            proc_summary[proc]['events'] += r['n_saved']
            proc_summary[proc]['pass_combined'] += r['n_pass_combined']
        
        print("\nSummary by process:")
        for proc, stats in sorted(proc_summary.items()):
            pass_rate = 100 * stats['pass_combined'] / stats['events'] if stats['events'] > 0 else 0
            modellite_marker = "✓" if proc in PROCESS_WITH_MODELLITE else "✗"
            print(f"  {proc} [modellite:{modellite_marker}]: {stats['files']} files, {stats['events']} events, "
                  f"{stats['pass_combined']} pass combined cut ({pass_rate:.2f}%)")
    
    if failed:
        print(f"\nFailed files:")
        for r in failed:
            print(f"  {r['name']}: {r.get('reason', 'Unknown error')}")

if __name__ == "__main__":
    print("="*80)
    print("Starting Ensemble Method2 Ntuple Generation")
    print("with Diagonal Slice CDF Correction")
    print("="*80)
    print(f"\nOutput directory: {output_base_dir}")
    print(f"Available CPU cores: {cpu_count()}")
    print(f"Using {N_WORKERS} workers")
    print(f"\nCorrection method: Diagonal slice-based CDF matching with dynamic extension")
    print(f"  - Source: Saved smooth 2D distributions")
    print(f"  - Target: Tomography-corrected distributions")
    print(f"  - Correction range (M): [70, 170] GeV")
    print(f"  - Extended sampling range: [60, 180] GeV (linear from center)")
    print(f"    * M = 120: sampling in [70, 170]")
    print(f"    * M = 70:  sampling in [60, 170]")
    print(f"    * M = 170: sampling in [70, 180]")
    print(f"\nModellite processing:")
    print(f"  - Enabled for: {', '.join(PROCESS_WITH_MODELLITE)}")
    print("="*80 + "\n")
    
    process_ensemble_method2_parallel()
    
    print("\n" + "="*80)
    print("Processing completed!")
    print("="*80)


Loading Universal Correctors for all models...

Initializing corrector for Ensemble...
Loading source distributions for Ensemble...
Found 9 score bins to process.

  [Corrector] Processing bin [0.000, 0.100)...
  Source interpolator built: range [50.0, 190.0] GeV
  Target polynomial fit R^2 = 0.9888
  [Corrector] Processing bin [0.100, 0.300)...
  Source interpolator built: range [50.0, 190.0] GeV
  Target polynomial fit R^2 = 0.9933
  [Corrector] Processing bin [0.300, 0.500)...
  Source interpolator built: range [50.0, 190.0] GeV
  Target polynomial fit R^2 = 0.9942
  [Corrector] Processing bin [0.500, 0.700)...
  Source interpolator built: range [50.0, 190.0] GeV
  Target polynomial fit R^2 = 0.9948
  [Corrector] Processing bin [0.700, 0.900)...
  Source interpolator built: range [50.0, 190.0] GeV
  Target polynomial fit R^2 = 0.9956
  [Corrector] Processing bin [0.900, 0.950)...
  Source interpolator built: range [50.0, 190.0] GeV
  Target polynomial fit R^2 = 0.9961
  [Corrector] 