In [None]:
import os
import glob
import numpy as np
import uproot
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import xgboost as xgb
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import roc_curve, auc
from scipy.stats import ks_2samp

# --- Config ---
input_dir = "/Users/artemis/Desktop/bdt/bdt_inputs"
output_dir = "/Users/artemis/Desktop/bdt/bdt_results"
os.makedirs(output_dir, exist_ok=True)

# Define signal files per regime
signal_files_map = {
    "boosted": [
        "bdt_ZH12.root", "bdt_ZH15.root"
    ],
    "merged": [
        "bdt_ZH20.root", "bdt_ZH25.root"
    ],
    "resolved": [
        "bdt_ZH30.root", "bdt_ZH40.root", "bdt_ZH50.root", "bdt_ZH60.root"
    ]
}

# Glob all background files (QCD, TT, Zto)
background_patterns = ["bdt_QCD*.root", "bdt_TT*.root", "bdt_Zto*.root"]
background_files = []
for pattern in background_patterns:
    background_files.extend(glob.glob(os.path.join(input_dir, pattern)))

regimes = ["boosted", "merged", "resolved"]

# --- Data Loader ---
def load_tree_data(filename, regime):
    with uproot.open(filename) as f:
        if regime not in f:
            return None
        return f[regime].arrays(library="np")

def load_data(files, regime, label, weight_scale=180000):
    X_all, y_all, w_all = [], [], []
    features_used = []

    for f in files:
        data = load_tree_data(f, regime)
        if data is None:
            continue

        features = [k for k in data.keys() if k != "weight"]
        X = np.column_stack([data[k] for k in features])
        y = np.full(X.shape[0], label)
        w = data["weight"] * weight_scale

        X_all.append(X)
        y_all.append(y)
        w_all.append(w)
        features_used = features

        print(f"[DEBUG] Loaded {X.shape[0]} samples from {os.path.basename(f)} with label {label}")

    if not X_all:
        return None, None, None, []

    return (
        np.concatenate(X_all),
        np.concatenate(y_all),
        np.concatenate(w_all),
        features_used,
    )

