Skip to content

Commit

Permalink
support single dimension min/max
Browse files Browse the repository at this point in the history
  • Loading branch information
aminnj committed Jan 20, 2021
1 parent 0e72566 commit d1352e0
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 62 deletions.
14 changes: 10 additions & 4 deletions pdroot/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,17 @@ def visit_Call(self, node):
if hasattr(node.func, "id"):
name = node.func.id
if name in ["min", "max", "sum", "mean", "length"]:
if name == "length":
node.func.id = "ak.count"
else:
if len(node.args) == 1:
if name == "length":
node.func.id = "count"
node.func.id = "ak." + node.func.id
node.keywords.append(ast.keyword("axis", ast.Constant(-1)))
node.keywords.append(ast.keyword("axis", ast.Constant(-1)))
elif (len(node.args) == 2) and name in ["min", "max"]:
node.func.id = {"min": "np.minimum", "max": "np.maximum"}[name]
else:
raise Exception(
f"Unsupported function '{name}' with {len(node.args)} arguments."
)
self.generic_visit(node)
return node

Expand Down
2 changes: 1 addition & 1 deletion pdroot/readwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def read_root(
treename = find_tree_name(f)
if treename is None:
raise RuntimeError(
"`treename` must be specified. File contains keys: {treenames}"
f"`treename` must be specified. File contains keys: {f.keys()}"
)

executor = None
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from codecs import open
from os import path

__version__ = "1.6.6"
__version__ = "1.6.7"

here = path.abspath(path.dirname(__file__))

Expand Down
119 changes: 63 additions & 56 deletions tests/test_draw.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from pdroot import tree_draw
from pdroot.draw import tree_draw
from pdroot.readwrite import awkward1_arrays_to_dataframe
from pdroot.draw import tree_draw_to_array

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -101,64 +100,72 @@ def test_draw_to_hist2d(df_jagged):
assert h.integral == 4


@pytest.mark.parametrize(
"varexp,sel,expected",
[
("Jet_pt", "", [42.0, 15.0, 10.5, 11.5, 50.0, 5.0]),
("Jet_pt", "abs(Jet_eta) > 1 and MET_pt > 10", [42.0, 11.5]),
("Jet_pt", "MET_pt > 40", [42, 15, 10.5, 11.5]),
("MET_pt", "MET_pt > 40", [46.5, 82.0]),
("Jet_pt", "Jet_pt > 40", [42.0, 50.0]),
("MET_pt", "", [46.5, 30.0, 82.0, 8.9]),
("Jet_pt", "", [42.0, 15.0, 10.5, 11.5, 50.0, 5.0]),
("Jet_eta + 1", "Jet_pt > 40 and MET_pt > 40", [-1.2]),
("Jet_eta", "Jet_pt > 40 and MET_pt > 40", [-2.2]),
("sum(Jet_pt)", "MET_pt < 10", [50 + 5]),
("length(Jet_pt)", "MET_pt < 10", [2]),
("length(Jet_pt)", "", [3, 0, 1, 2]),
("mean(Jet_pt)", "", [1.0 / 3 * (42 + 15 + 10.5), 11.5, 0.5 * (50 + 5)]),
("min(Jet_pt)", "", [10.5, 11.5, 5.0]),
("max(abs(Jet_eta))", "MET_pt > 80", [1.5]),
("max(abs(Jet_eta))", "", [2.2, 1.5, 3.0]),
("Jet_pt[0]:Jet_pt[1]", "MET_pt > 40", [[42, 15]]),
("Jet_pt[0]:Jet_pt[1]", "", [[42, 15], [50, 5]]),
("Jet_pt[2]", "", [10.5]),
("Jet_pt[Jet_pt>25]", "", [42, 50]),
("sum(Jet_pt[abs(Jet_eta)<2.0])", "", [15 + 10.5, 0.0, 11.5, 50.0]),
("sum(Jet_pt>10)", "MET_pt>40", [3, 1]),
("np.exp(sum(Jet_pt>10))", "MET_pt>40", [np.exp(3), np.exp(1)]),
("Jet_pt", "(14. < Jet_pt) & (Jet_pt < 16.)", [15]),
("Jet_pt", "(14. < Jet_pt) and (Jet_pt < 16.)", [15]),
("Jet_pt", "(14. < Jet_pt < 16.)", [15]),
("not MET_pt>40", "", [False, True, False, True]),
("Jet_pt", "~(14. < Jet_pt < 16.)", [42, 10.5, 11.5, 50, 5]),
("Jet_pt", "not(14. < Jet_pt < 16.)", [42, 10.5, 11.5, 50, 5]),
("(MET_pt>30) and True", "", [True, False, True, False]),
("(MET_pt>30) or True", "", [True, True, True, True]),
("(MET_pt>30) and False", "", [False, False, False, False]),
("(MET_pt>30) or False", "", [True, False, True, False]),
("length(Jet_pt) == 2 and sum(Jet_pt) > 40", "", [False, False, False, True]),
("sum(Jet_pt[(Jet_pt>40) and abs(Jet_eta)<2.4])", "MET_pt > 40", [42, 0]),
("sum(Jet_pt[(Jet_pt>40) and abs(Jet_eta)<2.4])", "", [42, 0, 0, 50]),
("sum(Jet_pt[2:3])", "MET_pt > 40", [10.5, 0.0]),
("sum(Jet_pt[:2])", "", [42 + 15, 0, 11.5, 50 + 5]),
(
"Jet_pt:Jet_eta",
"MET_pt > 40.",
[[42, -2.2], [15, 0.4], [10.5, 0.5], [11.5, 1.5]],
),
(
"(MET_pt>40) and sum((Jet_pt>40) and (abs(Jet_eta)<2.4)) >= 1",
"",
[True, False, False, False],
),
],
)
cases = [
("Jet_pt", "", [42.0, 15.0, 10.5, 11.5, 50.0, 5.0]),
("Jet_pt", "abs(Jet_eta) > 1 and MET_pt > 10", [42.0, 11.5]),
("Jet_pt", "MET_pt > 40", [42, 15, 10.5, 11.5]),
("MET_pt", "MET_pt > 40", [46.5, 82.0]),
("Jet_pt", "Jet_pt > 40", [42.0, 50.0]),
("MET_pt", "", [46.5, 30.0, 82.0, 8.9]),
("Jet_pt", "", [42.0, 15.0, 10.5, 11.5, 50.0, 5.0]),
("Jet_eta + 1", "Jet_pt > 40 and MET_pt > 40", [-1.2]),
("Jet_eta", "Jet_pt > 40 and MET_pt > 40", [-2.2]),
("sum(Jet_pt)", "MET_pt < 10", [50 + 5]),
("length(Jet_pt)", "MET_pt < 10", [2]),
("length(Jet_pt)", "", [3, 0, 1, 2]),
("mean(Jet_pt)", "", [1.0 / 3 * (42 + 15 + 10.5), 11.5, 0.5 * (50 + 5)]),
("min(Jet_pt)", "", [10.5, 11.5, 5.0]),
("max(abs(Jet_eta))", "MET_pt > 80", [1.5]),
("max(abs(Jet_eta))", "", [2.2, 1.5, 3.0]),
("Jet_pt[0]:Jet_pt[1]", "MET_pt > 40", [[42, 15]]),
("Jet_pt[0]:Jet_pt[1]", "", [[42, 15], [50, 5]]),
("Jet_pt[2]", "", [10.5]),
("Jet_pt[Jet_pt>25]", "", [42, 50]),
("sum(Jet_pt[abs(Jet_eta)<2.0])", "", [15 + 10.5, 0.0, 11.5, 50.0]),
("sum(Jet_pt>10)", "MET_pt>40", [3, 1]),
("np.exp(sum(Jet_pt>10))", "MET_pt>40", [np.exp(3), np.exp(1)]),
("Jet_pt", "(14. < Jet_pt) & (Jet_pt < 16.)", [15]),
("Jet_pt", "(14. < Jet_pt) and (Jet_pt < 16.)", [15]),
("Jet_pt", "(14. < Jet_pt < 16.)", [15]),
("not MET_pt>40", "", [False, True, False, True]),
("Jet_pt", "~(14. < Jet_pt < 16.)", [42, 10.5, 11.5, 50, 5]),
("Jet_pt", "not(14. < Jet_pt < 16.)", [42, 10.5, 11.5, 50, 5]),
("(MET_pt>30) and True", "", [True, False, True, False]),
("(MET_pt>30) or True", "", [True, True, True, True]),
("(MET_pt>30) and False", "", [False, False, False, False]),
("(MET_pt>30) or False", "", [True, False, True, False]),
("length(Jet_pt) == 2 and sum(Jet_pt) > 40", "", [False, False, False, True]),
("sum(Jet_pt[(Jet_pt>40) and abs(Jet_eta)<2.4])", "MET_pt > 40", [42, 0]),
("sum(Jet_pt[(Jet_pt>40) and abs(Jet_eta)<2.4])", "", [42, 0, 0, 50]),
("sum(Jet_pt[2:3])", "MET_pt > 40", [10.5, 0.0]),
("sum(Jet_pt[:2])", "", [42 + 15, 0, 11.5, 50 + 5]),
("Jet_pt[Jet_pt>40][0] + Jet_eta[2]", "", [42.5]),
("min(min(Jet_pt), min(Jet_eta))", "MET_pt > 40", [-2.2, 1.5]),
("min(min(Jet_pt), min(Jet_eta))", "", [-2.2, 1.5, -3]),
("max(min(Jet_pt), min(Jet_eta))", "", [10.5, 11.5, 5]),
("min(MET_pt, MET_pt+1)", "", [46.5, 30, 82, 8.9]),
("max(MET_pt, MET_pt+1)", "", [47.5, 31, 83, 9.9]),
("max(min(Jet_pt), MET_pt*2)", "", [93, 164, 17.8]),
(
"Jet_pt:Jet_eta",
"MET_pt > 40.",
[[42, -2.2], [15, 0.4], [10.5, 0.5], [11.5, 1.5]],
),
(
"(MET_pt>40) and sum((Jet_pt>40) and (abs(Jet_eta)<2.4)) >= 1",
"",
[True, False, False, False],
),
]


@pytest.mark.parametrize("varexp,sel,expected", cases)
def test_draw(df_jagged, varexp, sel, expected):
x = tree_draw_to_array(df_jagged, varexp, sel)
x = tree_draw(df_jagged, varexp, sel, to_array=True)
x = np.array(x)
y = np.array(expected)
np.testing.assert_allclose(x, y)


if __name__ == "__main__":
pytest.main(["--capture=no", __file__])

0 comments on commit d1352e0

Please sign in to comment.