From fdfbd139b31057bc493a3a6c0b39c0325131c4c1 Mon Sep 17 00:00:00 2001 From: Rehno Lindeque Date: Mon, 1 Nov 2021 12:29:14 -0400 Subject: [PATCH] util.ImmutableValidatedObject: Support Optional[ImmutableValidatedObject] --- nixops/util.py | 56 ++++++++++++++++++++++++----------------- tests/unit/test_util.py | 16 +++++++++++- 2 files changed, 48 insertions(+), 24 deletions(-) diff --git a/nixops/util.py b/nixops/util.py index 6b6a1a51d..ee90b088a 100644 --- a/nixops/util.py +++ b/nixops/util.py @@ -18,6 +18,7 @@ import typeguard import inspect import shlex +import typing from typing import ( Callable, List, @@ -33,6 +34,8 @@ TypeVar, Generic, Iterable, + Sequence, + Type, ) import nixops.util @@ -132,7 +135,7 @@ def __init__(self, *args: ImmutableValidatedObject, **kwargs): kw = {} for arg in args: if not isinstance(arg, ImmutableValidatedObject): - raise TypeError("Arg not a Immutablevalidatedobject instance") + raise TypeError("Arg not a ImmutableValidatedObject instance") kw.update(dict(arg)) kw.update(kwargs) @@ -143,30 +146,37 @@ def __init__(self, *args: ImmutableValidatedObject, **kwargs): continue anno.update(x.__annotations__) - def _transform_value(key: Any, value: Any) -> Any: - ann = anno.get(key) - + def _transform_value(value: Any, value_type: Optional[Type]) -> Any: # Untyped, pass through - if not ann: + if not value_type: return value - if inspect.isclass(ann) and issubclass(ann, ImmutableValidatedObject): - value = ann(**value) - - # Support Sequence[ImmutableValidatedObject] - if isinstance(value, tuple) and not isinstance(ann, str): - new_value = [] - for v in value: - for subann in ann.__args__: - if inspect.isclass(subann) and issubclass( - subann, ImmutableValidatedObject - ): - new_value.append(subann(**v)) - else: - new_value.append(v) - value = tuple(new_value) - - typeguard.check_type(key, value, ann) + # Support ImmutableValidatedObject + if ( + isinstance(value, Mapping) + and inspect.isclass(value_type) + and issubclass(value_type, ImmutableValidatedObject) + ): + value = value_type(**value) + + type_origin = typing.get_origin(value_type) # type: ignore[attr-defined] + type_args = tuple(set(typing.get_args(value_type)) - {type(None)}) # type: ignore[attr-defined] + if ( + type_origin is not None + and len(type_args) == 1 + and inspect.isclass(type_args[0]) + and issubclass(type_args[0], ImmutableValidatedObject) + ): + # Support Sequence[ImmutableValidatedObject] + if isinstance(value, Sequence) and issubclass(tuple, type_origin): + value = tuple(_transform_value(v, type_args[0]) for v in value) + + # Support Optional[ImmutableValidatedObject] + if type_origin is Union: + if value is not None: + value = _transform_value(value, type_args[0]) + + typeguard.check_type(key, value, value_type) return value @@ -181,7 +191,7 @@ def _transform_value(key: Any, value: Any) -> Any: # is set this attribute is set on self before __init__ is called default = getattr(self, key) if hasattr(self, key) else None value = kw.get(key, default) - setattr(self, key, _transform_value(key, value)) + setattr(self, key, _transform_value(value, anno.get(key))) self._frozen = True diff --git a/tests/unit/test_util.py b/tests/unit/test_util.py index 40e308f44..f28cea0ef 100644 --- a/tests/unit/test_util.py +++ b/tests/unit/test_util.py @@ -1,4 +1,4 @@ -from typing import Sequence +from typing import Sequence, Optional import json from nixops.logger import Logger from io import StringIO @@ -115,3 +115,17 @@ class WithSequence(util.ImmutableValidatedObject): subs: Sequence[SubResource] WithSequence(subs=[SubResource(x=1), SubResource(x=2)]) + + # Test Optional[ImmutableValidatedObject] + class WithOptional(util.ImmutableValidatedObject): + sub: Optional[SubResource] + sub_none: Optional[SubResource] + subs: Optional[Sequence[SubResource]] + subs_none: Optional[Sequence[SubResource]] + + WithOptional( + sub=SubResource(x=0), + sub_none=None, + subs=[SubResource(x=1), SubResource(x=2)], + subs_none=None, + )