Skip to content

Commit

Permalink
Merge pull request #843 from lgray/ressurect_triton_tests
Browse files Browse the repository at this point in the history
fix: adjustments to callable wrap to deal with typetracers in nested python structures
  • Loading branch information
lgray committed Jun 22, 2023
2 parents bd283f5 + 3c42daf commit 0a72c8b
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 22 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ dependencies = [
"awkward>=2.2.3",
"uproot>=5.0.7",
"dask[array]>=2023.4.0",
"dask-awkward>=2023.6.1",
"dask-awkward>=2023.6.3",
"dask-histogram>=2023.6.0",
"correctionlib>=2.0.0",
"pyarrow>=6.0.0",
Expand Down
41 changes: 23 additions & 18 deletions src/coffea/ml_tools/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,20 +94,31 @@ def __init__(self, method_map, default_conv=None):
if self.default_conv is None:
self.default_conv = self.unrecognized

def convert(self, arg):
def _convert(self, arg, maybe_backends):
if isinstance(arg, Dict):
return dict({key: self.convert(val) for key, val in arg.items()})
return dict(
{key: self.convert(val, maybe_backends) for key, val in arg.items()}
)
elif isinstance(arg, (List, Set, Tuple)):
return arg.__class__(self.convert(val) for val in arg)
return arg.__class__(self.convert(val, maybe_backends) for val in arg)
else:
for itype, call in self.method_map.items():
if isinstance(arg, itype):
if maybe_backends is not None and itype is awkward.highlevel.Array:
maybe_backends.add(awkward.backend(arg))
return call(arg)

return self.default_conv(arg)

def convert(self, arg, maybe_backends=None):
out = self._convert(arg, maybe_backends)
return out

def __call__(self, *args, **kwargs) -> Tuple:
return self.convert(args), self.convert(kwargs)
backends = set()
out_args = self.convert(args, backends)
out_kwargs = self.convert(kwargs, backends)
return (out_args, out_kwargs), backends

@staticmethod
def no_action(x):
Expand Down Expand Up @@ -223,7 +234,7 @@ def _call_awkward(self, *args, **kwargs):
then numpy_to_awkward conversion.
"""
ak_args, ak_kwargs = self.prepare_awkward(*args, **kwargs)
np_args, np_kwargs = self._ak_to_np_(*ak_args, **ak_kwargs)
(np_args, np_kwargs), _ = self._ak_to_np_(*ak_args, **ak_kwargs)
np_rets = self._call_numpy(*np_args, **np_kwargs)
np_rets = self._np_to_ak_.convert(np_rets)
return self.postprocess_awkward(np_rets, *args, **kwargs)
Expand Down Expand Up @@ -308,29 +319,23 @@ def __call__(self, *args):
# arrays
ak_args, ak_kwargs = self.args_to_pair(*args)

if self.get_backend(*args) == "typetracer":
# Length-0 conversion will not work! Must use length-1 method.
conv = container_converter(
{
awkward.Array: lambda x: awkward.Array(
x.layout.form.length_one_array(highlevel=False),
behavior=x.behavior,
)
}
)
conv = container_converter(
{awkward.Array: awkward.typetracer.length_one_if_typetracer},
default_conv=container_converter.no_action,
)

ak_args, ak_kwargs = conv(*ak_args, **ak_kwargs)
(ak_args, ak_kwargs), backends = conv(*ak_args, **ak_kwargs)

# Converting to numpy
np_args, np_kwargs = numpy_call_wrapper._ak_to_np_(
(np_args, np_kwargs), _ = numpy_call_wrapper._ak_to_np_(
*ak_args, **ak_kwargs
)
out = self.wrapper._call_numpy(*np_args, **np_kwargs)
out = self.wrapper._np_to_ak_.convert(out)

# Additional packing
out = pack_ret_array(out)
if self.get_backend(*args) == "typetracer":
if "typetracer" in backends:
out = awkward.Array(
out.layout.to_typetracer(forget_length=True),
behavior=out.behavior,
Expand Down
3 changes: 0 additions & 3 deletions tests/test_ml_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,6 @@ def my_pad(arr):
}


@pytest.mark.skip(
reason="triton requires nested args / kwargs in dask_awkward, not there yet"
)
def test_triton():
_ = pytest.importorskip("tritonclient")

Expand Down

0 comments on commit 0a72c8b

Please sign in to comment.