-
Notifications
You must be signed in to change notification settings - Fork 4.5k
/
io.py
656 lines (507 loc) 路 20.1 KB
/
io.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
from collections import OrderedDict
import errno
import glob
from hashlib import md5
from io import StringIO
import json
import os
import sys
from pathlib import Path
import re
from typing import Any, Dict, List, Optional, Text, Type, Union
import warnings
import random
import string
import portalocker
from ruamel import yaml as yaml
from ruamel.yaml import RoundTripRepresenter, YAMLError
from ruamel.yaml.constructor import DuplicateKeyError, BaseConstructor, ScalarNode
from rasa.shared.constants import (
DEFAULT_LOG_LEVEL,
ENV_LOG_LEVEL,
NEXT_MAJOR_VERSION_FOR_DEPRECATIONS,
CONFIG_SCHEMA_FILE,
MODEL_CONFIG_SCHEMA_FILE,
)
from rasa.shared.exceptions import (
FileIOException,
FileNotFoundException,
YamlSyntaxException,
RasaException,
)
import rasa.shared.utils.validation
DEFAULT_ENCODING = "utf-8"
YAML_VERSION = (1, 2)
class bcolors:
HEADER = "\033[95m"
OKBLUE = "\033[94m"
OKGREEN = "\033[92m"
WARNING = "\033[93m"
FAIL = "\033[91m"
ENDC = "\033[0m"
BOLD = "\033[1m"
UNDERLINE = "\033[4m"
def wrap_with_color(*args: Any, color: Text) -> Text:
return color + " ".join(str(s) for s in args) + bcolors.ENDC
def raise_warning(
message: Text,
category: Optional[Type[Warning]] = None,
docs: Optional[Text] = None,
**kwargs: Any,
) -> None:
"""Emit a `warnings.warn` with sensible defaults and a colored warning msg."""
original_formatter = warnings.formatwarning
def should_show_source_line() -> bool:
if "stacklevel" not in kwargs:
if category == UserWarning or category is None:
return False
if category == FutureWarning:
return False
return True
def formatwarning(
message: Union[Warning, Text],
category: Type[Warning],
filename: Text,
lineno: int,
line: Optional[Text] = None,
) -> Text:
"""Function to format a warning the standard way."""
if not should_show_source_line():
if docs:
line = f"More info at {docs}"
else:
line = ""
formatted_message = original_formatter(
message, category, filename, lineno, line
)
return wrap_with_color(formatted_message, color=bcolors.WARNING)
if "stacklevel" not in kwargs:
# try to set useful defaults for the most common warning categories
if category == DeprecationWarning:
kwargs["stacklevel"] = 3
elif category in (UserWarning, FutureWarning):
kwargs["stacklevel"] = 2
warnings.formatwarning = formatwarning
warnings.warn(message, category=category, **kwargs)
warnings.formatwarning = original_formatter
def write_text_file(
content: Text,
file_path: Union[Text, Path],
encoding: Text = DEFAULT_ENCODING,
append: bool = False,
) -> None:
"""Writes text to a file.
Args:
content: The content to write.
file_path: The path to which the content should be written.
encoding: The encoding which should be used.
append: Whether to append to the file or to truncate the file.
"""
mode = "a" if append else "w"
with open(file_path, mode, encoding=encoding) as file:
file.write(content)
def read_file(filename: Union[Text, Path], encoding: Text = DEFAULT_ENCODING) -> Any:
"""Read text from a file."""
try:
with open(filename, encoding=encoding) as f:
return f.read()
except FileNotFoundError:
raise FileNotFoundException(
f"Failed to read file, " f"'{os.path.abspath(filename)}' does not exist."
)
except UnicodeDecodeError:
raise FileIOException(
f"Failed to read file '{os.path.abspath(filename)}', "
f"could not read the file using {encoding} to decode "
f"it. Please make sure the file is stored with this "
f"encoding."
)
def read_json_file(filename: Union[Text, Path]) -> Any:
"""Read json from a file."""
content = read_file(filename)
try:
return json.loads(content)
except ValueError as e:
raise FileIOException(
f"Failed to read json from '{os.path.abspath(filename)}'. Error: {e}"
)
def list_directory(path: Text) -> List[Text]:
"""Returns all files and folders excluding hidden files.
If the path points to a file, returns the file. This is a recursive
implementation returning files in any depth of the path.
"""
if not isinstance(path, str):
raise ValueError(
f"`resource_name` must be a string type. " f"Got `{type(path)}` instead"
)
if os.path.isfile(path):
return [path]
elif os.path.isdir(path):
results: List[Text] = []
for base, dirs, files in os.walk(path, followlinks=True):
# sort files for same order across runs
files = sorted(files, key=_filename_without_prefix)
# add not hidden files
good_files = filter(lambda x: not x.startswith("."), files)
results.extend(os.path.join(base, f) for f in good_files)
# add not hidden directories
good_directories = filter(lambda x: not x.startswith("."), dirs)
results.extend(os.path.join(base, f) for f in good_directories)
return results
else:
raise ValueError(f"Could not locate the resource '{os.path.abspath(path)}'.")
def list_files(path: Text) -> List[Text]:
"""Returns all files excluding hidden files.
If the path points to a file, returns the file.
"""
return [fn for fn in list_directory(path) if os.path.isfile(fn)]
def _filename_without_prefix(file: Text) -> Text:
"""Splits of a filenames prefix until after the first ``_``."""
return "_".join(file.split("_")[1:])
def list_subdirectories(path: Text) -> List[Text]:
"""Returns all folders excluding hidden files.
If the path points to a file, returns an empty list.
"""
return [fn for fn in glob.glob(os.path.join(path, "*")) if os.path.isdir(fn)]
def deep_container_fingerprint(
obj: Union[List[Any], Dict[Any, Any], Any], encoding: Text = DEFAULT_ENCODING
) -> Text:
"""Calculate a hash which is stable.
Works for lists and dictionaries. For keys and values, we recursively call
`hash(...)` on them. In case of a dict, the hash is independent of the containers
key order. Keep in mind that a list with items in a different order
will not create the same hash!
Args:
obj: dictionary or list to be hashed.
encoding: encoding used for dumping objects as strings
Returns:
hash of the container.
"""
if isinstance(obj, dict):
return get_dictionary_fingerprint(obj, encoding)
elif isinstance(obj, list):
return get_list_fingerprint(obj, encoding)
elif hasattr(obj, "fingerprint") and callable(obj.fingerprint):
return obj.fingerprint()
else:
return get_text_hash(str(obj), encoding)
def get_dictionary_fingerprint(
dictionary: Dict[Any, Any], encoding: Text = DEFAULT_ENCODING
) -> Text:
"""Calculate the fingerprint for a dictionary.
The dictionary can contain any keys and values which are either a dict,
a list or a elements which can be dumped as a string.
Args:
dictionary: dictionary to be hashed
encoding: encoding used for dumping objects as strings
Returns:
The hash of the dictionary
"""
stringified = json.dumps(
{
deep_container_fingerprint(k, encoding): deep_container_fingerprint(
v, encoding
)
for k, v in dictionary.items()
},
sort_keys=True,
)
return get_text_hash(stringified, encoding)
def get_list_fingerprint(
elements: List[Any], encoding: Text = DEFAULT_ENCODING
) -> Text:
"""Calculate a fingerprint for an unordered list.
Args:
elements: unordered list
encoding: encoding used for dumping objects as strings
Returns:
the fingerprint of the list
"""
stringified = json.dumps(
[deep_container_fingerprint(element, encoding) for element in elements]
)
return get_text_hash(stringified, encoding)
def get_text_hash(text: Text, encoding: Text = DEFAULT_ENCODING) -> Text:
"""Calculate the md5 hash for a text."""
# deepcode ignore InsecureHash: Not used for a cryptographic purpose
return md5(text.encode(encoding)).hexdigest() # nosec
def json_to_string(obj: Any, **kwargs: Any) -> Text:
"""Dumps a JSON-serializable object to string.
Args:
obj: JSON-serializable object.
kwargs: serialization options. Defaults to 2 space indentation
and disable escaping of non-ASCII characters.
Returns:
The objects serialized to JSON, as a string.
"""
indent = kwargs.pop("indent", 2)
ensure_ascii = kwargs.pop("ensure_ascii", False)
return json.dumps(obj, indent=indent, ensure_ascii=ensure_ascii, **kwargs)
def fix_yaml_loader() -> None:
"""Ensure that any string read by yaml is represented as unicode."""
def construct_yaml_str(self: BaseConstructor, node: ScalarNode) -> Any:
# Override the default string handling function
# to always return unicode objects
return self.construct_scalar(node)
yaml.Loader.add_constructor("tag:yaml.org,2002:str", construct_yaml_str)
yaml.SafeLoader.add_constructor("tag:yaml.org,2002:str", construct_yaml_str)
def replace_environment_variables() -> None:
"""Enable yaml loader to process the environment variables in the yaml."""
# eg. ${USER_NAME}, ${PASSWORD}
env_var_pattern = re.compile(r"^(.*)\$\{(.*)\}(.*)$")
yaml.Resolver.add_implicit_resolver("!env_var", env_var_pattern, None)
def env_var_constructor(loader: BaseConstructor, node: ScalarNode) -> Text:
"""Process environment variables found in the YAML."""
value = loader.construct_scalar(node)
expanded_vars = os.path.expandvars(value)
not_expanded = [
w for w in expanded_vars.split() if w.startswith("$") and w in value
]
if not_expanded:
raise RasaException(
f"Error when trying to expand the "
f"environment variables in '{value}'. "
f"Please make sure to also set these "
f"environment variables: '{not_expanded}'."
)
return expanded_vars
yaml.SafeConstructor.add_constructor("!env_var", env_var_constructor)
fix_yaml_loader()
replace_environment_variables()
def read_yaml(content: Text, reader_type: Union[Text, List[Text]] = "safe") -> Any:
"""Parses yaml from a text.
Args:
content: A text containing yaml content.
reader_type: Reader type to use. By default "safe" will be used.
Raises:
ruamel.yaml.parser.ParserError: If there was an error when parsing the YAML.
"""
if _is_ascii(content):
# Required to make sure emojis are correctly parsed
content = (
content.encode("utf-8")
.decode("raw_unicode_escape")
.encode("utf-16", "surrogatepass")
.decode("utf-16")
)
yaml_parser = yaml.YAML(typ=reader_type)
yaml_parser.version = YAML_VERSION # type: ignore[assignment]
yaml_parser.preserve_quotes = True # type: ignore[assignment]
return yaml_parser.load(content) or {}
def _is_ascii(text: Text) -> bool:
return all(ord(character) < 128 for character in text)
def read_yaml_file(
filename: Union[Text, Path], reader_type: Union[Text, List[Text]] = "safe"
) -> Union[List[Any], Dict[Text, Any]]:
"""Parses a yaml file.
Raises an exception if the content of the file can not be parsed as YAML.
Args:
filename: The path to the file which should be read.
reader_type: Reader type to use. By default "safe" will be used.
Returns:
Parsed content of the file.
"""
try:
return read_yaml(read_file(filename, DEFAULT_ENCODING), reader_type)
except (YAMLError, DuplicateKeyError) as e:
raise YamlSyntaxException(filename, e)
def write_yaml(
data: Any,
target: Union[Text, Path, StringIO],
should_preserve_key_order: bool = False,
) -> None:
"""Writes a yaml to the file or to the stream.
Args:
data: The data to write.
target: The path to the file which should be written or a stream object
should_preserve_key_order: Whether to force preserve key order in `data`.
"""
_enable_ordered_dict_yaml_dumping()
if should_preserve_key_order:
data = convert_to_ordered_dict(data)
dumper = yaml.YAML()
# no wrap lines
dumper.width = YAML_LINE_MAX_WIDTH # type: ignore[assignment]
# use `null` to represent `None`
dumper.representer.add_representer(
type(None),
lambda self, _: self.represent_scalar("tag:yaml.org,2002:null", "null"),
)
if isinstance(target, StringIO):
dumper.dump(data, target)
return
with Path(target).open("w", encoding=DEFAULT_ENCODING) as outfile:
dumper.dump(data, outfile)
YAML_LINE_MAX_WIDTH = 4096
def is_key_in_yaml(file_path: Union[Text, Path], *keys: Text) -> bool:
"""Checks if any of the keys is contained in the root object of the yaml file.
Arguments:
file_path: path to the yaml file
keys: keys to look for
Returns:
`True` if at least one of the keys is found, `False` otherwise.
Raises:
FileNotFoundException: if the file cannot be found.
"""
try:
with open(file_path, encoding=DEFAULT_ENCODING) as file:
return any(
any(line.lstrip().startswith(f"{key}:") for key in keys)
for line in file
)
except FileNotFoundError:
raise FileNotFoundException(
f"Failed to read file, " f"'{os.path.abspath(file_path)}' does not exist."
)
def convert_to_ordered_dict(obj: Any) -> Any:
"""Convert object to an `OrderedDict`.
Args:
obj: Object to convert.
Returns:
An `OrderedDict` with all nested dictionaries converted if `obj` is a
dictionary, otherwise the object itself.
"""
if isinstance(obj, OrderedDict):
return obj
# use recursion on lists
if isinstance(obj, list):
return [convert_to_ordered_dict(element) for element in obj]
if isinstance(obj, dict):
out = OrderedDict()
# use recursion on dictionaries
for k, v in obj.items():
out[k] = convert_to_ordered_dict(v)
return out
# return all other objects
return obj
def _enable_ordered_dict_yaml_dumping() -> None:
"""Ensure that `OrderedDict`s are dumped so that the order of keys is respected."""
yaml.add_representer(
OrderedDict,
RoundTripRepresenter.represent_dict,
representer=RoundTripRepresenter,
)
def is_logging_disabled() -> bool:
"""Returns `True` if log level is set to WARNING or ERROR, `False` otherwise."""
log_level = os.environ.get(ENV_LOG_LEVEL, DEFAULT_LOG_LEVEL)
return log_level in ("ERROR", "WARNING")
def create_directory_for_file(file_path: Union[Text, Path]) -> None:
"""Creates any missing parent directories of this file path."""
create_directory(os.path.dirname(file_path))
def dump_obj_as_json_to_file(filename: Union[Text, Path], obj: Any) -> None:
"""Dump an object as a json string to a file."""
write_text_file(json.dumps(obj, ensure_ascii=False, indent=2), filename)
def dump_obj_as_yaml_to_string(
obj: Any, should_preserve_key_order: bool = False
) -> Text:
"""Writes data (python dict) to a yaml string.
Args:
obj: The object to dump. Has to be serializable.
should_preserve_key_order: Whether to force preserve key order in `data`.
Returns:
The object converted to a YAML string.
"""
buffer = StringIO()
write_yaml(obj, buffer, should_preserve_key_order=should_preserve_key_order)
return buffer.getvalue()
def create_directory(directory_path: Text) -> None:
"""Creates a directory and its super paths.
Succeeds even if the path already exists.
"""
try:
os.makedirs(directory_path)
except OSError as e:
# be happy if someone already created the path
if e.errno != errno.EEXIST:
raise
def raise_deprecation_warning(
message: Text,
warn_until_version: Text = NEXT_MAJOR_VERSION_FOR_DEPRECATIONS,
docs: Optional[Text] = None,
**kwargs: Any,
) -> None:
"""Thin wrapper around `raise_warning()` to raise a deprecation warning. It requires
a version until which we'll warn, and after which the support for the feature will
be removed.
"""
if warn_until_version not in message:
message = f"{message} (will be removed in {warn_until_version})"
# need the correct stacklevel now
kwargs.setdefault("stacklevel", 3)
# we're raising a `FutureWarning` instead of a `DeprecationWarning` because
# we want these warnings to be visible in the terminal of our users
# https://docs.python.org/3/library/warnings.html#warning-categories
raise_warning(message, FutureWarning, docs, **kwargs)
def read_validated_yaml(
filename: Union[Text, Path],
schema: Text,
reader_type: Union[Text, List[Text]] = "safe",
) -> Any:
"""Validates YAML file content and returns parsed content.
Args:
filename: The path to the file which should be read.
schema: The path to the schema file which should be used for validating the
file content.
reader_type: Reader type to use. By default "safe" will be used.
Returns:
The parsed file content.
Raises:
YamlValidationException: In case the model configuration doesn't match the
expected schema.
"""
content = read_file(filename)
rasa.shared.utils.validation.validate_yaml_schema(content, schema)
return read_yaml(content, reader_type)
def read_config_file(
filename: Union[Path, Text], reader_type: Union[Text, List[Text]] = "safe"
) -> Dict[Text, Any]:
"""Parses a yaml configuration file. Content needs to be a dictionary.
Args:
filename: The path to the file which should be read.
reader_type: Reader type to use. By default "safe" will be used.
Raises:
YamlValidationException: In case file content is not a `Dict`.
Returns:
Parsed config file.
"""
return read_validated_yaml(filename, CONFIG_SCHEMA_FILE, reader_type)
def read_model_configuration(filename: Union[Path, Text]) -> Dict[Text, Any]:
"""Parses a model configuration file.
Args:
filename: The path to the file which should be read.
Raises:
YamlValidationException: In case the model configuration doesn't match the
expected schema.
Returns:
Parsed config file.
"""
return read_validated_yaml(filename, MODEL_CONFIG_SCHEMA_FILE)
def is_subdirectory(path: Text, potential_parent_directory: Text) -> bool:
"""Checks if `path` is a subdirectory of `potential_parent_directory`.
Args:
path: Path to a file or directory.
potential_parent_directory: Potential parent directory.
Returns:
`True` if `path` is a subdirectory of `potential_parent_directory`.
"""
if path is None or potential_parent_directory is None:
return False
path = os.path.abspath(path)
potential_parent_directory = os.path.abspath(potential_parent_directory)
return potential_parent_directory in path
def random_string(length: int) -> Text:
"""Returns a random string of given length."""
return "".join(random.choices(string.ascii_uppercase + string.digits, k=length))
def handle_print_blocking(output: Text) -> None:
"""Handle print blocking (BlockingIOError) by getting the STDOUT lock.
Args:
output: Text to be printed to STDOUT.
"""
# Locking again to obtain STDOUT with a lock.
with portalocker.Lock(sys.stdout) as lock:
if sys.platform == "win32":
# colorama is used to fix a regression where colors can not be printed on
# windows. https://github.com/RasaHQ/rasa/issues/7053
from colorama import AnsiToWin32
lock = AnsiToWin32(lock).stream
print(output, file=lock, flush=True)