Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: use ak.merge_union_of_records to generate input data format #1017

Merged
merged 11 commits into from
Feb 13, 2024
6 changes: 2 additions & 4 deletions src/coffea/dataset_tools/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,8 @@ def slice_files(fileset: FilesetSpec, theslice: Any = slice(None)) -> FilesetSpe

def _default_filter(name_and_spec):
name, spec = name_and_spec
thesteps = spec["steps"]
return thesteps is not None and (
len(thesteps) > 1 or (thesteps[0][1] - thesteps[0][0]) > 0
)
num_entries = spec["num_entries"]
return num_entries is not None and num_entries > 0


def filter_files(
Expand Down
90 changes: 71 additions & 19 deletions src/coffea/dataset_tools/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def get_steps(
"file": arg.file,
"object_path": arg.object_path,
"steps": out_steps,
"num_entries": num_entries,
"uuid": out_uuid,
"form": form_json,
"form_hash_md5": form_hash,
Expand All @@ -159,6 +160,7 @@ def get_steps(
"file": "junk",
"object_path": "junk",
"steps": [[0, 0]],
"num_entries": 0,
"uuid": "junk",
"form": "junk",
"form_hash_md5": "junk",
Expand Down Expand Up @@ -187,12 +189,14 @@ class UprootFileSpec:
@dataclass
class CoffeaFileSpec(UprootFileSpec):
steps: list[list[int]]
num_entries: int
uuid: str


@dataclass
class CoffeaFileSpecOptional(CoffeaFileSpec):
steps: list[list[int]] | None
num_entriees: int | None
uuid: str | None


Expand Down Expand Up @@ -235,11 +239,14 @@ def _normalize_file_info(file_info):
None if not isinstance(maybe_finfo, dict) else maybe_finfo.get("uuid", None)
)
this_file = normed_files[ifile]
this_file += (3 - len(this_file)) * (None,) + (maybe_uuid,)
this_file += (4 - len(this_file)) * (None,) + (maybe_uuid,)
nsmith- marked this conversation as resolved.
Show resolved Hide resolved
normed_files[ifile] = this_file
return normed_files


_trivial_file_fields = {"run", "luminosityBlock", "event"}


def preprocess(
fileset: FilesetSpecOptional,
step_size: None | int = None,
Expand Down Expand Up @@ -295,7 +302,7 @@ def preprocess(
files_to_preprocess = {}
for name, info in fileset.items():
norm_files = _normalize_file_info(info)
fields = ["file", "object_path", "steps", "uuid"]
fields = ["file", "object_path", "steps", "num_entries", "uuid"]
ak_norm_files = awkward.from_iter(norm_files)
ak_norm_files = awkward.Array(
{field: ak_norm_files[str(ifield)] for ifield, field in enumerate(fields)}
Expand All @@ -322,51 +329,96 @@ def preprocess(

for name, processed_files in all_processed_files.items():
processed_files_without_forms = processed_files[
["file", "object_path", "steps", "uuid"]
["file", "object_path", "steps", "num_entries", "uuid"]
]

forms = processed_files[["form", "form_hash_md5"]][
forms = processed_files[["file", "form", "form_hash_md5", "num_entries"]][
~awkward.is_none(processed_files.form_hash_md5)
]

_, unique_forms_idx = numpy.unique(
forms.form_hash_md5.to_numpy(), return_index=True
)

dict_forms = []
for form in forms[unique_forms_idx].form:
dict_form = awkward.forms.from_json(decompress_form(form)).to_dict()
fields = dict_form.pop("fields")
dict_form["contents"] = {
field: content for field, content in zip(fields, dict_form["contents"])
}
dict_forms.append(dict_form)
dataset_forms = []
unique_forms = forms[unique_forms_idx]
for thefile, formstr, num_entries in zip(
unique_forms.file, unique_forms.form, unique_forms.num_entries
):
# skip trivially filled or empty files
form = awkward.forms.from_json(decompress_form(formstr))
if num_entries >= 0 and set(form.fields) != _trivial_file_fields:
dataset_forms.append(form)
else:
warnings.warn(
f"{thefile} has fields {form.fields} and num_entries={num_entries} "
"and has been skipped during form-union determination. You will need "
"to skip this file when processing. You can either manually remove it "
"or, if it is an empty file, dynamically remove it with the function "
"dataset_tools.filter_files which takes the output of preprocess and "
", by default, removes empty files each dataset in a fileset."
)

union_form = {}
union_array = None
union_form_jsonstr = None
while len(dict_forms):
form = dict_forms.pop()
union_form.update(form)
if len(union_form) > 0:
union_form_jsonstr = awkward.forms.from_dict(union_form).to_json()
while len(dataset_forms):
new_array = awkward.Array(dataset_forms.pop().length_zero_array())
if union_array is None:
union_array = new_array
else:
union_array = awkward.to_packed(
awkward.merge_union_of_records(
awkward.concatenate([union_array, new_array]), axis=0
)
)
union_array.layout.parameters.update(new_array.layout.parameters)
if union_array is not None:
union_form = union_array.layout.form

for icontent, content in enumerate(union_form.contents):
if isinstance(content, awkward.forms.IndexedOptionForm):
if (
not isinstance(content.content, awkward.forms.NumpyForm)
or content.content.primitive != "bool"
):
raise ValueError(
"IndexedOptionArrays can only contain NumpyArrays of "
"bools in mergers of flat-tuple-like schemas!"
)
parameters = (
content.content.parameters.copy()
if content.content.parameters is not None
else {}
)
# re-create IndexOptionForm with parameters of lower level array
union_form.contents[icontent] = awkward.forms.IndexedOptionForm(
content.index,
content.content,
parameters=parameters,
form_key=content.form_key,
)

union_form_jsonstr = union_form.to_json()

files_available = {
item["file"]: {
"object_path": item["object_path"],
"steps": item["steps"],
"num_entries": item["num_entries"],
"uuid": item["uuid"],
}
for item in awkward.drop_none(processed_files_without_forms).to_list()
}

files_out = {}
for proc_item, orig_item in zip(
processed_files.to_list(), all_ak_norm_files[name].to_list()
processed_files_without_forms.to_list(), all_ak_norm_files[name].to_list()
):
item = orig_item if proc_item is None else proc_item
files_out[item["file"]] = {
"object_path": item["object_path"],
"steps": item["steps"],
"num_entries": item["num_entries"],
"uuid": item["uuid"],
}

Expand Down
15 changes: 10 additions & 5 deletions src/coffea/nanoevents/mapping/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ def preload_column_source(self, uuid, path_in_source, source):
self._cache[key] = source

@abstractmethod
def get_column_handle(self, columnsource, name):
def get_column_handle(self, columnsource, name, allow_missing):
pass

@abstractmethod
def extract_column(self, columnhandle, start, stop, **kwargs):
def extract_column(self, columnhandle, start, stop, allow_missing, **kwargs):
pass

@classmethod
Expand All @@ -87,16 +87,21 @@ def __getitem__(self, key):
elif node == "!skip":
skip = True
continue
elif node == "!load":
elif node.startswith("!load"):
handle_name = stack.pop()
if self._access_log is not None:
self._access_log.append(handle_name)
allow_missing = node == "!loadallowmissing"
handle = self.get_column_handle(
self._column_source(uuid, treepath), handle_name
self._column_source(uuid, treepath), handle_name, allow_missing
)
stack.append(
self.extract_column(
handle, start, stop, use_ak_forth=self._use_ak_forth
handle,
start,
stop,
allow_missing,
use_ak_forth=self._use_ak_forth,
)
)
elif node.startswith("!"):
Expand Down
47 changes: 44 additions & 3 deletions src/coffea/nanoevents/mapping/uproot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings

import awkward
import numpy
import uproot

from coffea.nanoevents.mapping.base import BaseSourceMapping, UUIDOpener
Expand Down Expand Up @@ -48,6 +49,22 @@ def _lazify_form(form, prefix, docstr=None):
)
if parameters:
form["parameters"] = parameters
elif form["class"] == "IndexedOptionArray":
if (
form["content"]["class"] != "NumpyArray"
or form["content"]["primitive"] != "bool"
):
raise ValueError(
"Only boolean NumpyArrays can be created dynamically if "
"missing in file!"
)
assert prefix.endswith("!load")
form["form_key"] = quote(prefix + "allowmissing,!index")
nsmith- marked this conversation as resolved.
Show resolved Hide resolved
form["content"] = _lazify_form(
form["content"], prefix + "allowmissing,!content", docstr=docstr
)
if parameters:
form["parameters"] = parameters
elif form["class"] == "RecordArray":
newfields, newcontents = [], []
for field, value in zip(form["fields"], form["contents"]):
Expand Down Expand Up @@ -151,21 +168,45 @@ def preload_column_source(self, uuid, path_in_source, source):
key = self.key_root() + tuple_to_key((uuid, path_in_source))
self._cache[key] = source

def get_column_handle(self, columnsource, name):
def get_column_handle(self, columnsource, name, allow_missing):
if allow_missing:
return columnsource[name] if name in columnsource else None
return columnsource[name]

def extract_column(self, columnhandle, start, stop, use_ak_forth=True):
def extract_column(
self, columnhandle, start, stop, allow_missing, use_ak_forth=True
):
# make sure uproot is single-core since our calling context might not be
if allow_missing and columnhandle is None:

return awkward.contents.IndexedOptionArray(
awkward.index.Index64(numpy.full(stop - start, -1, dtype=numpy.int64)),
awkward.contents.NumpyArray(numpy.array([], dtype=bool)),
)
elif not allow_missing and columnhandle is None:
raise RuntimeError(
"Received columnhandle of None when missing column in file is not allowed!"
)

interp = columnhandle.interpretation
interp._forth = use_ak_forth
return columnhandle.array(

the_array = columnhandle.array(
interp,
entry_start=start,
entry_stop=stop,
decompression_executor=uproot.source.futures.TrivialExecutor(),
interpretation_executor=uproot.source.futures.TrivialExecutor(),
)

if allow_missing:
the_array = awkward.contents.IndexedOptionArray(
awkward.index.Index64(numpy.arange(stop - start, dtype=numpy.int64)),
awkward.contents.NumpyArray(the_array),
)

return the_array

def __len__(self):
return self._stop - self._start

Expand Down