Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add get_zarr method to context #540

Merged
merged 23 commits into from
Oct 14, 2021
Merged
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
50 changes: 50 additions & 0 deletions strax/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -1353,6 +1353,56 @@ def get_df(self, run_id: ty.Union[str, tuple, list],
f"array fields. Please use get_array.")
raise


def get_zarr(self, run_ids, targets, storage='./strax_temp_data',
jmosbacher marked this conversation as resolved.
Show resolved Hide resolved
progress_bar=False, overwrite=True, **kwargs):
jmosbacher marked this conversation as resolved.
Show resolved Hide resolved
"""get perisistant arrays using zarr. This is useful when
loading large amounts of data that cannot fit in memory
zarr is very compatible with dask.
Targets are loaded into separate arrays and runs are merged.
the data is added to any existing data in the storage location.

:param run_ids: (Iterable) Run ids you wish to load.
:param targets: (Iterable) targets to load.
:param storage: (str, optional) fsspec path to store array. Defaults to './strax_temp_data'.
:param overwrite: (boolean, optional) whether to overwrite existing arrays for targets at given path.

:returns zarr.Group: zarr group containing the persistant arrays available at
the storage location after loading the requested data
the runs loaded into a given array can be seen in the
array .attrs['RUNS'] field
"""
import zarr
context_hash = self._context_hash()
kwargs_hash = strax.deterministic_hash(kwargs)
root = zarr.open(storage, mode='w')
group = root.create_group(context_hash+'/'+kwargs_hash, overwrite=overwrite)
for target in strax.to_str_tuple(targets):
idx = 0
z = None
jmosbacher marked this conversation as resolved.
Show resolved Hide resolved
if target in group:
z = group[target]
if not overwrite:
idx = z.size
INSERTED = {}
for run_id in strax.to_str_tuple(run_ids):
if z is not None and run_id in z.attrs.get('RUNS', {}):
continue
key = self.key_for(run_id, target)
INSERTED[run_id] = dict(start_idx=idx, end_idx=idx, lineage_hash=key.lineage_hash)
for chunk in self.get_iter(run_id, target, progress_bar=progress_bar, **kwargs):
end_idx = idx+chunk.data.size
if z is None:
dtype = [(d[0][1], )+d[1:] for d in chunk.dtype.descr]
z = group.create_dataset(target, shape=end_idx, dtype=dtype)
else:
z.resize(end_idx)
z[idx:end_idx] = chunk.data
idx = end_idx
INSERTED[run_id]['end_idx'] = end_idx
z.attrs['RUNS'] = dict(z.attrs.get('RUNS', {}), **INSERTED)
return group

def key_for(self, run_id, target):
"""
Get the DataKey for a given run and a given target plugin. The
Expand Down