This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
from_params.py
623 lines (532 loc) · 25.2 KB
/
from_params.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
import collections.abc
from copy import deepcopy
from pathlib import Path
from typing import (
Any,
Callable,
cast,
Dict,
Iterable,
List,
Mapping,
Set,
Tuple,
Type,
TypeVar,
Union,
)
import inspect
import logging
from allennlp.common.checks import ConfigurationError
from allennlp.common.lazy import Lazy
from allennlp.common.params import Params
logger = logging.getLogger(__name__)
T = TypeVar("T", bound="FromParams")
# If a function parameter has no default value specified,
# this is what the inspect module returns.
_NO_DEFAULT = inspect.Parameter.empty
def takes_arg(obj, arg: str) -> bool:
"""
Checks whether the provided obj takes a certain arg.
If it's a class, we're really checking whether its constructor does.
If it's a function or method, we're checking the object itself.
Otherwise, we raise an error.
"""
if inspect.isclass(obj):
signature = inspect.signature(obj.__init__)
elif inspect.ismethod(obj) or inspect.isfunction(obj):
signature = inspect.signature(obj)
else:
raise ConfigurationError(f"object {obj} is not callable")
return arg in signature.parameters
def takes_kwargs(obj) -> bool:
"""
Checks whether a provided object takes in any positional arguments.
Similar to takes_arg, we do this for both the __init__ function of
the class or a function / method
Otherwise, we raise an error
"""
if inspect.isclass(obj):
signature = inspect.signature(obj.__init__)
elif inspect.ismethod(obj) or inspect.isfunction(obj):
signature = inspect.signature(obj)
else:
raise ConfigurationError(f"object {obj} is not callable")
return any(
p.kind == inspect.Parameter.VAR_KEYWORD # type: ignore
for p in signature.parameters.values()
)
def can_construct_from_params(type_: Type) -> bool:
if type_ in [str, int, float, bool]:
return True
origin = getattr(type_, "__origin__", None)
if origin == Lazy:
return True
elif origin:
if hasattr(type_, "from_params"):
return True
args = getattr(type_, "__args__")
return all(can_construct_from_params(arg) for arg in args)
return hasattr(type_, "from_params")
def is_base_registrable(cls) -> bool:
"""
Checks whether this is a class that directly inherits from Registrable, or is a subclass of such
a class.
"""
from allennlp.common.registrable import Registrable # import here to avoid circular imports
if not issubclass(cls, Registrable):
return False
method_resolution_order = inspect.getmro(cls)[1:]
for base_class in method_resolution_order:
if issubclass(base_class, Registrable) and base_class is not Registrable:
return False
return True
def remove_optional(annotation: type):
"""
Optional[X] annotations are actually represented as Union[X, NoneType].
For our purposes, the "Optional" part is not interesting, so here we
throw it away.
"""
origin = getattr(annotation, "__origin__", None)
args = getattr(annotation, "__args__", ())
if origin == Union:
return Union[tuple([arg for arg in args if arg != type(None)])] # noqa: E721
else:
return annotation
def infer_params(
cls: Type[T], constructor: Union[Callable[..., T], Callable[[T], None]] = None
) -> Dict[str, Any]:
if constructor is None:
constructor = cls.__init__
signature = inspect.signature(constructor)
parameters = dict(signature.parameters)
has_kwargs = False
var_positional_key = None
for param in parameters.values():
if param.kind == param.VAR_KEYWORD:
has_kwargs = True
elif param.kind == param.VAR_POSITIONAL:
var_positional_key = param.name
if var_positional_key:
del parameters[var_positional_key]
if not has_kwargs:
return parameters
# "mro" is "method resolution order". The first one is the current class, the next is the
# first superclass, and so on. We take the first superclass we find that inherits from
# FromParams.
super_class = None
for super_class_candidate in cls.mro()[1:]:
if issubclass(super_class_candidate, FromParams):
super_class = super_class_candidate
break
if super_class:
super_parameters = infer_params(super_class)
else:
super_parameters = {}
return {**super_parameters, **parameters} # Subclass parameters overwrite superclass ones
def create_kwargs(
constructor: Callable[..., T], cls: Type[T], params: Params, **extras
) -> Dict[str, Any]:
"""
Given some class, a `Params` object, and potentially other keyword arguments,
create a dict of keyword args suitable for passing to the class's constructor.
The function does this by finding the class's constructor, matching the constructor
arguments to entries in the `params` object, and instantiating values for the parameters
using the type annotation and possibly a from_params method.
Any values that are provided in the `extras` will just be used as is.
For instance, you might provide an existing `Vocabulary` this way.
"""
# Get the signature of the constructor.
kwargs: Dict[str, Any] = {}
parameters = infer_params(cls, constructor)
accepts_kwargs = False
# Iterate over all the constructor parameters and their annotations.
for param_name, param in parameters.items():
# Skip "self". You're not *required* to call the first parameter "self",
# so in theory this logic is fragile, but if you don't call the self parameter
# "self" you kind of deserve what happens.
if param_name == "self":
continue
if param.kind == param.VAR_KEYWORD:
# When a class takes **kwargs, we do two things: first, we assume that the **kwargs are
# getting passed to the super class, so we inspect super class constructors to get
# allowed arguments (that happens in `infer_params` above). Second, we store the fact
# that the method allows extra keys; if we get extra parameters, instead of crashing,
# we'll just pass them as-is to the constructor, and hope that you know what you're
# doing.
accepts_kwargs = True
continue
# If the annotation is a compound type like typing.Dict[str, int],
# it will have an __origin__ field indicating `typing.Dict`
# and an __args__ field indicating `(str, int)`. We capture both.
annotation = remove_optional(param.annotation)
explicitly_set = param_name in params
constructed_arg = pop_and_construct_arg(
cls.__name__, param_name, annotation, param.default, params, **extras
)
# If the param wasn't explicitly set in `params` and we just ended up constructing
# the default value for the parameter, we can just omit it.
# Leaving it in can cause issues with **kwargs in some corner cases, where you might end up
# with multiple values for a single parameter (e.g., the default value gives you lazy=False
# for a dataset reader inside **kwargs, but a particular dataset reader actually hard-codes
# lazy=True - the superclass sees both lazy=True and lazy=False in its constructor).
if explicitly_set or constructed_arg is not param.default:
kwargs[param_name] = constructed_arg
if accepts_kwargs:
kwargs.update(params)
else:
params.assert_empty(cls.__name__)
return kwargs
def create_extras(cls: Type[T], extras: Dict[str, Any]) -> Dict[str, Any]:
"""
Given a dictionary of extra arguments, returns a dictionary of
kwargs that actually are a part of the signature of the cls.from_params
(or cls) method.
"""
subextras: Dict[str, Any] = {}
if hasattr(cls, "from_params"):
from_params_method = cls.from_params # type: ignore
else:
# In some rare cases, we get a registered subclass that does _not_ have a
# from_params method (this happens with Activations, for instance, where we
# register pytorch modules directly). This is a bit of a hack to make those work,
# instead of adding a `from_params` method for them somehow. Then the extras
# in the class constructor are what we are looking for, to pass on.
from_params_method = cls
if takes_kwargs(from_params_method):
# If annotation.params accepts **kwargs, we need to pass them all along.
# For example, `BasicTextFieldEmbedder.from_params` requires a Vocabulary
# object, but `TextFieldEmbedder.from_params` does not.
subextras = extras
else:
# Otherwise, only supply the ones that are actual args; any additional ones
# will cause a TypeError.
subextras = {k: v for k, v in extras.items() if takes_arg(from_params_method, k)}
return subextras
def pop_and_construct_arg(
class_name: str, argument_name: str, annotation: Type, default: Any, params: Params, **extras
) -> Any:
"""
Does the work of actually constructing an individual argument for
[`create_kwargs`](./#create_kwargs).
Here we're in the inner loop of iterating over the parameters to a particular constructor,
trying to construct just one of them. The information we get for that parameter is its name,
its type annotation, and its default value; we also get the full set of `Params` for
constructing the object (which we may mutate), and any `extras` that the constructor might
need.
We take the type annotation and default value here separately, instead of using an
`inspect.Parameter` object directly, so that we can handle `Union` types using recursion on
this method, trying the different annotation types in the union in turn.
"""
from allennlp.models.archival import load_archive # import here to avoid circular imports
# We used `argument_name` as the method argument to avoid conflicts with 'name' being a key in
# `extras`, which isn't _that_ unlikely. Now that we are inside the method, we can switch back
# to using `name`.
name = argument_name
# Some constructors expect extra non-parameter items, e.g. vocab: Vocabulary.
# We check the provided `extras` for these and just use them if they exist.
if name in extras:
if name not in params:
return extras[name]
else:
logger.warning(
f"Parameter {name} for class {class_name} was found in both "
"**extras and in params. Using the specification found in params, "
"but you probably put a key in a config file that you didn't need, "
"and if it is different from what we get from **extras, you might "
"get unexpected behavior."
)
# Next case is when argument should be loaded from pretrained archive.
elif (
name in params
and isinstance(params.get(name), Params)
and "_pretrained" in params.get(name)
):
load_module_params = params.pop(name).pop("_pretrained")
archive_file = load_module_params.pop("archive_file")
module_path = load_module_params.pop("module_path")
freeze = load_module_params.pop("freeze", True)
archive = load_archive(archive_file)
result = archive.extract_module(module_path, freeze)
if not isinstance(result, annotation):
raise ConfigurationError(
f"The module from model at {archive_file} at path {module_path} "
f"was expected of type {annotation} but is of type {type(result)}"
)
return result
popped_params = params.pop(name, default) if default != _NO_DEFAULT else params.pop(name)
if popped_params is None:
return None
return construct_arg(class_name, name, popped_params, annotation, default, **extras)
def construct_arg(
class_name: str,
argument_name: str,
popped_params: Params,
annotation: Type,
default: Any,
**extras,
) -> Any:
"""
The first two parameters here are only used for logging if we encounter an error.
"""
origin = getattr(annotation, "__origin__", None)
args = getattr(annotation, "__args__", [])
# The parameter is optional if its default value is not the "no default" sentinel.
optional = default != _NO_DEFAULT
if hasattr(annotation, "from_params"):
if popped_params is default:
return default
elif popped_params is not None:
# Our params have an entry for this, so we use that.
subextras = create_extras(annotation, extras)
# In some cases we allow a string instead of a param dict, so
# we need to handle that case separately.
if isinstance(popped_params, str):
popped_params = Params({"type": popped_params})
elif isinstance(popped_params, dict):
popped_params = Params(popped_params)
return annotation.from_params(params=popped_params, **subextras)
elif not optional:
# Not optional and not supplied, that's an error!
raise ConfigurationError(f"expected key {argument_name} for {class_name}")
else:
return default
# If the parameter type is a Python primitive, just pop it off
# using the correct casting pop_xyz operation.
elif annotation in {int, bool}:
if type(popped_params) in {int, bool}:
return annotation(popped_params)
else:
raise TypeError(f"Expected {argument_name} to be a {annotation.__name__}.")
elif annotation == str:
# Strings are special because we allow casting from Path to str.
if type(popped_params) == str or isinstance(popped_params, Path):
return str(popped_params) # type: ignore
else:
raise TypeError(f"Expected {argument_name} to be a string.")
elif annotation == float:
# Floats are special because in Python, you can put an int wherever you can put a float.
# https://mypy.readthedocs.io/en/stable/duck_type_compatibility.html
if type(popped_params) in {int, float}:
return popped_params
else:
raise TypeError(f"Expected {argument_name} to be numeric.")
# This is special logic for handling types like Dict[str, TokenIndexer],
# List[TokenIndexer], Tuple[TokenIndexer, Tokenizer], and Set[TokenIndexer],
# which it creates by instantiating each value from_params and returning the resulting structure.
elif (
origin in {collections.abc.Mapping, Mapping, Dict, dict}
and len(args) == 2
and can_construct_from_params(args[-1])
):
value_cls = annotation.__args__[-1]
value_dict = {}
if not isinstance(popped_params, Mapping):
raise TypeError(
f"Expected {argument_name} to be a Mapping (probably a dict or a Params object)."
)
for key, value_params in popped_params.items():
value_dict[key] = construct_arg(
str(value_cls),
argument_name + "." + key,
value_params,
value_cls,
_NO_DEFAULT,
**extras,
)
return value_dict
elif origin in (Tuple, tuple) and all(can_construct_from_params(arg) for arg in args):
value_list = []
for i, (value_cls, value_params) in enumerate(zip(annotation.__args__, popped_params)):
value = construct_arg(
str(value_cls),
argument_name + f".{i}",
value_params,
value_cls,
_NO_DEFAULT,
**extras,
)
value_list.append(value)
return tuple(value_list)
elif origin in (Set, set) and len(args) == 1 and can_construct_from_params(args[0]):
value_cls = annotation.__args__[0]
value_set = set()
for i, value_params in enumerate(popped_params):
value = construct_arg(
str(value_cls),
argument_name + f".{i}",
value_params,
value_cls,
_NO_DEFAULT,
**extras,
)
value_set.add(value)
return value_set
elif origin == Union:
# Storing this so we can recover it later if we need to.
backup_params = deepcopy(popped_params)
# We'll try each of the given types in the union sequentially, returning the first one that
# succeeds.
for arg_annotation in args:
try:
return construct_arg(
str(arg_annotation),
argument_name,
popped_params,
arg_annotation,
default,
**extras,
)
except (ValueError, TypeError, ConfigurationError, AttributeError):
# Our attempt to construct the argument may have modified popped_params, so we
# restore it here.
popped_params = deepcopy(backup_params)
# If none of them succeeded, we crash.
raise ConfigurationError(
f"Failed to construct argument {argument_name} with type {annotation}"
)
elif origin == Lazy:
if popped_params is default:
return default
value_cls = args[0]
subextras = create_extras(value_cls, extras)
return Lazy(value_cls, params=deepcopy(popped_params), contructor_extras=subextras) # type: ignore
# For any other kind of iterable, we will just assume that a list is good enough, and treat
# it the same as List. This condition needs to be at the end, so we don't catch other kinds
# of Iterables with this branch.
elif (
origin in {collections.abc.Iterable, Iterable, List, list}
and len(args) == 1
and can_construct_from_params(args[0])
):
value_cls = annotation.__args__[0]
value_list = []
for i, value_params in enumerate(popped_params):
value = construct_arg(
str(value_cls),
argument_name + f".{i}",
value_params,
value_cls,
_NO_DEFAULT,
**extras,
)
value_list.append(value)
return value_list
else:
# Pass it on as is and hope for the best. ¯\_(ツ)_/¯
if isinstance(popped_params, Params):
return popped_params.as_dict()
return popped_params
class FromParams:
"""
Mixin to give a from_params method to classes. We create a distinct base class for this
because sometimes we want non-Registrable classes to be instantiatable from_params.
"""
@classmethod
def from_params(
cls: Type[T],
params: Params,
constructor_to_call: Callable[..., T] = None,
constructor_to_inspect: Union[Callable[..., T], Callable[[T], None]] = None,
**extras,
) -> T:
"""
This is the automatic implementation of `from_params`. Any class that subclasses
`FromParams` (or `Registrable`, which itself subclasses `FromParams`) gets this
implementation for free. If you want your class to be instantiated from params in the
"obvious" way -- pop off parameters and hand them to your constructor with the same names --
this provides that functionality.
If you need more complex logic in your from `from_params` method, you'll have to implement
your own method that overrides this one.
The `constructor_to_call` and `constructor_to_inspect` arguments deal with a bit of
redirection that we do. We allow you to register particular `@classmethods` on a class as
the constructor to use for a registered name. This lets you, e.g., have a single
`Vocabulary` class that can be constructed in two different ways, with different names
registered to each constructor. In order to handle this, we need to know not just the class
we're trying to construct (`cls`), but also what method we should inspect to find its
arguments (`constructor_to_inspect`), and what method to call when we're done constructing
arguments (`constructor_to_call`). These two methods are the same when you've used a
`@classmethod` as your constructor, but they are `different` when you use the default
constructor (because you inspect `__init__`, but call `cls()`).
"""
from allennlp.common.registrable import Registrable # import here to avoid circular imports
logger.debug(
f"instantiating class {cls} from params {getattr(params, 'params', params)} "
f"and extras {set(extras.keys())}"
)
if params is None:
return None
if isinstance(params, str):
params = Params({"type": params})
if not isinstance(params, Params):
raise ConfigurationError(
"from_params was passed a `params` object that was not a `Params`. This probably "
"indicates malformed parameters in a configuration file, where something that "
"should have been a dictionary was actually a list, or something else. "
f"This happened when constructing an object of type {cls}."
)
registered_subclasses = Registrable._registry.get(cls)
if is_base_registrable(cls) and registered_subclasses is None:
# NOTE(mattg): There are some potential corner cases in this logic if you have nested
# Registrable types. We don't currently have any of those, but if we ever get them,
# adding some logic to check `constructor_to_call` should solve the issue. Not
# bothering to add that unnecessary complexity for now.
raise ConfigurationError(
"Tried to construct an abstract Registrable base class that has no registered "
"concrete types. This might mean that you need to use --include-package to get "
"your concrete classes actually registered."
)
if registered_subclasses is not None and not constructor_to_call:
# We know `cls` inherits from Registrable, so we'll use a cast to make mypy happy.
as_registrable = cast(Type[Registrable], cls)
default_to_first_choice = as_registrable.default_implementation is not None
choice = params.pop_choice(
"type",
choices=as_registrable.list_available(),
default_to_first_choice=default_to_first_choice,
)
subclass, constructor_name = as_registrable.resolve_class_name(choice)
# See the docstring for an explanation of what's going on here.
if not constructor_name:
constructor_to_inspect = subclass.__init__
constructor_to_call = subclass # type: ignore
else:
constructor_to_inspect = cast(Callable[..., T], getattr(subclass, constructor_name))
constructor_to_call = constructor_to_inspect
if hasattr(subclass, "from_params"):
# We want to call subclass.from_params.
extras = create_extras(subclass, extras)
# mypy can't follow the typing redirection that we do, so we explicitly cast here.
retyped_subclass = cast(Type[T], subclass)
return retyped_subclass.from_params(
params=params,
constructor_to_call=constructor_to_call,
constructor_to_inspect=constructor_to_inspect,
**extras,
)
else:
# In some rare cases, we get a registered subclass that does _not_ have a
# from_params method (this happens with Activations, for instance, where we
# register pytorch modules directly). This is a bit of a hack to make those work,
# instead of adding a `from_params` method for them somehow. We just trust that
# you've done the right thing in passing your parameters, and nothing else needs to
# be recursively constructed.
return subclass(**params) # type: ignore
else:
# This is not a base class, so convert our params and extras into a dict of kwargs.
# See the docstring for an explanation of what's going on here.
if not constructor_to_inspect:
constructor_to_inspect = cls.__init__
if not constructor_to_call:
constructor_to_call = cls
if constructor_to_inspect == object.__init__:
# This class does not have an explicit constructor, so don't give it any kwargs.
# Without this logic, create_kwargs will look at object.__init__ and see that
# it takes *args and **kwargs and look for those.
kwargs: Dict[str, Any] = {}
params.assert_empty(cls.__name__)
else:
# This class has a constructor, so create kwargs for it.
constructor_to_inspect = cast(Callable[..., T], constructor_to_inspect)
kwargs = create_kwargs(constructor_to_inspect, cls, params, **extras)
return constructor_to_call(**kwargs) # type: ignore