Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: preprocess a dataset's base form and hint it to uproot #978

Merged
merged 8 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ classifiers = [
]
dependencies = [
"awkward>=2.5.1",
"uproot>=5.2.0",
"uproot>=5.2.1",
"dask[array]>=2023.4.0",
"dask-awkward>=2024.1.0",
"dask-histogram>=2023.10.0",
Expand Down
8 changes: 8 additions & 0 deletions src/coffea/dataset_tools/apply_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import copy
from typing import Any, Callable, Dict, Hashable, List, Set, Tuple, Union

import awkward
import dask.base
import dask_awkward

Expand All @@ -14,6 +15,7 @@
)
from coffea.nanoevents import BaseSchema, NanoAODSchema, NanoEventsFactory
from coffea.processor import ProcessorABC
from coffea.util import decompress_form

DaskOutputBaseType = Union[
dask.base.DaskMethodsMixin,
Expand Down Expand Up @@ -58,11 +60,15 @@ def apply_to_dataset(
report : dask_awkward.Array, optional
The file access report for running the analysis on the input dataset. Needs to be computed in simultaneously with the analysis to be accurate.
"""
maybe_base_form = dataset.get("form", None)
if maybe_base_form is not None:
maybe_base_form = awkward.forms.from_json(decompress_form(maybe_base_form))
files = dataset["files"]
events = NanoEventsFactory.from_root(
files,
metadata=metadata,
schemaclass=schemaclass,
known_base_form=maybe_base_form,
uproot_options=uproot_options,
).events()

Expand Down Expand Up @@ -113,6 +119,8 @@ def apply_to_fileset(
report = {}
for name, dataset in fileset.items():
metadata = copy.deepcopy(dataset.get("metadata", {}))
if metadata is None:
metadata = {}
metadata.setdefault("dataset", name)
dataset_out = apply_to_dataset(
data_manipulation, dataset, schemaclass, metadata, uproot_options
Expand Down
88 changes: 82 additions & 6 deletions src/coffea/dataset_tools/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from __future__ import annotations

import copy
import gzip
import hashlib
import math
from dataclasses import dataclass
from typing import Any, Dict, Hashable
from typing import Any, Callable, Dict, Hashable

import awkward
import dask
import dask_awkward
import numpy
import uproot

from coffea.util import compress_form


def _get_steps(
normed_files: awkward.Array | dask_awkward.Array,
Expand All @@ -19,6 +23,7 @@ def _get_steps(
recalculate_seen_steps: bool = False,
skip_bad_files: bool = False,
file_exceptions: Exception | Warning | tuple[Exception | Warning] = (OSError,),
calculate_form: bool = False,
lgray marked this conversation as resolved.
Show resolved Hide resolved
) -> awkward.Array | dask_awkward.Array:
"""
Given a list of normalized file and object paths (defined in uproot), determine the steps for each file according to the supplied processing options.
Expand All @@ -36,8 +41,10 @@ def _get_steps(
of only recalculating the steps if the uuid has changed.
skip_bad_files: bool, False
Instead of failing, catch exceptions specified by file_exceptions and return null data.
file_exceptions: Exception | Warning | tuple[Exception | Warning], default (FileNotFoundError, OSError)
file_exceptions: Exception | Warning | tuple[Exception | Warning], default (OSError,)
What exceptions to catch when skipping bad files.
calculate_form: bool, default False
Extract the form of the TTree from the file so we can skip opening files later.

Returns
-------
Expand All @@ -61,6 +68,16 @@ def _get_steps(
tree = the_file[arg.object_path]
num_entries = tree.num_entries

form_json = None
form_hash = None
if calculate_form:
form_bytes = (
uproot.dask(tree, ak_add_doc=True).layout.form.to_json().encode("utf-8")
)

form_hash = hashlib.blake2b(form_bytes).hexdigest()[:16]
lgray marked this conversation as resolved.
Show resolved Hide resolved
form_json = gzip.compress(form_bytes)

target_step_size = num_entries if maybe_step_size is None else maybe_step_size

file_uuid = str(the_file.file.uuid)
Expand Down Expand Up @@ -101,13 +118,22 @@ def _get_steps(
"object_path": arg.object_path,
"steps": out_steps,
"uuid": out_uuid,
"form": form_json,
"form_hash": form_hash,
lgray marked this conversation as resolved.
Show resolved Hide resolved
}
)

if len(array) == 0:
array = awkward.Array(
[
{"file": "junk", "object_path": "junk", "steps": [[]], "uuid": "junk"},
{
"file": "junk",
"object_path": "junk",
"steps": [[]],
"uuid": "junk",
"form": "junk",
"form_hash": "junk",
},
None,
]
)
Expand Down Expand Up @@ -147,12 +173,14 @@ class DatasetSpecOptional:
dict[str, str] | list[str] | dict[str, UprootFileSpec | CoffeaFileSpecOptional]
)
metadata: dict[Hashable, Any] | None
form: str | None
lgray marked this conversation as resolved.
Show resolved Hide resolved


@dataclass
class DatasetSpec:
files: dict[str, CoffeaFileSpec]
metadata: dict[Hashable, Any] | None
form: str | None
lgray marked this conversation as resolved.
Show resolved Hide resolved


FilesetSpecOptional = Dict[str, DatasetSpecOptional]
Expand All @@ -167,6 +195,8 @@ def preprocess(
files_per_batch: int = 1,
skip_bad_files: bool = False,
file_exceptions: Exception | Warning | tuple[Exception | Warning] = (OSError,),
calculate_form: bool = False,
scheduler: None | Callable | str = None,
) -> tuple[FilesetSpec, FilesetSpecOptional]:
"""
Given a list of normalized file and object paths (defined in uproot), determine the steps for each file according to the supplied processing options.
Expand All @@ -187,7 +217,10 @@ def preprocess(
Instead of failing, catch exceptions specified by file_exceptions and return null data.
file_exceptions: Exception | Warning | tuple[Exception | Warning], default (FileNotFoundError, OSError)
What exceptions to catch when skipping bad files.

calculate_form: bool, default False
Extract the form of the TTree from each file in each dataset, creating the union of the forms over the dataset.
scheduler: None | Callable | str, default None
Specifies the scheduler that dask should use to execute the preprocessing task graph.
Returns
-------
out_available : FilesetSpec
Expand Down Expand Up @@ -229,18 +262,48 @@ def preprocess(
recalculate_seen_steps=recalculate_seen_steps,
skip_bad_files=skip_bad_files,
file_exceptions=file_exceptions,
calculate_form=calculate_form,
)

all_processed_files = dask.compute(files_to_preprocess)[0]
(all_processed_files,) = dask.compute(files_to_preprocess, scheduler=scheduler)

for name, processed_files in all_processed_files.items():
not_form = processed_files[["file", "object_path", "steps", "uuid"]]
lgray marked this conversation as resolved.
Show resolved Hide resolved

forms = processed_files[["form", "form_hash"]][
~awkward.is_none(processed_files.form_hash)
]

_, unique_forms_idx = numpy.unique(
forms.form_hash.to_numpy(), return_index=True
)

dict_forms = []
for form in forms[unique_forms_idx].form:
dict_form = awkward.forms.from_json(
gzip.decompress(form).decode("utf-8")
).to_dict()
fields = dict_form.pop("fields")
dict_form["contents"] = {
field: content for field, content in zip(fields, dict_form["contents"])
}
dict_forms.append(dict_form)

union_form = {}
union_form_jsonstr = None
while len(dict_forms):
form = dict_forms.pop()
union_form.update(form)
if len(union_form) > 0:
union_form_jsonstr = awkward.forms.from_dict(union_form).to_json()

files_available = {
item["file"]: {
"object_path": item["object_path"],
"steps": item["steps"],
"uuid": item["uuid"],
}
for item in awkward.drop_none(processed_files).to_list()
for item in awkward.drop_none(not_form).to_list()
}

files_out = {}
Expand All @@ -257,4 +320,17 @@ def preprocess(
out_updated[name]["files"] = files_out
out_available[name]["files"] = files_available

compressed_union_form = None
if union_form_jsonstr is not None:
compressed_union_form = compress_form(union_form_jsonstr)
out_updated[name]["form"] = compressed_union_form
out_available[name]["form"] = compressed_union_form
else:
out_updated[name]["form"] = None
out_available[name]["form"] = None

if "metadata" not in out_updated:
out_updated[name]["metadata"] = None
out_available[name]["metadata"] = None

return out_available, out_updated
4 changes: 4 additions & 0 deletions src/coffea/nanoevents/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ def from_root(
iteritems_options={},
use_ak_forth=True,
delayed=True,
known_base_form=None,
):
"""Quickly build NanoEvents from a root file

Expand Down Expand Up @@ -281,6 +282,8 @@ def from_root(
Toggle using awkward_forth to interpret branches in root file.
delayed:
Nanoevents will use dask as a backend to construct a delayed task graph representing your analysis.
known_base_form:
If the base form of the input file is known ahead of time we can skip opening a single file and parsing metadata.
"""

if treepath is not uproot._util.unset and not isinstance(
Expand Down Expand Up @@ -328,6 +331,7 @@ def from_root(
ak_add_doc=True,
filter_branch=_remove_not_interpretable,
steps_per_file=steps_per_file,
known_base_form=known_base_form,
**uproot_options,
)

Expand Down
14 changes: 14 additions & 0 deletions src/coffea/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Utility functions

"""
import binascii
import gzip
import hashlib
from typing import Any, List, Optional

Expand Down Expand Up @@ -195,3 +197,15 @@ def rewrap_recordarray(layout, depth, data):
if isinstance(layout, awkward.layout.RecordArray):
return lambda: data
return None


# shorthand for compressing forms
def compress_form(formjson):
return binascii.b2a_base64(
lgray marked this conversation as resolved.
Show resolved Hide resolved
gzip.compress(formjson.encode("utf-8")), newline=False
).decode("ascii")


# shorthand for decompressing forms
def decompress_form(form_compressedb64):
return gzip.decompress(binascii.a2b_base64(form_compressedb64)).decode("utf-8")