This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
util.py
644 lines (524 loc) · 21.7 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
"""
Various utilities that don't fit anywhere else.
"""
import importlib
import json
import logging
import os
import pkgutil
import random
import subprocess
import sys
from contextlib import contextmanager
from itertools import islice, zip_longest
from logging import Filter
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
Generator,
Iterable,
Iterator,
List,
Optional,
Tuple,
TypeVar,
Union,
)
import numpy
import spacy
import torch
import torch.distributed as dist
from spacy.cli.download import download as spacy_download
from spacy.language import Language as SpacyModelType
from allennlp.common.checks import log_pytorch_version_info
from allennlp.common.params import Params
from allennlp.common.tqdm import Tqdm
try:
import resource
except ImportError:
# resource doesn't exist on Windows systems
resource = None
logger = logging.getLogger(__name__)
JsonDict = Dict[str, Any]
# If you want to have start and/or end symbols for any reason in your code, we recommend you use
# these, to have a common place to import from. Also, it's important for some edge cases in how
# data is processed for these symbols to be lowercase, not uppercase (because we have code that
# will lowercase tokens for you in some circumstances, and we need this symbol to not change in
# those cases).
START_SYMBOL = "@start@"
END_SYMBOL = "@end@"
PathType = Union[os.PathLike, str]
T = TypeVar("T")
ContextManagerFunctionReturnType = Generator[T, None, None]
def sanitize(x: Any) -> Any:
"""
Sanitize turns PyTorch and Numpy types into basic Python types so they
can be serialized into JSON.
"""
# Import here to avoid circular references
from allennlp.data.tokenizers.token import Token
if isinstance(x, (str, float, int, bool)):
# x is already serializable
return x
elif isinstance(x, torch.Tensor):
# tensor needs to be converted to a list (and moved to cpu if necessary)
return x.cpu().tolist()
elif isinstance(x, numpy.ndarray):
# array needs to be converted to a list
return x.tolist()
elif isinstance(x, numpy.number):
# NumPy numbers need to be converted to Python numbers
return x.item()
elif isinstance(x, dict):
# Dicts need their values sanitized
return {key: sanitize(value) for key, value in x.items()}
elif isinstance(x, numpy.bool_):
# Numpy bool_ need to be converted to python bool.
return bool(x)
elif isinstance(x, (spacy.tokens.Token, Token)):
# Tokens get sanitized to just their text.
return x.text
elif isinstance(x, (list, tuple)):
# Lists and Tuples need their values sanitized
return [sanitize(x_i) for x_i in x]
elif x is None:
return "None"
elif hasattr(x, "to_json"):
return x.to_json()
else:
raise ValueError(
f"Cannot sanitize {x} of type {type(x)}. "
"If this is your own custom class, add a `to_json(self)` method "
"that returns a JSON-like object."
)
def group_by_count(iterable: List[Any], count: int, default_value: Any) -> List[List[Any]]:
"""
Takes a list and groups it into sublists of size `count`, using `default_value` to pad the
list at the end if the list is not divisable by `count`.
For example:
```
>>> group_by_count([1, 2, 3, 4, 5, 6, 7], 3, 0)
[[1, 2, 3], [4, 5, 6], [7, 0, 0]]
```
This is a short method, but it's complicated and hard to remember as a one-liner, so we just
make a function out of it.
"""
return [list(l) for l in zip_longest(*[iter(iterable)] * count, fillvalue=default_value)]
A = TypeVar("A")
def lazy_groups_of(iterable: Iterable[A], group_size: int) -> Iterator[List[A]]:
"""
Takes an iterable and batches the individual instances into lists of the
specified size. The last list may be smaller if there are instances left over.
"""
iterator = iter(iterable)
while True:
s = list(islice(iterator, group_size))
if len(s) > 0:
yield s
else:
break
def pad_sequence_to_length(
sequence: List,
desired_length: int,
default_value: Callable[[], Any] = lambda: 0,
padding_on_right: bool = True,
) -> List:
"""
Take a list of objects and pads it to the desired length, returning the padded list. The
original list is not modified.
# Parameters
sequence : List
A list of objects to be padded.
desired_length : int
Maximum length of each sequence. Longer sequences are truncated to this length, and
shorter ones are padded to it.
default_value: Callable, default=lambda: 0
Callable that outputs a default value (of any type) to use as padding values. This is
a lambda to avoid using the same object when the default value is more complex, like a
list.
padding_on_right : bool, default=True
When we add padding tokens (or truncate the sequence), should we do it on the right or
the left?
# Returns
padded_sequence : List
"""
# Truncates the sequence to the desired length.
if padding_on_right:
padded_sequence = sequence[:desired_length]
else:
padded_sequence = sequence[-desired_length:]
# Continues to pad with default_value() until we reach the desired length.
pad_length = desired_length - len(padded_sequence)
# This just creates the default value once, so if it's a list, and if it gets mutated
# later, it could cause subtle bugs. But the risk there is low, and this is much faster.
values_to_pad = [default_value()] * pad_length
if padding_on_right:
padded_sequence = padded_sequence + values_to_pad
else:
padded_sequence = values_to_pad + padded_sequence
return padded_sequence
def add_noise_to_dict_values(dictionary: Dict[A, float], noise_param: float) -> Dict[A, float]:
"""
Returns a new dictionary with noise added to every key in `dictionary`. The noise is
uniformly distributed within `noise_param` percent of the value for every value in the
dictionary.
"""
new_dict = {}
for key, value in dictionary.items():
noise_value = value * noise_param
noise = random.uniform(-noise_value, noise_value)
new_dict[key] = value + noise
return new_dict
def namespace_match(pattern: str, namespace: str):
"""
Matches a namespace pattern against a namespace string. For example, `*tags` matches
`passage_tags` and `question_tags` and `tokens` matches `tokens` but not
`stemmed_tokens`.
"""
if pattern[0] == "*" and namespace.endswith(pattern[1:]):
return True
elif pattern == namespace:
return True
return False
def prepare_environment(params: Params):
"""
Sets random seeds for reproducible experiments. This may not work as expected
if you use this from within a python project in which you have already imported Pytorch.
If you use the scripts/run_model.py entry point to training models with this library,
your experiments should be reasonably reproducible. If you are using this from your own
project, you will want to call this function before importing Pytorch. Complete determinism
is very difficult to achieve with libraries doing optimized linear algebra due to massively
parallel execution, which is exacerbated by using GPUs.
# Parameters
params: Params object or dict, required.
A `Params` object or dict holding the json parameters.
"""
seed = params.pop_int("random_seed", 13370)
numpy_seed = params.pop_int("numpy_seed", 1337)
torch_seed = params.pop_int("pytorch_seed", 133)
if seed is not None:
random.seed(seed)
if numpy_seed is not None:
numpy.random.seed(numpy_seed)
if torch_seed is not None:
torch.manual_seed(torch_seed)
# Seed all GPUs with the same seed if available.
if torch.cuda.is_available():
torch.cuda.manual_seed_all(torch_seed)
log_pytorch_version_info()
class FileFriendlyLogFilter(Filter):
"""
TQDM and requests use carriage returns to get the training line to update for each batch
without adding more lines to the terminal output. Displaying those in a file won't work
correctly, so we'll just make sure that each batch shows up on its one line.
"""
def filter(self, record):
if "\r" in record.msg:
record.msg = record.msg.replace("\r", "")
if not record.msg or record.msg[-1] != "\n":
record.msg += "\n"
return True
class WorkerLogFilter(Filter):
def __init__(self, rank=-1):
super().__init__()
self._rank = rank
def filter(self, record):
if self._rank != -1:
record.msg = f"Rank {self._rank} | {record.msg}"
return True
def prepare_global_logging(
serialization_dir: str, file_friendly_logging: bool, rank: int = 0, world_size: int = 1
) -> None:
# If we don't have a terminal as stdout,
# force tqdm to be nicer.
if not sys.stdout.isatty():
file_friendly_logging = True
Tqdm.set_slower_interval(file_friendly_logging)
# Handlers for stdout/err logging
output_stream_log_handler = logging.StreamHandler(sys.stdout)
error_stream_log_handler = logging.StreamHandler(sys.stderr)
if world_size == 1:
# This case is not distributed training and hence will stick to the older
# log file names
output_file_log_handler = logging.FileHandler(
filename=os.path.join(serialization_dir, "stdout.log")
)
error_file_log_handler = logging.FileHandler(
filename=os.path.join(serialization_dir, "stderr.log")
)
else:
# Create log files with worker ids
output_file_log_handler = logging.FileHandler(
filename=os.path.join(serialization_dir, f"stdout_worker{rank}.log")
)
error_file_log_handler = logging.FileHandler(
filename=os.path.join(serialization_dir, f"stderr_worker{rank}.log")
)
# This adds the worker's rank to messages being logged to files.
# This will help when combining multiple worker log files using `less` command.
worker_filter = WorkerLogFilter(rank)
output_file_log_handler.addFilter(worker_filter)
error_file_log_handler.addFilter(worker_filter)
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s")
root_logger = logging.getLogger()
# Remove the already set stream handler in root logger.
# Not doing this will result in duplicate log messages
# printed in the console
if len(root_logger.handlers) > 0:
for handler in root_logger.handlers:
root_logger.removeHandler(handler)
# file handlers need to be handled for tqdm's \r char
file_friendly_log_filter = FileFriendlyLogFilter()
if os.environ.get("ALLENNLP_DEBUG"):
LEVEL = logging.DEBUG
else:
level_name = os.environ.get("ALLENNLP_LOG_LEVEL")
LEVEL = logging._nameToLevel.get(level_name, logging.INFO)
if rank == 0:
# stdout/stderr handlers are added only for the
# master worker. This is to avoid cluttering the console
# screen with too many log messages from all workers.
output_stream_log_handler.setFormatter(formatter)
error_stream_log_handler.setFormatter(formatter)
output_stream_log_handler.setLevel(LEVEL)
error_stream_log_handler.setLevel(logging.ERROR)
if file_friendly_logging:
output_stream_log_handler.addFilter(file_friendly_log_filter)
error_stream_log_handler.addFilter(file_friendly_log_filter)
root_logger.addHandler(output_stream_log_handler)
root_logger.addHandler(error_stream_log_handler)
output_file_log_handler.addFilter(file_friendly_log_filter)
error_file_log_handler.addFilter(file_friendly_log_filter)
output_file_log_handler.setFormatter(formatter)
error_file_log_handler.setFormatter(formatter)
output_file_log_handler.setLevel(LEVEL)
error_file_log_handler.setLevel(logging.ERROR)
root_logger.addHandler(output_file_log_handler)
root_logger.addHandler(error_file_log_handler)
root_logger.setLevel(LEVEL)
LOADED_SPACY_MODELS: Dict[Tuple[str, bool, bool, bool], SpacyModelType] = {}
def get_spacy_model(
spacy_model_name: str, pos_tags: bool, parse: bool, ner: bool
) -> SpacyModelType:
"""
In order to avoid loading spacy models a whole bunch of times, we'll save references to them,
keyed by the options we used to create the spacy model, so any particular configuration only
gets loaded once.
"""
options = (spacy_model_name, pos_tags, parse, ner)
if options not in LOADED_SPACY_MODELS:
disable = ["vectors", "textcat"]
if not pos_tags:
disable.append("tagger")
if not parse:
disable.append("parser")
if not ner:
disable.append("ner")
try:
spacy_model = spacy.load(spacy_model_name, disable=disable)
except OSError:
logger.warning(
f"Spacy models '{spacy_model_name}' not found. Downloading and installing."
)
spacy_download(spacy_model_name)
# Import the downloaded model module directly and load from there
spacy_model_module = __import__(spacy_model_name)
spacy_model = spacy_model_module.load(disable=disable) # type: ignore
LOADED_SPACY_MODELS[options] = spacy_model
return LOADED_SPACY_MODELS[options]
@contextmanager
def pushd(new_dir: PathType, verbose: bool = False) -> ContextManagerFunctionReturnType[None]:
"""
Changes the current directory to the given path and prepends it to `sys.path`.
This method is intended to use with `with`, so after its usage, the current directory will be
set to the previous value.
"""
previous_dir = os.getcwd()
if verbose:
logger.info(f"Changing directory to {new_dir}") # type: ignore
os.chdir(new_dir)
try:
yield
finally:
if verbose:
logger.info(f"Changing directory back to {previous_dir}")
os.chdir(previous_dir)
@contextmanager
def push_python_path(path: PathType) -> ContextManagerFunctionReturnType[None]:
"""
Prepends the given path to `sys.path`.
This method is intended to use with `with`, so after its usage, its value willbe removed from
`sys.path`.
"""
# In some environments, such as TC, it fails when sys.path contains a relative path, such as ".".
path = Path(path).resolve()
path = str(path)
sys.path.insert(0, path)
try:
yield
finally:
# Better to remove by value, in case `sys.path` was manipulated in between.
sys.path.remove(path)
def import_module_and_submodules(package_name: str) -> None:
"""
Import all submodules under the given package.
Primarily useful so that people using AllenNLP as a library
can specify their own custom packages and have their custom
classes get loaded and registered.
"""
importlib.invalidate_caches()
# For some reason, python doesn't always add this by default to your path, but you pretty much
# always want it when using `--include-package`. And if it's already there, adding it again at
# the end won't hurt anything.
with push_python_path("."):
# Import at top level
module = importlib.import_module(package_name)
path = getattr(module, "__path__", [])
path_string = "" if not path else path[0]
# walk_packages only finds immediate children, so need to recurse.
for module_finder, name, _ in pkgutil.walk_packages(path):
# Sometimes when you import third-party libraries that are on your path,
# `pkgutil.walk_packages` returns those too, so we need to skip them.
if path_string and module_finder.path != path_string:
continue
subpackage = f"{package_name}.{name}"
import_module_and_submodules(subpackage)
def peak_memory_mb() -> float:
"""
Get peak memory usage for this process, as measured by
max-resident-set size:
https://unix.stackexchange.com/questions/30940/getrusage-system-call-what-is-maximum-resident-set-size
Only works on OSX and Linux, returns 0.0 otherwise.
"""
if resource is None or sys.platform not in ("linux", "darwin"):
return 0.0
# TODO(joelgrus): For whatever, our pinned version 0.521 of mypy does not like
# next line, but later versions (e.g. 0.530) are fine with it. Once we get that
# figured out, remove the type: ignore.
peak = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss # type: ignore
if sys.platform == "darwin":
# On OSX the result is in bytes.
return peak / 1_000_000
else:
# On Linux the result is in kilobytes.
return peak / 1_000
def gpu_memory_mb() -> Dict[int, int]:
"""
Get the current GPU memory usage.
Based on https://discuss.pytorch.org/t/access-gpu-memory-usage-in-pytorch/3192/4
# Returns
`Dict[int, int]`
Keys are device ids as integers.
Values are memory usage as integers in MB.
Returns an empty `dict` if GPUs are not available.
"""
try:
result = subprocess.check_output(
["nvidia-smi", "--query-gpu=memory.used", "--format=csv,nounits,noheader"],
encoding="utf-8",
)
gpu_memory = [int(x) for x in result.strip().split("\n")]
return {gpu: memory for gpu, memory in enumerate(gpu_memory)}
except FileNotFoundError:
# `nvidia-smi` doesn't exist, assume that means no GPU.
return {}
except: # noqa
# Catch *all* exceptions, because this memory check is a nice-to-have
# and we'd never want a training run to fail because of it.
logger.warning(
"unable to check gpu_memory_mb() due to occasional failure, continuing", exc_info=True
)
return {}
def ensure_list(iterable: Iterable[A]) -> List[A]:
"""
An Iterable may be a list or a generator.
This ensures we get a list without making an unnecessary copy.
"""
if isinstance(iterable, list):
return iterable
else:
return list(iterable)
def is_lazy(iterable: Iterable[A]) -> bool:
"""
Checks if the given iterable is lazy,
which here just means it's not a list.
"""
return not isinstance(iterable, list)
def log_frozen_and_tunable_parameter_names(model: torch.nn.Module) -> None:
frozen_parameter_names, tunable_parameter_names = get_frozen_and_tunable_parameter_names(model)
logger.info("The following parameters are Frozen (without gradient):")
for name in frozen_parameter_names:
logger.info(name)
logger.info("The following parameters are Tunable (with gradient):")
for name in tunable_parameter_names:
logger.info(name)
def get_frozen_and_tunable_parameter_names(
model: torch.nn.Module,
) -> Tuple[Iterable[str], Iterable[str]]:
frozen_parameter_names = (
name for name, parameter in model.named_parameters() if not parameter.requires_grad
)
tunable_parameter_names = (
name for name, parameter in model.named_parameters() if parameter.requires_grad
)
return frozen_parameter_names, tunable_parameter_names
def dump_metrics(file_path: Optional[str], metrics: Dict[str, Any], log: bool = False) -> None:
metrics_json = json.dumps(metrics, indent=2)
if file_path:
with open(file_path, "w") as metrics_file:
metrics_file.write(metrics_json)
if log:
logger.info("Metrics: %s", metrics_json)
def flatten_filename(file_path: str) -> str:
return file_path.replace("/", "_SLASH_")
def is_master(
global_rank: int = None, world_size: int = None, num_procs_per_node: int = None
) -> bool:
"""
Checks if the process is a "master" of its node in a distributed process group. If a
process group is not initialized, this returns `True`.
# Parameters
global_rank : int ( default = None )
Global rank of the process if in a distributed process group. If not
given, rank is obtained using `torch.distributed.get_rank()`
world_size : int ( default = None )
Number of processes in the distributed group. If not
given, this is obtained using `torch.distributed.get_world_size()`
num_procs_per_node: int ( default = None ),
Number of GPU processes running per node
"""
distributed = dist.is_available() and dist.is_initialized()
# In non-distributed case, a "master" process doesn't make any
# sense. So instead of raising an error, returning True would
# make things less painful
if not distributed:
return True
if global_rank is None:
global_rank = dist.get_rank()
if world_size is None:
world_size = dist.get_world_size()
if num_procs_per_node is None and os.environ:
num_procs_per_node = int(os.environ.get("ALLENNLP_PROCS_PER_NODE", world_size))
# rank == 0 would do in a single-node multi-GPU setup. However,
# in a multi-node case, every node has a logical master and hence
# the mod(%) op.
return global_rank % (world_size / num_procs_per_node) == 0
def is_distributed() -> bool:
"""
Checks if the distributed process group is available and has been initialized
"""
return dist.is_available() and dist.is_initialized()
def sanitize_wordpiece(wordpiece: str) -> str:
"""
Sanitizes wordpieces from BERT, RoBERTa or ALBERT tokenizers.
"""
if wordpiece.startswith("##"):
return wordpiece[2:]
elif wordpiece.startswith("Ġ"):
return wordpiece[1:]
elif wordpiece.startswith("▁"):
return wordpiece[1:]
else:
return wordpiece