-
Notifications
You must be signed in to change notification settings - Fork 120
/
util.py
170 lines (139 loc) · 5.02 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""Utility functions
"""
from typing import List, Optional, Any
import awkward
import hashlib
import numpy
import numba
import coffea
from rich.progress import (
Progress,
BarColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
Column,
ProgressColumn,
Text,
)
ak = awkward
np = numpy
nb = numba
import lz4.frame
import cloudpickle
import warnings
from functools import partial
def load(filename):
"""Load a coffea file from disk"""
with lz4.frame.open(filename) as fin:
output = cloudpickle.load(fin)
return output
def save(output, filename):
"""Save a coffea object or collection thereof to disk
This function can accept any picklable object. Suggested suffix: ``.coffea``
"""
with lz4.frame.open(filename, "wb") as fout:
thepickle = cloudpickle.dumps(output)
fout.write(thepickle)
def _hex(string):
try:
return string.hex()
except AttributeError:
return "".join("{:02x}".format(ord(c)) for c in string)
def _ascii(maybebytes):
try:
return maybebytes.decode("ascii")
except AttributeError:
return maybebytes
def _hash(items):
# python 3.3 salts hash(), we want it to persist across processes
x = hashlib.md5(bytes(";".join(str(x) for x in items), "ascii"))
return int(x.hexdigest()[:16], base=16)
def _ensure_flat(array, allow_missing=False):
"""Normalize an array to a flat numpy array or raise ValueError"""
if not isinstance(array, (ak.Array, numpy.ndarray)):
raise ValueError("Expected a numpy or awkward array, received: %r" % array)
aktype = ak.type(array)
if not isinstance(aktype, ak.types.ArrayType):
raise ValueError("Expected an array type, received: %r" % aktype)
isprimitive = isinstance(aktype.type, ak.types.PrimitiveType)
isoptionprimitive = isinstance(aktype.type, ak.types.OptionType) and isinstance(
aktype.type.type, ak.types.PrimitiveType
)
if allow_missing and not (isprimitive or isoptionprimitive):
raise ValueError(
"Expected an array of type N * primitive or N * ?primitive, received: %r"
% aktype
)
if not (allow_missing or isprimitive):
raise ValueError(
"Expected an array of type N * primitive, received: %r" % aktype
)
if isinstance(array, ak.Array):
array = ak.to_numpy(array, allow_missing=allow_missing)
return array
def _exception_chain(exc: BaseException) -> List[BaseException]:
"""Retrieves the entire exception chain as a list."""
ret = []
while isinstance(exc, BaseException):
ret.append(exc)
exc = exc.__cause__
return ret
class SpeedColumn(ProgressColumn):
"""Renders human readable transfer speed."""
def __init__(self, fmt: str = ".1f", table_column: Optional[Column] = None):
self.fmt = fmt
super().__init__(table_column=table_column)
def render(self, task: Any) -> Text:
"""Show data transfer speed."""
speed = task.finished_speed or task.speed
if speed is None:
return Text("?", style="progress.data.speed")
return Text(f"{speed:{self.fmt}}", style="progress.data.speed")
def rich_bar():
return Progress(
TextColumn("[bold blue]{task.description}", justify="right"),
"[progress.percentage]{task.percentage:>3.0f}%",
BarColumn(bar_width=None),
TextColumn(
"[bold blue][progress.completed]{task.completed}/{task.total}",
justify="right",
),
"[",
TimeElapsedColumn(),
"<",
TimeRemainingColumn(),
"|",
SpeedColumn(".1f"),
TextColumn("[progress.data.speed]{task.fields[unit]}/s", justify="right"),
"]",
auto_refresh=False,
)
# lifted from awkward - https://github.com/scikit-hep/awkward-1.0/blob/5fe31a916bf30df6c2ea10d4094f6f1aefcf3d0c/src/awkward/_util.py#L47-L61 # noqa
# drive our deprecations-as-errors as with awkward
def deprecate(exception, version, date=None):
if coffea.deprecations_as_errors:
raise exception
else:
if date is None:
date = ""
else:
date = " (target date: " + date + ")"
message = """In coffea version {0}{1}, this will be an error.
(Set coffea.deprecations_as_errors = True to get a stack trace now.)
{2}: {3}""".format(
version, date, type(exception).__name__, str(exception)
)
warnings.warn(message, FutureWarning)
# re-nest a record array into a ListArray
def awkward_rewrap(arr, like_what, gfunc):
behavior = awkward._util.behaviorof(like_what)
func = partial(gfunc, data=arr.layout)
layout = awkward.operations.convert.to_layout(like_what)
newlayout = awkward._util.recursively_apply(layout, func)
return awkward._util.wrap(newlayout, behavior=behavior)
# we're gonna assume that the first record array we encounter is the flattened data
def rewrap_recordarray(layout, depth, data):
if isinstance(layout, awkward.layout.RecordArray):
return lambda: data
return None