# TMVA Classification Example in Python

<img src="tmva_logo.gif" height="20%" width="20%">

Notebook based on ROOT TMVA Tutorial.

Original author: Lorenzo Moneta | Edited by Sitong An for CMSDAS2019@Pisa

The Toolkit for Multivariate Data Analysis with ROOT (TMVA) is a ROOT-integrated project providing a machine learning environment. For traditional Machine Learning methods like Boosted Decision Trees (BDT), TMVA is widely used by Particle Physics community in the analyses.

This exercise shows an example of TMVA BDT classification using BDTs on some toy data in Python, so as to give you a glimpse on how TMVA is used in analyses. For the same example in C++, and more content on TMVA and ROOT-based multivariate analyses, go to [ https://swan.web.cern.ch/content/machine-learning ].

This exercise  For  TMVA User Guide [ https://github.com/root-project/root/blob/master/documentation/tmva/UsersGuide/TMVAUsersGuide.pdf ] can also be a helpful reference in explaining the method parameters.

In [1]:
import ROOT
from ROOT import TMVA

Welcome to JupyROOT 6.12/07


## Declare Factory

Create the Factory class. Later you can choose the MVA methods you'd like to use. 

The factory is the major TMVA object you have to interact with. Here is the list of parameters you need to pass

 - The first argument is the base name of all the output weight files

 - The second argument is the output file for the training results
  
 - The third argument is a string option defining some general configuration for the TMVA session. For example, all TMVA output can be suppressed by removing the "!" (not) in front of the "Silent" argument in the option string

In [2]:
ROOT.TMVA.Tools.Instance()
## For PYMVA methods
TMVA.PyMethodBase.PyInitialize();


outputFile = ROOT.TFile.Open("ClassificationOutput.root", "RECREATE")

factory = ROOT.TMVA.Factory("TMVA_Classification", outputFile,
                      "!V:ROC:!Silent:Color:!DrawProgressBar:AnalysisType=Classification" )

## Input Data

We define now the input data file and retrieve the ROOT TTree objects with signal and background input events

In [3]:
inputFileName = "data/tmva_data.root"

inputFile = ROOT.TFile.Open( inputFileName )

# retrieve input trees

signalTree     = inputFile.Get("sig_tree")
backgroundTree = inputFile.Get("bkg_tree")


In [4]:
inputFile.ls()

TFile**		data/tmva_data.root	
 TFile*		data/tmva_data.root	
  OBJ: TTree	sig_tree	Signal Tree : 0 at: 0x9d6ba00
  OBJ: TTree	bkg_tree	Background Tree : 0 at: 0x9f02d40
  KEY: TTree	sig_tree;1	Signal Tree
  KEY: TTree	bkg_tree;1	Background Tree


and print a summary of the signal tree

In [5]:
signalTree.Print()

******************************************************************************
*Tree    :sig_tree  : Signal Tree                                            *
*Entries :    10000 : Total =         1141446 bytes  File  Size =    1000730 *
*        :          : Tree compression factor =   1.13                       *
******************************************************************************
*Br    0 :lepton_pT : lepton_pT/F                                            *
*Entries :    10000 : Total  Size=      40761 bytes  File Size  =      30836 *
*Baskets :        1 : Basket Size=      32000 bytes  Compression=   1.04     *
*............................................................................*
*Br    1 :lepton_eta : lepton_eta/F                                          *
*Entries :    10000 : Total  Size=      40768 bytes  File Size  =      29658 *
*Baskets :        1 : Basket Size=      32000 bytes  Compression=   1.08     *
*...................................................

## Declare DataLoader(s)

The next step is to declare the DataLoader class that deals with input data and variables 

We add first the signal and background trees to the data loader, then define the input variables that will be used for the MVA training. Note that you may also use expressions.

In [6]:
loader = ROOT.TMVA.DataLoader("dataset")

### global event weights per tree (see below for setting event-wise weights)
signalWeight     = 1.0
backgroundWeight = 1.0
   
### You can add an arbitrary number of signal or background trees
loader.AddSignalTree    ( signalTree,     signalWeight     )
loader.AddBackgroundTree( backgroundTree, backgroundWeight )

## Define input variables 

loader.AddVariable("m_jj")
loader.AddVariable("m_jjj")
loader.AddVariable("m_lv")
loader.AddVariable("m_jlv")
loader.AddVariable("m_bb")
loader.AddVariable("m_wbb")
loader.AddVariable("m_wwbb")

DataSetInfo              : [dataset] : Added class "Signal"
                         : Add Tree sig_tree of type Signal with 10000 events
DataSetInfo              : [dataset] : Added class "Background"
                         : Add Tree bkg_tree of type Background with 10000 events


## Setup Dataset(s)

Setup the DataLoader by splitting events in training and test samples. 
Here we use a random split and a fixed number of training and test events.


In [7]:

## Apply additional cuts on the signal and background samples (can be different)
mycuts = ROOT.TCut("")   ## for example:  mycuts = ROOT.TCut("abs(var1)<0.5 && abs(var2-0.5)<1")
mycutb = ROOT.TCut("")   ## for example:  mycutb = ROOT.TCut("abs(var1)<0.5")


loader.PrepareTrainingAndTestTree( mycuts, mycutb,
                                  "nTrain_Signal=7000:nTrain_Background=7000:SplitMode=Random:"
                                   "NormMode=NumEvents:!V" )

# Booking Methods

Here we book the TMVA methods. We book a Likelihood based a BDT and a standard MLP (shallow NN)

In [8]:
## Boosted Decision Trees
factory.BookMethod(loader,ROOT.TMVA.Types.kBDT, "BDT",
                   "!V:NTrees=200:MinNodeSize=2.5%:MaxDepth=2:BoostType=AdaBoost:AdaBoostBeta=0.5:UseBaggedBoost:"
                   "BaggedSampleFraction=0.5:SeparationType=GiniIndex:nCuts=20" )

## Multi-Layer Perceptron (Neural Network)
factory.BookMethod(loader, ROOT.TMVA.Types.kMLP, "MLP",
                   "!H:!V:NeuronType=tanh:VarTransform=N:NCycles=100:HiddenLayers=N+5:TestRate=5:!UseRegulator" );

Factory                  : Booking method: [1mBDT[0m
                         : 
DataSetFactory           : [dataset] : Number of events in input trees
                         : 
                         : 
                         : Number of training and testing events
                         : ---------------------------------------------------------------------------
                         : Signal     -- training events            : 7000
                         : Signal     -- testing events             : 3000
                         : Signal     -- training and testing events: 10000
                         : Background -- training events            : 7000
                         : Background -- testing events             : 3000
                         : Background -- training and testing events: 10000
                         : 
DataSetInfo              : Correlation matrix (Signal):
                         : -----------------------------------------------------------

## Using scikit-learn (optional)

 We can also book some scikit-learn packages into TMVA factory. 
Click on the white space in front of the cell and press Y to activate this. To deactivate again, press R.

## Train Methods

In [9]:
factory.TrainAllMethods();

Factory                  : [1mTrain all methods[0m
Factory                  : [dataset] : Create Transformation "I" with events from all classes.
                         : 
                         : Transformation, Variable selection : 
                         : Input : variable 'm_jj' <---> Output : variable 'm_jj'
                         : Input : variable 'm_jjj' <---> Output : variable 'm_jjj'
                         : Input : variable 'm_lv' <---> Output : variable 'm_lv'
                         : Input : variable 'm_jlv' <---> Output : variable 'm_jlv'
                         : Input : variable 'm_bb' <---> Output : variable 'm_bb'
                         : Input : variable 'm_wbb' <---> Output : variable 'm_wbb'
                         : Input : variable 'm_wwbb' <---> Output : variable 'm_wwbb'
TFHandler_Factory        : Variable        Mean        RMS   [        Min        Max ]
                         : -----------------------------------------------------------
 

## Test  all methods

Here we test all methods using the test data set

In [10]:
factory.TestAllMethods();   

Factory                  : [1mTest all methods[0m
Factory                  : Test method: BDT for Classification performance
                         : 
BDT                      : [dataset] : Evaluation of BDT on testing sample (6000 events)
                         : Elapsed time for evaluation of 6000 events: 0.04 sec       
Factory                  : Test method: MLP for Classification performance
                         : 
MLP                      : [dataset] : Evaluation of MLP on testing sample (6000 events)
                         : Elapsed time for evaluation of 6000 events: 0.00959 sec       


## Evaluate all methods

Here we evaluate all methods and compare their performances, computing efficiencies, ROC curves etc.. using both training and tetsing data sets. Several histograms are produced which can be examined with the TMVAGui or directly using the output file

In [11]:
factory.EvaluateAllMethods();

Factory                  : [1mEvaluate all methods[0m
Factory                  : Evaluate classifier: BDT
                         : 
BDT                      : [dataset] : Loop over test events and fill histograms with classifier response...
                         : 
TFHandler_BDT            : Variable        Mean        RMS   [        Min        Max ]
                         : -----------------------------------------------------------
                         :     m_jj:     1.0454    0.68951   [    0.16034     17.681 ]
                         :    m_jjj:     1.0296    0.38670   [    0.38832     8.8785 ]
                         :     m_lv:     1.0529    0.16718   [    0.49665     2.9841 ]
                         :    m_jlv:     1.0096    0.39897   [    0.40002     6.3362 ]
                         :     m_bb:    0.97642    0.53416   [   0.078035     7.2843 ]
                         :    m_wbb:     1.0353    0.36429   [    0.40515     4.9710 ]
                         :   m_

## Plot ROC Curve
We enable JavaScript visualisation for the plots

In [12]:
%jsroot on

In [13]:
c1 = factory.GetROCCurve(loader);
c1.Draw();


####  Close outputfile to save all output information (evaluation result of methods)

In [14]:
outputFile.Close();