/
misc_utils.py
292 lines (227 loc) · 8.39 KB
/
misc_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
import copy
import functools
import hashlib
import inspect
import json
import math
import os
import random
import subprocess
import urllib
import urllib.request
from collections import Counter
from contextlib import contextmanager
from typing import Sequence, List, Optional, Tuple, Hashable
import filelock
import numpy as np
import torch
from scipy.special import comb
from allenact.utils.system import get_logger
TABLEAU10_RGB = (
(31, 119, 180),
(255, 127, 14),
(44, 160, 44),
(214, 39, 40),
(148, 103, 189),
(140, 86, 75),
(227, 119, 194),
(127, 127, 127),
(188, 189, 34),
(23, 190, 207),
)
def multiprocessing_safe_download_file_from_url(url: str, save_path: str):
with filelock.FileLock(save_path + ".lock"):
if not os.path.isfile(save_path):
get_logger().info(f"Downloading file from {url} to {save_path}.")
urllib.request.urlretrieve(
url, save_path,
)
else:
get_logger().debug(f"{save_path} exists - skipping download.")
def experimental_api(to_decorate):
"""Decorate a function to note that it is part of the experimental API."""
have_warned = [False]
name = f"{inspect.getmodule(to_decorate).__name__}.{to_decorate.__qualname__}"
if to_decorate.__name__ == "__init__":
name = name.replace(".__init__", "")
@functools.wraps(to_decorate)
def decorated(*args, **kwargs):
if not have_warned[0]:
get_logger().warning(
f"'{name}' is a part of AllenAct's experimental API."
f" This means: (1) there are likely bugs present and (2)"
f" we may remove/change this functionality without warning."
f" USE AT YOUR OWN RISK.",
)
have_warned[0] = True
return to_decorate(*args, **kwargs)
return decorated
def deprecated(to_decorate):
"""Decorate a function to note that it has been deprecated."""
have_warned = [False]
name = f"{inspect.getmodule(to_decorate).__name__}.{to_decorate.__qualname__}"
if to_decorate.__name__ == "__init__":
name = name.replace(".__init__", "")
@functools.wraps(to_decorate)
def decorated(*args, **kwargs):
if not have_warned[0]:
get_logger().warning(
f"'{name}' has been deprecated and will soon be removed from AllenAct's API."
f" Please discontinue your use of this function.",
)
have_warned[0] = True
return to_decorate(*args, **kwargs)
return decorated
class NumpyJSONEncoder(json.JSONEncoder):
"""JSON encoder for numpy objects.
Based off the stackoverflow answer by Jie Yang here: https://stackoverflow.com/a/57915246.
The license for this code is [BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/).
"""
def default(self, obj):
if isinstance(obj, np.void):
return None
elif isinstance(obj, np.bool):
return bool(obj)
elif isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
else:
return super(NumpyJSONEncoder, self).default(obj)
@contextmanager
def tensor_print_options(**print_opts):
torch_print_opts = copy.deepcopy(torch._tensor_str.PRINT_OPTS)
np_print_opts = np.get_printoptions()
try:
torch.set_printoptions(**print_opts)
np.set_printoptions(**print_opts)
yield None
finally:
torch.set_printoptions(**{k: getattr(torch_print_opts, k) for k in print_opts})
np.set_printoptions(**np_print_opts)
def md5_hash_str_as_int(to_hash: str):
return int(hashlib.md5(to_hash.encode()).hexdigest(), 16,)
def get_git_diff_of_project() -> Tuple[str, str]:
short_sha = (
subprocess.check_output(["git", "describe", "--always"]).decode("utf-8").strip()
)
diff = subprocess.check_output(["git", "diff", short_sha]).decode("utf-8")
return short_sha, diff
class HashableDict(dict):
"""A dictionary which is hashable so long as all of its values are
hashable.
A HashableDict object will allow setting / deleting of items until
the first time that `__hash__()` is called on it after which
attempts to set or delete items will throw `RuntimeError`
exceptions.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._hash_has_been_called = False
def __key(self):
return tuple((k, self[k]) for k in sorted(self))
def __hash__(self):
self._hash_has_been_called = True
return hash(self.__key())
def __eq__(self, other):
return self.__key() == other.__key()
def __setitem__(self, *args, **kwargs):
if not self._hash_has_been_called:
return super(HashableDict, self).__setitem__(*args, **kwargs)
raise RuntimeError("Cannot set item in HashableDict after having called hash.")
def __delitem__(self, *args, **kwargs):
if not self._hash_has_been_called:
return super(HashableDict, self).__delitem__(*args, **kwargs)
raise RuntimeError(
"Cannot delete item in HashableDict after having called hash."
)
def partition_sequence(seq: Sequence, parts: int) -> List:
assert 0 < parts, f"parts [{parts}] must be greater > 0"
assert parts <= len(seq), f"parts [{parts}] > len(seq) [{len(seq)}]"
n = len(seq)
quotient = n // parts
remainder = n % parts
counts = [quotient + (i < remainder) for i in range(parts)]
inds = np.cumsum([0] + counts)
return [seq[ind0:ind1] for ind0, ind1 in zip(inds[:-1], inds[1:])]
def uninterleave(seq: Sequence, parts: int) -> List:
assert 0 < parts <= len(seq)
n = len(seq)
quotient = n // parts
return [
[seq[i + j * parts] for j in range(quotient + 1) if i + j * parts < len(seq)]
for i in range(parts)
]
@functools.lru_cache(10000)
def cached_comb(n: int, m: int):
return comb(n, m)
def expected_max_of_subset_statistic(vals: List[float], m: int):
n = len(vals)
assert m <= n
vals_and_counts = list(Counter([round(val, 8) for val in vals]).items())
vals_and_counts.sort()
count_so_far = 0
logdenom = math.log(comb(n, m))
expected_max = 0.0
for val, num_occurances_of_val in vals_and_counts:
count_so_far += num_occurances_of_val
if count_so_far < m:
continue
count_where_max = 0
for i in range(1, min(num_occurances_of_val, m) + 1):
count_where_max += cached_comb(num_occurances_of_val, i) * cached_comb(
count_so_far - num_occurances_of_val, m - i
)
expected_max += val * math.exp(math.log(count_where_max) - logdenom)
return expected_max
def bootstrap_max_of_subset_statistic(
vals: List[float], m: int, reps=1000, seed: Optional[int] = None
):
rstate = None
if seed is not None:
rstate = random.getstate()
random.seed(seed)
results = []
for _ in range(reps):
results.append(
expected_max_of_subset_statistic(random.choices(vals, k=len(vals)), m)
)
if seed is not None:
random.setstate(rstate)
return results
def rand_float(low: float, high: float, shape):
assert low <= high
try:
return np.random.rand(*shape) * (high - low) + low
except TypeError as _:
return np.random.rand(shape) * (high - low) + low
def all_unique(seq: Sequence[Hashable]):
seen = set()
for s in seq:
if s in seen:
return False
seen.add(s)
return True
def all_equal(s: Sequence):
if len(s) <= 1:
return True
return all(s[0] == ss for ss in s[1:])
def prepare_locals_for_super(local_vars, args_name="args", kwargs_name="kwargs"):
assert (
args_name not in local_vars
), "`prepare_locals_for_super` does not support {}.".format(args_name)
new_locals = {k: v for k, v in local_vars.items() if k != "self" and "__" not in k}
if kwargs_name in new_locals:
kwargs = new_locals[kwargs_name]
del new_locals[kwargs_name]
kwargs.update(new_locals)
new_locals = kwargs
return new_locals
def partition_limits(num_items: int, num_parts: int):
return (
np.round(np.linspace(0, num_items, num_parts + 1, endpoint=True))
.astype(np.int32)
.tolist()
)