Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
fancy version of the configuration wizard (#1344)
Browse files Browse the repository at this point in the history
* initial version

* progress

* more features

* progress

* clean up

* add docstring

* provide type annotations for Dicts and Lists

* rewrite

* rewrite the whole thing

* fix bugs

* fix tests + final touches

* add tutorial

* fix broken link

* fix typo
  • Loading branch information
joelgrus committed Jun 12, 2018
1 parent 6739c31 commit d844a9a
Show file tree
Hide file tree
Showing 11 changed files with 1,604 additions and 199 deletions.
136 changes: 121 additions & 15 deletions allennlp/common/configuration.py
Expand Up @@ -6,9 +6,11 @@
from typing import NamedTuple, Optional, Any, List, TypeVar, Generic, Type, Dict, Union, Sequence, Tuple
import inspect
import importlib
import json
import re

import torch
from numpydoc.docscrape import NumpyDocString

from allennlp.common import Registrable, JsonDict
from allennlp.data.dataset_readers import DatasetReader
Expand All @@ -17,6 +19,7 @@
from allennlp.models.model import Model
from allennlp.modules.seq2seq_encoders import _Seq2SeqWrapper
from allennlp.modules.seq2vec_encoders import _Seq2VecWrapper
from allennlp.modules.token_embedders import Embedding
from allennlp.nn.initializers import Initializer
from allennlp.nn.regularizers import Regularizer
from allennlp.training.optimizers import Optimizer as AllenNLPOptimizer
Expand Down Expand Up @@ -57,6 +60,37 @@ def full_name(cla55: Optional[type]) -> str:
return _remove_prefix(f"{cla55.__module__}.{cla55.__name__}")


def json_annotation(cla55: Optional[type]):
# Special case to handle None:
if cla55 is None:
return {'origin': '?'}

# Hack because e.g. typing.Union isn't a type.
if isinstance(cla55, type) and issubclass(cla55, Initializer) and cla55 != Initializer:
init_fn = cla55()._init_function
return {'origin': f"{init_fn.__module__}.{init_fn.__name__}"}

origin = getattr(cla55, '__origin__', None)
args = getattr(cla55, '__args__', ())

# Special handling for compound types
if origin == Dict:
key_type, value_type = args
return {'origin': "Dict", 'args': [json_annotation(key_type), json_annotation(value_type)]}
elif origin in (Tuple, List, Sequence):
return {'origin': _remove_prefix(str(origin)), 'args': [json_annotation(arg) for arg in args]}
elif origin == Union:
# Special special case to handle optional types:
if len(args) == 2 and args[-1] == type(None):
return {'origin': json_annotation(args[0])}
else:
return {'origin': "Union", 'args': [json_annotation(arg) for arg in args]}
elif cla55 == Ellipsis:
return {'origin': "..."}
else:
return {'origin': _remove_prefix(f"{cla55.__module__}.{cla55.__name__}")}


class ConfigItem(NamedTuple):
"""
Each ``ConfigItem`` represents a single entry in a configuration JsonDict.
Expand All @@ -67,12 +101,29 @@ class ConfigItem(NamedTuple):
comment: str = ''

def to_json(self) -> JsonDict:
return {
"annotation": full_name(self.annotation),
"default_value": str(self.default_value),
"comment": self.comment
json_dict = {
"name": self.name,
"annotation": json_annotation(self.annotation),
}

if is_configurable(self.annotation):
json_dict["configurable"] = True

if self.default_value != _NO_DEFAULT:
try:
# Ugly check that default value is actually serializable
json.dumps(self.default_value)
json_dict["defaultValue"] = self.default_value
except TypeError:
print(f"unable to json serialize {self.default_value}, using None instead")
json_dict["defaultValue"] = None


if self.comment:
json_dict["comment"] = self.comment

return json_dict


T = TypeVar("T")

Expand All @@ -92,15 +143,13 @@ def __repr__(self) -> str:
return f"Config({self.items})"

def to_json(self) -> JsonDict:
item_dict: JsonDict = {
item.name: item.to_json()
for item in self.items
}
blob: JsonDict = {'items': [item.to_json() for item in self.items]}

if self.typ3:
item_dict["type"] = self.typ3
#items.insert(0, {"name": "type", "type": self.typ3})
blob["type"] = self.typ3

return item_dict
return blob


# ``None`` is sometimes the default value for a function parameter,
Expand Down Expand Up @@ -135,6 +184,33 @@ def _get_config_type(cla55: type) -> Optional[str]:

return None

def _docspec_comments(obj) -> Dict[str, str]:
"""
Inspect the docstring and get the comments for each parameter.
"""
# Sometimes our docstring is on the class, and sometimes it's on the initializer,
# so we've got to check both.
class_docstring = getattr(obj, '__doc__', None)
init_docstring = getattr(obj.__init__, '__doc__', None) if hasattr(obj, '__init__') else None

docstring = class_docstring or init_docstring or ''

doc = NumpyDocString(docstring)
params = doc["Parameters"]
comments: Dict[str, str] = {}

for line in params:
# It looks like when there's not a space after the parameter name,
# numpydocstring parses it incorrectly.
name_bad = line[0]
name = name_bad.split(":")[0]

# Sometimes the line has 3 fields, sometimes it has 4 fields.
comment = "\n".join(line[-1])

comments[name] = comment

return comments

def _auto_config(cla55: Type[T]) -> Config[T]:
"""
Expand All @@ -143,8 +219,8 @@ def _auto_config(cla55: Type[T]) -> Config[T]:
"""
typ3 = _get_config_type(cla55)

# Don't include self
names_to_ignore = {"self"}
# Don't include self, or vocab
names_to_ignore = {"self", "vocab"}

# Hack for RNNs
if cla55 in [torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU]:
Expand All @@ -160,6 +236,7 @@ def _auto_config(cla55: Type[T]) -> Config[T]:
names_to_ignore.add("tensor")

argspec = inspect.getfullargspec(function_to_inspect)
comments = _docspec_comments(cla55)

items: List[ConfigItem] = []

Expand All @@ -175,11 +252,16 @@ def _auto_config(cla55: Type[T]) -> Config[T]:
if name in names_to_ignore:
continue
annotation = argspec.annotations.get(name)
comment = comments.get(name)

# Don't include Model, the only place you'd specify that is top-level.
if annotation == Model:
continue

# Don't include DataIterator, the only place you'd specify that is top-level.
if annotation == DataIterator:
continue

# Don't include params for an Optimizer
if torch.optim.Optimizer in getattr(cla55, '__bases__', ()) and name == "params":
continue
Expand All @@ -192,7 +274,15 @@ def _auto_config(cla55: Type[T]) -> Config[T]:
if cla55 == Trainer and annotation == torch.optim.Optimizer:
annotation = AllenNLPOptimizer

items.append(ConfigItem(name, annotation, default))
# Hack in embedding num_embeddings as optional (it can be inferred from the pretrained file)
if cla55 == Embedding and name == "num_embeddings":
default = None

items.append(ConfigItem(name, annotation, default, comment))

# More hacks, Embedding
if cla55 == Embedding:
items.insert(1, ConfigItem("pretrained_file", str, None))

return Config(items, typ3=typ3)

Expand All @@ -216,10 +306,26 @@ def render_config(config: Config, indent: str = "") -> str:
"}\n"
])

def is_configurable(obj) -> bool:

def _remove_optional(typ3: type) -> type:
origin = getattr(typ3, '__origin__', None)
args = getattr(typ3, '__args__', None)

if origin == Union and len(args) == 2 and args[-1] == type(None):
return _remove_optional(args[0])
else:
return typ3

def is_configurable(typ3: type) -> bool:
# Throw out optional:
typ3 = _remove_optional(typ3)

# Anything with a from_params method is itself configurable.
# So are regularizers even though they don't.
return hasattr(obj, 'from_params') or obj == Regularizer
return any([
hasattr(typ3, 'from_params'),
typ3 == Regularizer,
])

def _render(item: ConfigItem, indent: str = "") -> str:
"""
Expand Down
2 changes: 1 addition & 1 deletion allennlp/modules/token_embedders/embedding.py
Expand Up @@ -38,7 +38,7 @@ class Embedding(TokenEmbedder):
Parameters
----------
num_embeddings :, int:
num_embeddings : int:
Size of the dictionary of embeddings (vocabulary size).
embedding_dim : int
The size of each embedding vector.
Expand Down

0 comments on commit d844a9a

Please sign in to comment.