Skip to content

Commit

Permalink
[lebrice#241] fixes the issue where defaults from the config path wer…
Browse files Browse the repository at this point in the history
…e not being passed when using subparsers (field_wrapper.py). Additionally fixes an issue where subdataclasses were requiring every value to be defaulted in the config path, instead of falling back to the default in the dataclass definition if it wasn't (dataclass_wrapper.py)
  • Loading branch information
aliounis committed Apr 17, 2023
1 parent f4dee55 commit e5360e4
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
11 changes: 9 additions & 2 deletions simple_parsing/wrappers/dataclass_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,15 @@ def defaults(self) -> list[DataclassT | dict[str, Any] | None | Literal[argparse
self._defaults = []
for default in self.parent.defaults:
if default not in (None, argparse.SUPPRESS):
default = getattr(default, self.name)
self._defaults.append(default)
# we need to check here if the default has been provided.
# If not we'll use the default_value option function
if hasattr(default, self.name):
default = getattr(default, self.name)
else:
default = utils.default_value(self._field)
if default is MISSING:
continue
self._defaults.append(default)
else:
default_field_value = utils.default_value(self._field)
if default_field_value is MISSING:
Expand Down
7 changes: 6 additions & 1 deletion simple_parsing/wrappers/field_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,7 +1030,12 @@ def add_subparsers(self, parser: ArgumentParser):
# Just for typing correctness, as we didn't explicitly change
# the return type of subparsers.add_parser method.)
subparser = cast("ArgumentParser", subparser)
subparser.add_arguments(dataclass_type, dest=self.dest)
# we need to propagate the defaults down to the sub dataclass if they've been set.
# there may need to be some error handling here in case the use has specified the wrong values for the default.
if isinstance(self.default, dict) and self.default.get(subcommand, None) is not None:
subparser.add_arguments(dataclass_type, dest=self.dest, default=dataclass_type(**self.default[subcommand]))
else:
subparser.add_arguments(dataclass_type, dest=self.dest)

def equivalent_argparse_code(self):
arg_options = self.arg_options.copy()
Expand Down

0 comments on commit e5360e4

Please sign in to comment.