Skip to content

Commit

Permalink
add print for input varibales/sampels to TMVA trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
VanyaBelyaev committed May 2, 2024
1 parent 73d2851 commit ae645ba
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 12 deletions.
6 changes: 3 additions & 3 deletions ostap/tools/tests/test_tools_tmva.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ def test_tmva () :
( ROOT.TMVA.Types.kKNN , 'KNN' , "H:!V:nkNN=20:ScaleFrac=0.8:SigmaFact=1.0:Kernel=Gaus:UseKernel=F:UseWeight=T:!Trim" ) ,
##
] ,
variables = [ 'var1' , 'var2' , 'var3' ] , ## Variables for training
signal = tSignal , ## ``Signal'' sample
background = tBkg , ## ``Background'' sample
variables = [ 'var1' , 'var2' , 'var3' ] , ## Variables for training
signal = tSignal , ## `Signal' sample
background = tBkg , ## `Background' sample
verbose = True ,
signal_train_fraction = 0.75 ,
background_train_fraction = 0.75 ,
Expand Down
28 changes: 21 additions & 7 deletions ostap/tools/tmva.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,8 +1249,8 @@ def __train ( self ) :
# The table
# =================================================================

## if self.verbose :
if 1 < 2 :
if self.verbose :
## if 1 < 2 :

rows = [ ( 'Item' , 'Value' ) ]

Expand Down Expand Up @@ -1330,8 +1330,23 @@ def __train ( self ) :
table = T.table ( rows , title = title , prefix = "# " , alignment = "lw" )
self.logger.info ( "%s\n%s" % ( title , table ) )



if self.verbose :

import ostap.trees.trees
import ostap.trees.cuts

stitle = 'Input Signal variables'
sc = ROOT.TCut ( self.signal_cuts )
if self.signal_weight : sc *= self.signal_weight
tS = self.signal.table2 ( vv , title = stitle , cuts = sc , prefix = '# ' )
self.logger.info ( '%s\n%s' % ( stitle , tS ) )

btitle = 'Input Background variables'
bc = ROOT.TCut ( self.background_cuts )
if self.background_weight : bc *= self.background_weight
tB = self.background.table2 ( vv , title = btitle , cuts = bc , prefix = '# ' )
self.logger.info ( '%s\n%s' % ( btitle , tB ) )

bo = self.bookingoptions.split (':')
bo.sort()
if self.verbose : self.logger.info ( 'Book TMVA-factory %s ' % bo )
Expand All @@ -1352,7 +1367,7 @@ def __train ( self ) :
for v in self.signal_vars : avars.add ( v )
for v in self.background_vars : avars.add ( v )
avars = sorted ( avars )

all_vars = []
## for v in self.variables :
for v in avars :
Expand All @@ -1367,7 +1382,6 @@ def __train ( self ) :
if isinstance ( vv , str ) : vv = ( vv , 'F' )
all_vars.append ( vv[0] )
dataloader.AddSpectator ( *vv )
#

if self.verbose : self.logger.info ( "Loading 'Signal' sample" )
dataloader.AddTree ( self.signal , 'Signal' , 1.0 , ROOT.TCut ( self. signal_cuts ) )
Expand Down Expand Up @@ -1523,7 +1537,7 @@ def __train ( self ) :
def makePlots ( self , name = None , output = None , ) :
"""Make selected standard TMVA plots"""

self.logger.warning ( "makePlots: method is disbaled!" )
self.logger.warning ( "makePlots: method is (temporarily?) disabled!" )
return

name = name if name else self.name
Expand Down
5 changes: 3 additions & 2 deletions ostap/trees/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -1432,8 +1432,9 @@ def _rt_table_1_ ( tree ,

bbs = tuple ( sorted ( variables ) )

if hasattr ( tree , 'pstatVar' ) : bbstats = tree.pstatVar ( bbs , cuts , *args )
else : bbstats = tree. statVar ( bbs , cuts , *args )
if hasattr ( tree , 'fstatVar' ) : bbstats = tree.fstatVar ( bbs , cuts , *args )
elif hasattr ( tree , 'pstatVar' ) : bbstats = tree.pstatVar ( bbs , cuts , *args )
else : bbstats = tree. statVar ( bbs , cuts , *args )

from ostap.stats.counters import WSE
if isinstance ( bbstats , WSE ) : bbstats = { bbs[0] : bbstats }
Expand Down

0 comments on commit ae645ba

Please sign in to comment.