# --- Main Loop ---
for regime in regimes:
    print(f"\n[INFO] Training BDT for regime: {regime}")

    sig_files = [os.path.join(input_dir, fname) for fname in signal_files_map[regime]]
    X_sig, y_sig, w_sig, features = load_data(sig_files, regime, 1)
    X_bkg, y_bkg, w_bkg, _ = load_data(background_files, regime, 0)
    if X_sig is None or X_bkg is None:
        print(f"[WARNING] Missing data for regime {regime}. Skipping.")
        continue

    X = np.concatenate([X_sig, X_bkg])
    y = np.concatenate([y_sig, y_bkg])
    w = np.concatenate([w_sig, w_bkg])

    X_train, X_test, y_train, y_test, w_train, w_test = train_test_split(
        X, y, w, test_size=0.5, stratify=y
    )

    df_train = pd.DataFrame(X_train, columns=features)
    df_test = pd.DataFrame(X_test, columns=features)
    df_train['label'] = y_train
    df_test['label'] = y_test

    # Correlation matrices
       # Correlation matrices WITH percentages
    for tag, df in zip(["train_sig", "train_bkg", "test_sig", "test_bkg"],
                       [df_train[df_train.label == 1], df_train[df_train.label == 0],
                        df_test[df_test.label == 1], df_test[df_test.label == 0]]):
        corr = df[features].corr()
        plt.figure(figsize=(10, 8))
        sns.heatmap(corr, annot=True, fmt=".2f", cmap="coolwarm", square=True, cbar=True,
                    annot_kws={"size": 8})
        plt.title(f"Correlation Matrix ({regime} - {tag})")
        plt.tight_layout()
        plt.savefig(f"{output_dir}/corr_{regime}_{tag}.pdf")
        plt.close()


    # Input variable plots
    for i, feature in enumerate(features):
        plt.figure()
        plt.hist(X_sig[:, i], bins=100, density=True, alpha=0.5, label="Signal", color="blue")
        plt.hist(X_bkg[:, i], bins=100, density=True, alpha=0.5, label="Background", color="red")
        plt.xlabel(feature)
        plt.ylabel("Normalized Events")
        plt.title(f"{feature} ({regime})")
        plt.legend()
        plt.grid()
        plt.tight_layout()
        plt.savefig(f"{output_dir}/feature_{regime}_{feature}.pdf")
        plt.close()

    # GridSearchCV
    param_grid = {
        "max_depth": [1,2, 3, 4],
        "learning_rate": [0.01, 0.05,0.08,0.09, 0.1],
        "n_estimators": [100, 200, 300],
        "subsample": [0.5, 0.8, 1.0],
        "colsample_bytree": [0.6,0.8, 1.0],
        "tree_method": ["hist"],
        "gamma": [0, 0.5, 1.0],
        "alpha": [0, 0.1, 1],
        "reg_lambda": [1, 2]
    }
        

    model = xgb.XGBClassifier(objective="binary:logistic", eval_metric="logloss")
    grid = GridSearchCV(model, param_grid, scoring="roc_auc", cv=3, verbose=1, n_jobs=-1)
    grid.fit(X_train, y_train, sample_weight=w_train)

    model = grid.best_estimator_
    print(f"[INFO] Best parameters: {grid.best_params_}")

    y_pred_train = model.predict_proba(X_train)[:, 1]
    y_pred_test  = model.predict_proba(X_test)[:, 1]

    # ROC curve
    fpr_train, tpr_train, _ = roc_curve(y_train, y_pred_train, sample_weight=w_train)
    fpr_test, tpr_test, _   = roc_curve(y_test, y_pred_test, sample_weight=w_test)
    auc_train = auc(fpr_train, tpr_train)
    auc_test = auc(fpr_test, tpr_test)

    plt.figure()
    plt.plot(fpr_train, tpr_train, label=f"Train AUC = {auc_train:.3f}")
    plt.plot(fpr_test, tpr_test, label=f"Test AUC = {auc_test:.3f}")
    plt.plot([0, 1], [0, 1], "k--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"ROC Curve ({regime})")
    plt.legend()
    plt.grid()
    plt.savefig(f"{output_dir}/roc_{regime}.pdf")
    plt.close()

    # Overtraining check
    ks_sig = ks_2samp(y_pred_train[y_train == 1], y_pred_test[y_test == 1])
    ks_bkg = ks_2samp(y_pred_train[y_train == 0], y_pred_test[y_test == 0])

    bins = np.linspace(0, 1, 100)
    plt.figure()
    plt.hist(y_pred_train[y_train == 1], bins=bins, density=True, histtype='step', label="Sig train", color="blue")
    plt.hist(y_pred_test[y_test == 1], bins=bins, density=True, histtype='step', linestyle='--', label="Sig test", color="blue")
    plt.hist(y_pred_train[y_train == 0], bins=bins, density=True, histtype='step', label="Bkg train", color="red")
    plt.hist(y_pred_test[y_test == 0], bins=bins, density=True, histtype='step', linestyle='--', label="Bkg test", color="red")
    plt.title(f"Overtraining — KS p: Sig={ks_sig.pvalue:.3f}, Bkg={ks_bkg.pvalue:.3f}")
    plt.xlabel("BDT Score")
    plt.ylabel("Density")
    plt.legend()
    plt.grid()
    plt.savefig(f"{output_dir}/overtraining_{regime}.pdf")
    plt.close()
    # variable importance
        # Variable importance ranking (with actual feature names)
    importances = model.feature_importances_
    importance_df = pd.DataFrame({
        "Feature": features,
        "Importance": importances
    }).sort_values(by="Importance", ascending=False)

    # Save as bar chart
    plt.figure(figsize=(10, 6))
    sns.barplot(x="Importance", y="Feature", data=importance_df, palette="viridis")
    plt.title(f"Feature Importance ({regime})")
    plt.tight_layout()
    plt.savefig(f"{output_dir}/feature_importance_{regime}.pdf")
    plt.close()

    # Save as CSV (optional)
    importance_df.to_csv(f"{output_dir}/feature_importance_{regime}.csv", index=False)

    # Efficiency, Rejection, Significance
    s_hist, _ = np.histogram(y_pred_test[y_test == 1], bins=bins, weights=w_test[y_test == 1])
    b_hist, _ = np.histogram(y_pred_test[y_test == 0], bins=bins, weights=w_test[y_test == 0])
    s_cumsum = s_hist[::-1].cumsum()[::-1]
    b_cumsum = b_hist[::-1].cumsum()[::-1]

    efficiency = s_cumsum / s_cumsum[0]
    rejection = 1 - b_cumsum / b_cumsum[0]
    significance = s_cumsum / np.sqrt(s_cumsum + b_cumsum )

    plt.figure()
    plt.plot(bins[:-1], efficiency, label="Signal Efficiency")
    plt.plot(bins[:-1], rejection, label="Background Rejection")
    plt.plot(bins[:-1], significance, label="Significance (S/√(S+B))")
    plt.xlabel("BDT Cut")
    plt.ylabel("Metric")
    plt.title(f"Performance Metrics vs. BDT Cut ({regime})")
    plt.legend()
    plt.grid()
    plt.savefig(f"{output_dir}/eff_rej_significance_{regime}.pdf")
    plt.close()

    model.save_model(f"{output_dir}/bdt_model_{regime}.json")
    print(f"[DONE] {regime} — AUC: {auc_test:.3f}, Max Significance: {np.max(significance):.2f}")
