Skip to content

Commit

Permalink
A hack around lebrice#276
Browse files Browse the repository at this point in the history
  • Loading branch information
anivegesana committed Aug 11, 2023
1 parent 77e4ff5 commit aa0e1a2
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 8 deletions.
10 changes: 7 additions & 3 deletions simple_parsing/helpers/serialization/serializable.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def __init_subclass__(
register_decoding_fn(cls, cls.from_dict)

def to_dict(
self, dict_factory: type[dict] = dict, recurse: bool = True, save_dc_types: bool = False
self, dict_factory: type[dict] = dict, recurse: bool = True, save_dc_types: int | bool = False
) -> dict:
"""Serializes this dataclass to a dict.
Expand Down Expand Up @@ -597,7 +597,7 @@ def save(
obj: Any,
path: str | Path,
format: FormatExtension | None = None,
save_dc_types: bool = False,
save_dc_types: int | bool = False,
**kwargs,
) -> None:
"""Save the given dataclass or dictionary to the given file."""
Expand Down Expand Up @@ -688,7 +688,7 @@ def to_dict(
dc: DataclassT,
dict_factory: type[dict] = dict,
recurse: bool = True,
save_dc_types: bool = False,
save_dc_types: int | bool = False,
) -> dict:
"""Serializes this dataclass to a dict.
Expand Down Expand Up @@ -720,6 +720,10 @@ def to_dict(
else:
d[DC_TYPE_KEY] = module + "." + class_name

# Decrement save_dc_types if it is an int
if save_dc_types is not True and save_dc_types > 0:
save_dc_types -= 1

for f in fields(dc):
name = f.name
value = getattr(dc, name)
Expand Down
9 changes: 9 additions & 0 deletions simple_parsing/helpers/subgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from typing_extensions import TypeAlias

from simple_parsing.helpers.serialization.serializable import to_dict
from simple_parsing.utils import DataclassT, is_dataclass_instance, is_dataclass_type

logger = get_logger(__name__)
Expand Down Expand Up @@ -112,6 +113,14 @@ def subgroups(
metadata["subgroup_default"] = default
metadata["subgroup_dataclass_types"] = {}

def _encoding_fn(value: Any) -> dict:
"""Custom encoding function that will simply represent the value as the
the key in the dict rather than the value itself.
"""
return to_dict(value, save_dc_types=1)

kwargs.setdefault("encoding_fn", _encoding_fn)

subgroup_dataclass_types: dict[Key, type[DataclassT]] = {}
choices = subgroups.keys()

Expand Down
18 changes: 13 additions & 5 deletions simple_parsing/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
from collections import defaultdict
from logging import getLogger
from pathlib import Path
from typing import Any, Callable, Sequence, Type, overload
from typing import Any, Callable, Sequence, Type, cast, overload

from simple_parsing.helpers.subgroups import SubgroupKey
from simple_parsing.wrappers.dataclass_wrapper import DataclassWrapperType

from . import utils
from .conflicts import ConflictResolution, ConflictResolver
from .help_formatter import SimpleHelpFormatter
from .helpers.serialization.serializable import read_file
from .helpers.serialization.serializable import DC_TYPE_KEY, from_dict, read_file
from .utils import (
Dataclass,
DataclassT,
Expand Down Expand Up @@ -646,7 +646,6 @@ def _resolve_subgroups(
if subgroup_field.subgroup_default is dataclasses.MISSING:
assert argument_options["required"]
else:
assert argument_options["default"] is subgroup_field.subgroup_default
assert not is_dataclass_instance(argument_options["default"])

# TODO: Do we really need to care about this "SUPPRESS" stuff here?
Expand Down Expand Up @@ -674,7 +673,7 @@ def _resolve_subgroups(
# here.
subgroup_dict = subgroup_field.subgroup_choices
chosen_subgroup_key: SubgroupKey = getattr(parsed_args, dest)
assert chosen_subgroup_key in subgroup_dict
assert isinstance(chosen_subgroup_key, dict) or chosen_subgroup_key in subgroup_dict

# Changing the default value of the (now parsed) field for the subgroup choice,
# just so it shows (default: {chosen_subgroup_key}) on the command-line.
Expand All @@ -687,7 +686,11 @@ def _resolve_subgroups(
f"{chosen_subgroup_key!r}"
)

default_or_dataclass_fn = subgroup_dict[chosen_subgroup_key]
if isinstance(chosen_subgroup_key, dict):
default_or_dataclass_fn = from_dict(cast(Type[Dataclass], None), chosen_subgroup_key)
else:
default_or_dataclass_fn = subgroup_dict[chosen_subgroup_key]

if is_dataclass_instance(default_or_dataclass_fn):
# The chosen value in the subgroup dict is a frozen dataclass instance.
default = default_or_dataclass_fn
Expand Down Expand Up @@ -1124,6 +1127,11 @@ def _create_dataclass_instance(
# None.
# TODO: (BUG!) This doesn't distinguish the case where the defaults are passed via the
# command-line from the case where no arguments are passed at all!
dc_type = constructor_args.pop(DC_TYPE_KEY, None)
if dc_type is not None:
from simple_parsing.helpers.serialization.serializable import _locate
constructor = _locate(dc_type)

if wrapper.optional and wrapper.default is None:
for field_wrapper in wrapper.fields:

Expand Down

0 comments on commit aa0e1a2

Please sign in to comment.