-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
yaml_parse.py
477 lines (407 loc) · 16 KB
/
yaml_parse.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
"""Support code for YAML parsing of experiment descriptions."""
import yaml
from pylearn2.utils import serial
from pylearn2.utils.exc import reraise_as
from pylearn2.utils.string_utils import preprocess
from pylearn2.utils.call_check import checked_call
from pylearn2.utils.string_utils import match
from collections import namedtuple
import logging
import warnings
import re
from theano.compat import six
SCIENTIFIC_NOTATION_REGEXP = r'^[\-\+]?(\d+\.?\d*|\d*\.?\d+)?[eE][\-\+]?\d+$'
is_initialized = False
additional_environ = None
logger = logging.getLogger(__name__)
# Lightweight container for initial YAML evaluation.
#
# This is intended as a robust, forward-compatible intermediate representation
# for either internal consumption or external consumption by another tool e.g.
# hyperopt.
#
# We've included a slot for positionals just in case, though they are
# unsupported by the instantiation mechanism as yet.
BaseProxy = namedtuple('BaseProxy', ['callable', 'positionals',
'keywords', 'yaml_src'])
class Proxy(BaseProxy):
"""
An intermediate representation between initial YAML parse and object
instantiation.
Parameters
----------
callable : callable
The function/class to call to instantiate this node.
positionals : iterable
Placeholder for future support for positional arguments (`*args`).
keywords : dict-like
A mapping from keywords to arguments (`**kwargs`), which may be
`Proxy`s or `Proxy`s nested inside `dict` or `list` instances.
Keys must be strings that are valid Python variable names.
yaml_src : str
The YAML source that created this node, if available.
Notes
-----
This is intended as a robust, forward-compatible intermediate
representation for either internal consumption or external consumption
by another tool e.g. hyperopt.
This particular class mainly exists to override `BaseProxy`'s `__hash__`
(to avoid hashing unhashable namedtuple elements).
"""
__slots__ = []
def __hash__(self):
"""
Return a hash based on the object ID (to avoid hashing unhashable
namedtuple elements).
"""
return hash(id(self))
def do_not_recurse(value):
"""
Function symbol used for wrapping an unpickled object (which should
not be recursively expanded). This is recognized and respected by the
instantiation parser. Implementationally, no-op (returns the value
passed in as an argument).
Parameters
----------
value : object
The value to be returned.
Returns
-------
value : object
The same object passed in as an argument.
"""
return value
def _instantiate_proxy_tuple(proxy, bindings=None):
"""
Helper function for `_instantiate` that handles objects of the `Proxy`
class.
Parameters
----------
proxy : Proxy object
A `Proxy` object that.
bindings : dict, opitonal
A dictionary mapping previously instantiated `Proxy` objects
to their instantiated values.
Returns
-------
obj : object
The result object from recursively instantiating the object DAG.
"""
if proxy in bindings:
return bindings[proxy]
else:
# Respect do_not_recurse by just un-packing it (same as calling).
if proxy.callable == do_not_recurse:
obj = proxy.keywords['value']
else:
# TODO: add (requested) support for positionals (needs to be added
# to checked_call also).
if len(proxy.positionals) > 0:
raise NotImplementedError('positional arguments not yet '
'supported in proxy instantiation')
kwargs = dict((k, _instantiate(v, bindings))
for k, v in six.iteritems(proxy.keywords))
obj = checked_call(proxy.callable, kwargs)
try:
obj.yaml_src = proxy.yaml_src
except AttributeError: # Some classes won't allow this.
pass
bindings[proxy] = obj
return bindings[proxy]
def _instantiate(proxy, bindings=None):
"""
Instantiate a (hierarchy of) Proxy object(s).
Parameters
----------
proxy : object
A `Proxy` object or list/dict/literal. Strings are run through
`preprocess`.
bindings : dict, opitonal
A dictionary mapping previously instantiated `Proxy` objects
to their instantiated values.
Returns
-------
obj : object
The result object from recursively instantiating the object DAG.
Notes
-----
This should not be considered part of the stable, public API.
"""
if bindings is None:
bindings = {}
if isinstance(proxy, Proxy):
return _instantiate_proxy_tuple(proxy, bindings)
elif isinstance(proxy, dict):
# Recurse on the keys too, for backward compatibility.
# Is the key instantiation feature ever actually used, by anyone?
return dict((_instantiate(k, bindings), _instantiate(v, bindings))
for k, v in six.iteritems(proxy))
elif isinstance(proxy, list):
return [_instantiate(v, bindings) for v in proxy]
# In the future it might be good to consider a dict argument that provides
# a type->callable mapping for arbitrary transformations like this.
elif isinstance(proxy, six.string_types):
return preprocess(proxy)
else:
return proxy
def load(stream, environ=None, instantiate=True, **kwargs):
"""
Loads a YAML configuration from a string or file-like object.
Parameters
----------
stream : str or object
Either a string containing valid YAML or a file-like object
supporting the .read() interface.
environ : dict, optional
A dictionary used for ${FOO} substitutions in addition to
environment variables. If a key appears both in `os.environ`
and this dictionary, the value in this dictionary is used.
instantiate : bool, optional
If `False`, do not actually instantiate the objects but instead
produce a nested hierarchy of `Proxy` objects.
Returns
-------
graph : dict or object
The dictionary or object (if the top-level element specified
a Python object to instantiate), or a nested hierarchy of
`Proxy` objects.
Notes
-----
Other keyword arguments are passed on to `yaml.load`.
"""
global is_initialized
global additional_environ
if not is_initialized:
initialize()
additional_environ = environ
if isinstance(stream, six.string_types):
string = stream
else:
string = stream.read()
proxy_graph = yaml.load(string, **kwargs)
if instantiate:
return _instantiate(proxy_graph)
else:
return proxy_graph
def load_path(path, environ=None, instantiate=True, **kwargs):
"""
Convenience function for loading a YAML configuration from a file.
Parameters
----------
path : str
The path to the file to load on disk.
environ : dict, optional
A dictionary used for ${FOO} substitutions in addition to
environment variables. If a key appears both in `os.environ`
and this dictionary, the value in this dictionary is used.
instantiate : bool, optional
If `False`, do not actually instantiate the objects but instead
produce a nested hierarchy of `Proxy` objects.
Returns
-------
graph : dict or object
The dictionary or object (if the top-level element specified
a Python object to instantiate), or a nested hierarchy of
`Proxy` objects.
Notes
-----
Other keyword arguments are passed on to `yaml.load`.
"""
with open(path, 'r') as f:
content = ''.join(f.readlines())
# This is apparently here to avoid the odd instance where a file gets
# loaded as Unicode instead (see 03f238c6d). It's rare instance where
# basestring is not the right call.
if not isinstance(content, str):
raise AssertionError("Expected content to be of type str, got " +
str(type(content)))
return load(content, instantiate=instantiate, environ=environ, **kwargs)
def try_to_import(tag_suffix):
"""
.. todo::
WRITEME
"""
components = tag_suffix.split('.')
modulename = '.'.join(components[:-1])
try:
exec('import %s' % modulename)
except ImportError as e:
# We know it's an ImportError, but is it an ImportError related to
# this path,
# or did the module we're importing have an unrelated ImportError?
# and yes, this test can still have false positives, feel free to
# improve it
pieces = modulename.split('.')
str_e = str(e)
found = True in [piece.find(str(e)) != -1 for piece in pieces]
if found:
# The yaml file is probably to blame.
# Report the problem with the full module path from the YAML
# file
reraise_as(ImportError("Could not import %s; ImportError was %s" %
(modulename, str_e)))
else:
pcomponents = components[:-1]
assert len(pcomponents) >= 1
j = 1
while j <= len(pcomponents):
modulename = '.'.join(pcomponents[:j])
try:
exec('import %s' % modulename)
except Exception:
base_msg = 'Could not import %s' % modulename
if j > 1:
modulename = '.'.join(pcomponents[:j - 1])
base_msg += ' but could import %s' % modulename
reraise_as(ImportError(base_msg + '. Original exception: '
+ str(e)))
j += 1
try:
obj = eval(tag_suffix)
except AttributeError as e:
try:
# Try to figure out what the wrong field name was
# If we fail to do it, just fall back to giving the usual
# attribute error
pieces = tag_suffix.split('.')
module = '.'.join(pieces[:-1])
field = pieces[-1]
candidates = dir(eval(module))
msg = ('Could not evaluate %s. ' % tag_suffix +
'Did you mean ' + match(field, candidates) + '? ' +
'Original error was ' + str(e))
except Exception:
warnings.warn("Attempt to decipher AttributeError failed")
reraise_as(AttributeError('Could not evaluate %s. ' % tag_suffix +
'Original error was ' + str(e)))
reraise_as(AttributeError(msg))
return obj
def initialize():
"""
Initialize the configuration system by installing YAML handlers.
Automatically done on first call to load() specified in this file.
"""
global is_initialized
# Add the custom multi-constructor
yaml.add_multi_constructor('!obj:', multi_constructor_obj)
yaml.add_multi_constructor('!pkl:', multi_constructor_pkl)
yaml.add_multi_constructor('!import:', multi_constructor_import)
yaml.add_constructor('!import', constructor_import)
yaml.add_constructor("!float", constructor_float)
pattern = re.compile(SCIENTIFIC_NOTATION_REGEXP)
yaml.add_implicit_resolver('!float', pattern)
is_initialized = True
###############################################################################
# Callbacks used by PyYAML
def multi_constructor_obj(loader, tag_suffix, node):
"""
Callback used by PyYAML when a "!obj:" tag is encountered.
See PyYAML documentation for details on the call signature.
"""
yaml_src = yaml.serialize(node)
construct_mapping(node)
mapping = loader.construct_mapping(node)
assert hasattr(mapping, 'keys')
assert hasattr(mapping, 'values')
for key in mapping.keys():
if not isinstance(key, six.string_types):
message = "Received non string object (%s) as " \
"key in mapping." % str(key)
raise TypeError(message)
if '.' not in tag_suffix:
# TODO: I'm not sure how this was ever working without eval().
callable = eval(tag_suffix)
else:
callable = try_to_import(tag_suffix)
rval = Proxy(callable=callable, yaml_src=yaml_src, positionals=(),
keywords=mapping)
return rval
def multi_constructor_pkl(loader, tag_suffix, node):
"""
Callback used by PyYAML when a "!pkl:" tag is encountered.
"""
global additional_environ
if tag_suffix != "" and tag_suffix != u"":
raise AssertionError('Expected tag_suffix to be "" but it is "'
+ tag_suffix +
'": Put space between !pkl: and the filename.')
mapping = loader.construct_yaml_str(node)
obj = serial.load(preprocess(mapping, additional_environ))
proxy = Proxy(callable=do_not_recurse, positionals=(),
keywords={'value': obj}, yaml_src=yaml.serialize(node))
return proxy
def multi_constructor_import(loader, tag_suffix, node):
"""
Callback used by PyYAML when a "!import:" tag is encountered.
"""
if '.' not in tag_suffix:
raise yaml.YAMLError("!import: tag suffix contains no '.'")
return try_to_import(tag_suffix)
def constructor_import(loader, node):
"""
Callback used by PyYAML when a "!import <str>" tag is encountered.
This tag exects a (quoted) string as argument.
"""
value = loader.construct_scalar(node)
if '.' not in value:
raise yaml.YAMLError("import tag suffix contains no '.'")
return try_to_import(value)
def constructor_float(loader, node):
"""
Callback used by PyYAML when a "!float <str>" tag is encountered.
This tag exects a (quoted) string as argument.
"""
value = loader.construct_scalar(node)
return float(value)
def construct_mapping(node, deep=False):
# This is a modified version of yaml.BaseConstructor.construct_mapping
# in which a repeated key raises a ConstructorError
if not isinstance(node, yaml.nodes.MappingNode):
const = yaml.constructor
message = "expected a mapping node, but found"
raise const.ConstructorError(None, None,
"%s %s " % (message, node.id),
node.start_mark)
mapping = {}
constructor = yaml.constructor.BaseConstructor()
for key_node, value_node in node.value:
key = constructor.construct_object(key_node, deep=False)
try:
hash(key)
except TypeError as exc:
const = yaml.constructor
reraise_as(const.ConstructorError("while constructing a mapping",
node.start_mark,
"found unacceptable key (%s)" %
(exc, key_node.start_mark)))
if key in mapping:
const = yaml.constructor
raise const.ConstructorError("while constructing a mapping",
node.start_mark,
"found duplicate key (%s)" % key)
value = constructor.construct_object(value_node, deep=False)
mapping[key] = value
return mapping
if __name__ == "__main__":
initialize()
# Demonstration of how to specify objects, reference them
# later in the configuration, etc.
yamlfile = """{
"corruptor" : !obj:pylearn2.corruption.GaussianCorruptor &corr {
"corruption_level" : 0.9
},
"dae" : !obj:pylearn2.models.autoencoder.DenoisingAutoencoder {
"nhid" : 20,
"nvis" : 30,
"act_enc" : null,
"act_dec" : null,
"tied_weights" : true,
# we could have also just put the corruptor definition here
"corruptor" : *corr
}
}"""
# yaml.load can take a string or a file object
loaded = yaml.load(yamlfile)
logger.info(loaded)
# These two things should be the same object
assert loaded['corruptor'] is loaded['dae'].corruptor