Skip to content

Commit

Permalink
fix issue 1
Browse files Browse the repository at this point in the history
  • Loading branch information
aminnj committed Jan 12, 2021
1 parent 9ae2d5f commit e06d247
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 16 deletions.
4 changes: 2 additions & 2 deletions pdroot/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from yahist import Hist1D, Hist2D

from .parse import variables_in_expr, nops_in_expr, to_ak_expr
from .parse import variables_in_expr, nops_in_expr, to_ak_expr, split_expr_on_free_colon

def tree_draw(df, varexp, sel="", **kwargs):
"""
Expand All @@ -41,7 +41,7 @@ def tree_draw_to_array(df, varexp, sel=""):
globalmask = eval(to_ak_expr(sel))

dims = []
for ve in varexp.split(":"):
for ve in split_expr_on_free_colon(varexp):

vals = eval(to_ak_expr(ve))
if sel:
Expand Down
53 changes: 40 additions & 13 deletions pdroot/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,17 @@ class Transformer(ast.NodeTransformer):

# "and" -> "&"
def visit_And(self, node):
ast.NodeVisitor.generic_visit(self, node)
self.generic_visit(node)
return ast.BitAnd()

# "or" -> "|"
def visit_Or(self, node):
ast.NodeVisitor.generic_visit(self, node)
self.generic_visit(node)
return ast.BitOr()

# "not" -> "~"
def visit_Not(self, node):
ast.NodeVisitor.generic_visit(self, node)
self.generic_visit(node)
return ast.Invert()

# "a < b < c" -> "(a < b) and (b < c)"
Expand Down Expand Up @@ -115,36 +115,63 @@ def visit_Call(self, node):
else:
node.func.id = "ak." + node.func.id
node.keywords.append(ast.keyword("axis", ast.Constant(-1)))
ast.NodeVisitor.generic_visit(self, node)
self.generic_visit(node)
return node


# "x[2]" -> "ak.pad_none(x, 3)[:, 2]"
def visit_Subscript(self, node):
if isinstance(node.slice.value, (ast.Constant, ast.Num)):
index = node.slice.value.n
value = node.value
value = ast.Call(func=ast.Name("ak.pad_none"), args=[value, ast.Constant(index+1)], keywords=[])
valid_slice = False
for attr in ["value", "upper", "lower", "step"]:
if isinstance(getattr(node.slice, attr, None), (ast.Constant, ast.Num)): valid_slice = True
if valid_slice:
if hasattr(node.slice, "value"):
index = node.slice.value.n
value = ast.Call(func=ast.Name("ak.pad_none"), args=[node.value, ast.Constant(index+1)], keywords=[])
dimslice = ast.Constant(index)
elif hasattr(node.slice, "upper"):
upper = node.slice.upper.n
value = ast.Call(func=ast.Name("ak.pad_none"), args=[node.value, ast.Constant(upper+1)], keywords=[])
dimslice = node.slice
node = ast.Subscript(
value=value,
slice=ast.ExtSlice(dims=[
ast.Slice(lower=None, upper=None, step=None),
ast.Constant(index)
dimslice,
]),
ctx=ast.Load()
)
ast.NodeVisitor.generic_visit(self, node)
self.generic_visit(node)
return node


def to_ak_expr(expr):
def to_ak_expr(expr, transformer=Transformer()):
"""
turns
expr = "sum(Jet_pt[abs(Jet_eta)>4.])"
into
expr = "ak.sum(Jet_pt[abs(Jet_eta) > 4.0], axis=-1)"
"""
parsed = ast.parse(expr)
Transformer().visit(parsed)
transformer.visit(parsed)
source = astor.to_source(parsed).strip()
return source

def split_expr_on_free_colon(expr):
"""
When splitting on : for the purpose of drawing in 2D,
a simple expr.split(":") won't work if it picks a slice,
so we find a colon which has an equal number of open
and close parentheses/brackets before it.
Input: "sum(Jet_pt[:2]):Jet_eta"
Output: ("sum(Jet_pt[:2])", "Jet_eta")
"""
n_enclosure = 0
for ic, c in enumerate(expr):
if c == "[": n_enclosure += 10
elif c == "]": n_enclosure -= 10
elif c == "(": n_enclosure += 1
elif c == ")": n_enclosure -= 1
elif (c == ":") and (n_enclosure == 0):
return expr[:ic], expr[ic+1:]
return [expr]
9 changes: 8 additions & 1 deletion tests/draw_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,13 @@ def test_jitdraw_1d(self):

class DrawJaggedTest(unittest.TestCase):

def drawclose(self, varexp, sel, y):
def drawclose(self, varexp, sel, y, verbose=False):
x = tree_draw_to_array(self.df, varexp, sel)
x = np.array(x)
y = np.array(y)
if verbose:
print("true", x)
print("test", y)
self.assertEqual(x.shape, y.shape)
self.assertTrue(np.allclose(x, y))

Expand Down Expand Up @@ -140,6 +143,10 @@ def test_negation(self):
self.drawclose("Jet_pt", "not(14. < Jet_pt < 16.)", [42, 10.5, 11.5, 50, 5])
self.drawclose("Jet_pt", "~(14. < Jet_pt < 16.)", [42, 10.5, 11.5, 50, 5])

def test_slicing(self):
self.drawclose("sum(Jet_pt[:2])", "", [42+15, 0, 11.5, 50+5])
self.drawclose("sum(Jet_pt[2:3])", "MET_pt > 40", [10.5, 0.])



if __name__ == "__main__":
Expand Down

0 comments on commit e06d247

Please sign in to comment.