# Train a Classifier

In this notebook we train a Gradient Boosting Decision Tree (GBDT) classifier using the implementation of the package [LightGBM](https://lightgbm.readthedocs.io/en/latest/).

#### Index<a name="index"></a>
1. [Import Packages](#imports)
2. [Load Features](#loadFeatures)
3. [Generate Classifier](#generateClassifier)
    1. [Untrained Classifier](#createClassifier)
    2. [Train Classifier](#trainClassifier)
    3. [Save the Classifier Instance](#saveClassifier)

## 1. Import Packages<a name="imports"></a>

In [1]:
import os
import pickle
import sys
import time

In [2]:
import lightgbm as lgb
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.stats

In [3]:
from snmachine import snclassifier
from snmachine.utils.plasticc_pipeline import get_directories, load_dataset

In [4]:
import warnings
warnings.simplefilter('always', DeprecationWarning)

In [5]:
%config Completer.use_jedi = False  # enable autocomplete

## 2. Load Features<a name="loadFeatures"></a>

First, **write** the path to the folder that contains the features and the labels of the events (`path_saved_features`). These quantities were calculated and saved in [5_feature_extraction](5_feature_extraction.ipynb).

### 2.1. Features Path<a name="pathFeatures"></a>

**<font color=Orange>A)</font>** Obtain path from folder structure.

If you created a folder structure, you can obtain the path from there. **Write** the name of the folder in `analysis_name`. 

In [6]:
analysis_name = 'example_dataset_aug' 

In [7]:
folder_path = '../snmachine/example_data'

directories = get_directories(folder_path, analysis_name) 
path_saved_features = directories['features_directory']

**<font color=Orange>B)</font>** Directly **write** where you saved the files.

```python
folder_path = '../snmachine/example_data'
path_saved_features = folder_path
```

### 2.2. Load<a name="load"></a>

Then, load the features and labels.

In [8]:
X = pd.read_pickle(os.path.join(path_saved_features, 'features.pckl'))  # features
y = pd.read_pickle(os.path.join(path_saved_features, 'data_labels.pckl'))  # class label of each event

**<font color=Orange>A)</font>** If the dataset is not augmented, skip **<font color=Orange>B)</font>**.


**<font color=Orange>B)</font>** If the dataset is augmented, load the augmented dataset.

In order to avoid information leaks during the classifier optimization, all synthetic events generated by the training set augmentation which derived from the same original event must be placed in the same cross-validation fold. 

First, **write** in `data_file_name` the name of the file where your dataset is saved.

In this notebook we use the dataset saved in [4_augment_data](4_augment_data.ipynb).

In [9]:
data_file_name = 'example_dataset_aug.pckl'

Then, load the augmented dataset.

In [10]:
data_path = os.path.join(folder_path, data_file_name)
dataset = load_dataset(data_path)

Opening from binary pickle
Dataset loaded from pickle file as: <snmachine.sndata.PlasticcData object at 0x7f09837546a0>


In [11]:
metadata = dataset.metadata

## 3. Generate Classifier<a name="generateClassifier"></a>

### 3.1. Untrained Classifier<a name="createClassifier"></a>

Start by creating a classifier. For that **choose**: 

- classifier type: `snmachine` contains the following classifiers
    * [LightGBM](https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.LGBMClassifier.html?highlight=classifier) classifier - `snclassifier.LightGBMClassifier`
    * Boosted decision trees - `snclassifier.BoostDTClassifier`
    * Boosted random forests - `snclassifier.BoostRFClassifier`
    * K-nearest neighbors vote - `snclassifier.KNNClassifier`
    * Support vector machine - `snclassifier.SVMClassifier`
    * Multi-layer Perceptron classifier of a Neural Network - `snclassifier.NNClassifier`
    * Random forest - `snclassifier.RFClassifier`
    * Decision tree - `snclassifier.DTClassifier`
    * Gaussian Naive Bayes - `snclassifier.NBClassifier`
- `random_seed`: this allows reproducible results (**<font color=green>optional</font>**).
- `classifier_name`: name under which the classifier is saved (**<font color=green>optional</font>**).
- `**kwargs`: optional keywords to pass arguments into the underlying classifier; see the docstring in each classifier for more information (**<font color=green>optional</font>**).

Here we chose a LightGBM classifier.

In [12]:
classifier_instance = snclassifier.LightGBMClassifier(classifier_name='our_classifier', random_seed=42)

Created classifier of type: LGBMClassifier(random_state=42).



### 3.2. Train Classifier<a name="trainClassifier"></a>

We can now train and use the classifier generated above or optimise it beforehand. In general, it is important to optimise the classifier hyperparameters.

If you do not want to optimise the classifier, **run** **<font color=Orange>A)</font>**.

**<font color=Orange>A)</font>** Train unoptimised classifier.

