Skip to content
This repository has been archived by the owner on May 2, 2022. It is now read-only.

Commit

Permalink
Merge branch 'master-20180516-01-branch'
Browse files Browse the repository at this point in the history
  • Loading branch information
TaiSakuma committed May 16, 2018
2 parents 7f9f584 + 1c1b4ae commit bdddeae
Show file tree
Hide file tree
Showing 10 changed files with 166 additions and 5 deletions.
1 change: 0 additions & 1 deletion alphatwirl/roottree/Branch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Tai Sakuma <tai.sakuma@gmail.com>
from .BranchAddressManager import BranchAddressManager

##__________________________________________________________________||
class Branch(object):
Expand Down
34 changes: 33 additions & 1 deletion alphatwirl/roottree/BranchAddressManager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Tai Sakuma <tai.sakuma@gmail.com>
import array

import ROOT

from .inspect import is_ROOT_null_pointer

##__________________________________________________________________||
Expand Down Expand Up @@ -98,6 +100,10 @@ def inspectLeaf(tree, bname):
leafcount = leaf.GetLeafCount()
isArray = not is_ROOT_null_pointer(leafcount)

countmax = None
if isArray:
countmax = _get_countmax(leafcount, tree, bname)

return dict(
name=leaf.GetName(),
ROOTtype=leaf.GetTypeName(),
Expand All @@ -106,7 +112,33 @@ def inspectLeaf(tree, bname):
countname=leafcount.GetName() if isArray else None,
countROOTtype=leafcount.GetTypeName() if isArray else None,
countarraytype=typedic[leafcount.GetTypeName()] if isArray else None,
countmax=leafcount.GetMaximum() if isArray else None
countmax=countmax,
)

def _get_countmax(leafcount, tree, bname):
# If the tree is a chain, leafcount.GetMaximum() only returns the
# maximum in the current file. The `countmax` needs to be the
# maximum in all files. Not very efficient or elegant, the current
# implementation opens all files in the chain and finds the
# maximum.

try:
tobjarray_files = tree.GetListOfFiles()
except AttributeError:
# the tree is not a chain
return leafcount.GetMaximum()

if 0 == tobjarray_files.GetLast():
# the chain has only one file
return leafcount.GetMaximum()

filepahts = (f.GetTitle() for f in tobjarray_files)
files = (ROOT.TFile.Open(p) for p in filepahts)
trees = (f.Get(tree.GetName()) for f in files)
leaves = (t.GetLeaf(bname) for t in trees)
leafcounts = (l.GetLeafCount() for l in leaves)
countmaxs = (l.GetMaximum() for l in leafcounts)
countmax = max(countmaxs)
return countmax

##__________________________________________________________________||
2 changes: 1 addition & 1 deletion alphatwirl/roottree/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .Events import Events
from .BranchAddressManager import BranchAddressManager
from .Branch import Branch
from .EventBuilderConfig import EventBuilderConfig

Expand All @@ -16,5 +15,6 @@
from .BEventBuilder import BEventBuilder
from .BranchBuilder import BranchBuilder
from .EventBuilder import EventBuilder
from .BranchAddressManager import BranchAddressManager
from .BranchAddressManagerForVector import BranchAddressManagerForVector
from .inspect import inspect_tree
56 changes: 56 additions & 0 deletions tests/unit/roottree/create_sample_chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#!/usr/bin/env python
# Tai Sakuma <tai.sakuma@gmail.com>

import array
import ROOT

##__________________________________________________________________||
def main():

file_name = 'sample_chain_01.root'
content_list = [
[10, 20, 30],
[24, 5]
]
create_file(file_name, content_list)

file_name = 'sample_chain_02.root'
content_list = [
[3, 10],
[5, 8, 32, 15, 2],
[22, 11],
]
create_file(file_name, content_list)

file_name = 'sample_chain_03.root'
content_list = [
[2, 7],
[10, 100],
]
create_file(file_name, content_list)

def create_file(name, contents):

f = ROOT.TFile(name, 'recreate')
t = ROOT.TTree('tree', 'sample tree')

