In [1]:
import os
import urllib

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.metrics import roc_curve

In [2]:
filename = "atlas-higgs-challenge-2014-v2.csv.gz"
url = "http://opendata.cern.ch/record/328/files/atlas-higgs-challenge-2014-v2.csv.gz"

In [3]:
if not os.path.exists(filename):
    urllib.request.urlretrieve(url, filename)
df = pd.read_csv(filename)

In [4]:
feat_columns = [col for col in df.columns if col[:3] in ["DER", "PRI"]]

X = df.loc[:, feat_columns]
y = df['Label']
weight = df['Weight']
(
    X_train,
    X_test,
    y_train,
    y_test,
    weight_train,
    weight_test,
) = train_test_split(
    X.to_numpy(),
    (y == "s").to_numpy(),
    weight.to_numpy(),
    test_size=0.33,
    random_state=42
)

In [5]:
def ams(s, b):
    """
    Approximate median significance, as defined in Higgs Kaggle challenge

    The number 10, added to the background yield, is a regularization term to decrease the variance of the AMS.
    """
    return np.sqrt(2 * ((s + b + 10) * np.log(1 + s / (b + 10)) - s))

In [6]:
sumw = df.groupby("Label").Weight.sum()
nsig_tot = sumw["s"]
nbkg_tot = sumw["b"]

In [7]:
class_weight = np.array([
    len(y_train) / weight_train[y_train==0].sum(),
    len(y_train) / weight_train[y_train==1].sum(),
])


In [8]:
model_gbc_weighted = HistGradientBoostingClassifier(class_weight={0.0:class_weight[0], 1.0:class_weight[1]})
model_gbc_weighted.fit(X_train, y_train, sample_weight=weight_train)
model_gbc_weighted.score(X_test, y_test, sample_weight=weight_test)

0.8372740002551992

In [10]:
p_test_weighted = model_gbc_weighted.predict_proba(X_test)[:,1]
roc_gbc_weighted = roc_curve(y_test, p_test_weighted, sample_weight=weight_test)

amsses = [ams(tpr*nsig_tot, fpr*nbkg_tot) for tpr, fpr in zip(roc_gbc_weighted[1], roc_gbc_weighted[0])]
max_ams = max(amsses)
max_thr = roc_gbc_weighted[2][amsses.index(max_ams)]
print(f"Maximum AMS of {max_ams:.3f} at a threshold of {max_thr:.3f}.")

Maximum AMS of 3.579 at a threshold of 0.947.