```python
classifier.fit(X, y)
```

If you want to optimise the classifier, run **<font color=Orange>B)</font>**.

**<font color=Orange>B)</font>** Optimise and train classifier.

For that, **choose**:
- `param_grid`: parameter grid containing the hyperparameters names and lists of their possible settings as values. If none is provided, the code uses a default parameter grid. (**<font color=green>optional</font>**)
- `scoring`: metric used to evaluate the predictions on the validation sets and write it in `scoring`. 
    * `snmachine` contains the `'auc'` and the PLAsTiCC `'logloss'` costum metrics. For more details about these, see `snclassifier.logloss_score` and `snclassifier.auc_score`, respectively.
    * Additionally, you can choose a different metric from the list in [Scikit-learn](https://scikit-learn.org/stable/modules/model_evaluation.html#scoring-parameter) or create your own (see [`sklearn.model_selection._search.GridSearchCV`](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html) for details).
- `number_cv_folds`: number of folds for cross-validation. By default it is 5. (**<font color=green>optional</font>**)
- `metadata`: metadata of the events with which to train the classifier. This ensures all synthetic events generated by the training set augmentation that were derived from the same original event are placed in the same cross-validation fold. (**<font color=green>optional</font>**)

In [13]:
param_grid={'learning_rate': [.1, .25, .5]}

classifier_instance.optimise(X, y, param_grid=param_grid, scoring='logloss', 
                             number_cv_folds=5, metadata=metadata)

Cross-validation for an augmented dataset.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 724
[LightGBM] [Info] Number of data points in the train set: 47, number of used features: 42
[LightGBM] [Info] Start training from score -1.211090
[LightGBM] [Info] Start training from score -1.452252
[LightGBM] [Info] Start training from score -0.759105
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 769
[LightGBM] [Info] Number of data points in the train set: 50, number of used features: 42
[LightGBM] [Info] Start training from score -1.139434
[LightGBM] [Info] Start training from score -1.514128
[LightGBM] [Info] Start training from score -0.776529
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 742
[LightGBM] [Info] Number of data points in the train set: 48, number of used features: 42
[LightGBM] [Info] Start training from score -1.232144
[LightGBM] [Info] Start training fro

The classifier is optimised and its optimised hyperparameters are:

In [14]:
classifier_instance.classifier

In [15]:
classifier_instance.grid_search.best_params_

{'learning_rate': 0.1}

In [16]:
classifier_instance.classifier_name

'our_classifier'

### 3.3. Save the Classifier Instance<a name="saveClassifier"></a>

**Write** in `path_saved_classifier` the path to the folder where to save the trained classifier instance.

In [17]:
path_saved_classifier = directories['classifications_directory']

Save the classifier instance (which includes the grid search used to optimise the classifier).

In [18]:
classifier_instance.save_classifier(path_saved_classifier)

Classifier saved in ../snmachine/example_data/example_dataset_aug/classifications/our_classifier.pck .


[Go back to top.](#index)

*Previous notebook:* [5_feature_extraction](5_feature_extraction.ipynb)

**Next notebook:** [7_classify_test](7_classify_test.ipynb)