max_nvar = 128;
nvar = array.array('i', [0])
var = array.array('i', max_nvar*[0])

t.Branch('nvar', nvar, 'nvar/I')
t.Branch('var', var, 'var[nvar]/I')

for c in contents:
nvar[0] = len(c)
for i, v in enumerate(c):
var[i] = v
t.Fill()

t.Write()

##__________________________________________________________________||
if __name__ == '__main__':
main()

##__________________________________________________________________||
Binary file added tests/unit/roottree/sample_chain_01.root
Binary file not shown.
Binary file added tests/unit/roottree/sample_chain_02.root
Binary file not shown.
Binary file added tests/unit/roottree/sample_chain_03.root
Binary file not shown.
63 changes: 63 additions & 0 deletions tests/unit/roottree/test_BEvents_sample_chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Tai Sakuma <tai.sakuma@gmail.com>
import os
import sys
import pytest

has_no_ROOT = False
try:
import ROOT
except ImportError:
has_no_ROOT = True

if not has_no_ROOT:
from alphatwirl.roottree import BEvents as Events
from alphatwirl.roottree import Branch

##__________________________________________________________________||
pytestmark = pytest.mark.skipif(has_no_ROOT, reason="has no ROOT")

##__________________________________________________________________||
@pytest.fixture()
def chain():
input_file_names = [
'sample_chain_01.root',
'sample_chain_02.root',
'sample_chain_03.root',
]
input_paths = [
os.path.join(os.path.dirname(os.path.realpath(__file__)), n)
for n in input_file_names
]
tree_name = 'tree'
chain = ROOT.TChain(tree_name)
for p in input_paths:
chain.Add(p)
yield chain

@pytest.fixture()
def events(chain):
yield Events(chain)

def test_event(events):

content_list = [
# file 1
[10, 20, 30],
[24, 5],

# file 2
[3, 10],
[5, 8, 32, 15, 2],
[22, 11],

# file 3
[2, 7],
[10, 100],
]

for i, c in enumerate(content_list):
event = events[i]
assert len(c) == event.nvar[0]
assert c == list(event.var)

##__________________________________________________________________||
13 changes: 12 additions & 1 deletion tests/unit/roottree/test_BranchAddressManager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
from alphatwirl.roottree import BranchAddressManager
import unittest

##__________________________________________________________________||
hasROOT = False
try:
import ROOT
hasROOT = True
except ImportError:
pass

if hasROOT:
from alphatwirl.roottree import BranchAddressManager

##__________________________________________________________________||
class MockLeaf(object):
def __init__(self, name, typename, leafcount = None, maximum = None):
Expand Down Expand Up @@ -39,6 +49,7 @@ def SetBranchAddress(self, name, address):
self.branchaddress.append((name, address))

##__________________________________________________________________||
@unittest.skipUnless(hasROOT, "has no ROOT")
class TestBranchAddressManager(unittest.TestCase):


Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_import_01.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_with_pandas():
@pytest.mark.skipif(has_no_ROOT, reason="has no ROOT")
def test_with_ROOT():
assert inspect.isclass(alphatwirl.roottree.BEvents)
assert inspect.isclass(alphatwirl.roottree.BranchAddressManager)
assert inspect.isclass(alphatwirl.roottree.BranchAddressManagerForVector)
assert inspect.isclass(alphatwirl.heppyresult.EventBuilder)
assert inspect.isclass(alphatwirl.heppyresult.EventBuilderConfigMaker)
Expand Down Expand Up @@ -57,7 +58,6 @@ def test_classes():
assert inspect.isclass(alphatwirl.configure.TableConfigCompleter)
assert inspect.isclass(alphatwirl.configure.TableFileNameComposer)
assert inspect.isclass(alphatwirl.roottree.Branch)
assert inspect.isclass(alphatwirl.roottree.BranchAddressManager)
assert inspect.isclass(alphatwirl.roottree.Events)
assert inspect.isclass(alphatwirl.heppyresult.Analyzer)
assert inspect.isclass(alphatwirl.heppyresult.Component)
Expand Down

0 comments on commit bdddeae

Please sign in to comment.