# BIP Framework training example 

In [None]:
using Pkg
Pkg.activate("..")
using BIPs

In [None]:
using Statistics
using Pkg.Artifacts
using DelimitedFiles


Lets begin by bringing in the dataset. It contains tree splits:
* **train**: the training set with 1M jets
* **validation**: the validation set with 400k jets

And of course later we will use the **test** set with other 400k jets to report the results


In [None]:
dataset_path = "../../../DataLake/raw"
# dataset_path = "/Users/ortner/datasets/toptagging"

train_data_path = dataset_path*"/train.h5"
val_data_path = dataset_path*"/val.h5"

### Reading the data

In order to read the datasets, we call the `read_dataset` function:
to read the TopQuark format

In [None]:
train_jets, train_labels = BIPs.read_data("TQ", train_data_path)
train_labels = [reinterpret(Bool, b == 1.0) for b in train_labels]
print("Number of entries in the training data: ", length(train_jets))

In [None]:
val_jets, val_labels = BIPs.read_data("TQ", val_data_path)
val_labels = [reinterpret(Bool, b == 1.0) for b in val_labels]
print("Number of entries in the validation data: ", length(val_jets))

Lets examine how one of the jets looks like, each one of the entries is one detected particle's four momentum $(E, p_x, p_y, p_z)$.

However,in order to compute the embeddings, it is necesary to convert the jets to a format that can be used by the framework. The function `data2hyp` allows to convert each detected four momentum to the jet basis, a.k.a $(\tilde p_T, \cos(\theta), \sin(\theta), \tilde y, E_T)$

In [None]:
train_transf_jets = data2hyp(train_jets)
val_transf_jets = data2hyp(val_jets)
println("Transformed jets")

### The embeddings

Once the jets are converted to the jet basis, it is moment to embed the model using the *Invariant Polynomials*. 

The function `build_ip` allocates efficiently the sparse basis, while the `bip_data` computes the invariant representation of each one of the jets.

In [None]:
f_bip, specs = build_ip(order=4, levels=8)
    
function bip_data(dataset_jets)
    storage = zeros(length(dataset_jets), length(specs))
    for i = 1:length(dataset_jets)
        storage[i, :] = f_bip(dataset_jets[i])
    end
    storage[:, 2:end]
end

In [None]:
train_embedded_jets = bip_data(train_transf_jets)
println("Embedded train jets correclty")
val_embedded_jets = bip_data(val_transf_jets)
println("Embedded test jets correclty")

## Test data

In [None]:
test_data_path = "../../../DataLake/raw/test.h5"
test_jets, test_labels = BIPs.read_data("TQ", test_data_path)
test_labels = [reinterpret(Bool, b == 1.0) for b in test_labels]
test_transf_jets = data2hyp(test_jets)
test_embedded_jets = bip_data(test_transf_jets)
print("Embedded test jets correclty")

In [None]:
scale(A) = (A .- mean(A, dims=1)) ./ std(A, dims=1)

##
X_train, y_train = scale(train_embedded_jets), train_labels
X_test, y_test = scale(test_embedded_jets), test_labels
X_val, y_val = scale(val_embedded_jets), val_labels


# CLASSIFIERS

### MLP

In [None]:
using PyCall
@pyimport sklearn.neural_network as sk_nn

In [None]:
mlp = sk_nn.MLPClassifier(verbose=true, max_iter=2000, hidden_layer_sizes=(200,100), n_iter_no_change=50)

In [None]:
mlp.fit(X_train, train_labels)

In [None]:
mlp.score(X_val, val_labels)

In [None]:
mlp_test_scores = mlp.predict_proba(X_test)
writedlm( "../foo/mlp_test_probas2.csv",  mlp_test_scores, ',')


## XGboost

In [None]:
using XGBoost
function fit_xgb_clf(X_train, y_train, X_val, y_val)
    dtrain = DMatrix(X_train, label=y_train)
    dtest = DMatrix(X_val, label=y_val)
    param = ["silent" => 0, "subsample" => 0.5, "colsample_bytree" => 0.5, "eta" => 0.05]
    watchlist = [(dtest, "test"), (dtrain, "train")]
    bst = xgboost(X_train, 500, label=y_train, param=param, objective = "binary:logistic",
        watchlist=watchlist, metrics=["logloss", "auc", "error"],
        early_stopping_rounds=50, verbose_eval=10, seed=137)
    
    bst
end

In [None]:
model = fit_xgb_clf(X_train, y_train, X_val, val_labels)

In [None]:
xgb_probas = predict(model, X_test)

In [None]:
writedlm( "../foo/xgb_test_probas2.csv",  xgb_probas, ',')