Skip to content

Commit

Permalink
support strong check for variables
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <this@zyc.ai>
  • Loading branch information
ZhiyuanChen committed Jun 21, 2023
1 parent ad2ef51 commit e4d179f
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 12 deletions.
5 changes: 5 additions & 0 deletions chanfig/flat_dict.py
Expand Up @@ -278,6 +278,11 @@ def __delattr__(self, name: Any) -> None:
def __missing__(self, name: Any) -> Any: # pylint: disable=R1710
raise KeyError(name)

def validate(self) -> None:
for value in self.values():
if isinstance(value, Variable):
value.validate()

def getattr(self, name: str, default: Any = Null) -> Any:
r"""
Get attribute of `FlatDict`.
Expand Down
6 changes: 6 additions & 0 deletions chanfig/nested_dict.py
Expand Up @@ -24,6 +24,7 @@
from .default_dict import DefaultDict
from .flat_dict import PathStr
from .utils import _K, _V, Null, apply, apply_
from .variable import Variable

if TYPE_CHECKING:
from torch import device as TorchDevice
Expand Down Expand Up @@ -397,6 +398,11 @@ def delete(self, name: Any) -> None:
raise KeyError(name) from None
super().delete(name)

def validate(self) -> None:
for value in self.all_values():
if isinstance(value, Variable):
value.validate()

def pop(self, name: Any, default: Any = Null) -> Any:
r"""
Pop value from `NestedDict`.
Expand Down
4 changes: 0 additions & 4 deletions chanfig/utils.py
Expand Up @@ -23,8 +23,6 @@

from yaml import SafeDumper, SafeLoader

from .variable import Variable

PathStr = Union[PathLike, str, bytes]
File = Union[PathStr, IO]

Expand Down Expand Up @@ -150,8 +148,6 @@ class JsonEncoder(JSONEncoder):
"""

def default(self, o: Any) -> Any:
if isinstance(o, Variable):
return o.value
if hasattr(o, "__json__"):
return o.__json__()
return super().default(o)
Expand Down
69 changes: 61 additions & 8 deletions chanfig/variable.py
Expand Up @@ -16,10 +16,12 @@

from contextlib import contextmanager
from copy import copy
from typing import Any, Callable, List, Mapping, Optional
from typing import Any, Callable, Generic, List, Mapping, Optional

from .utils import _V, Null

class Variable:

class Variable(Generic[_V]):
r"""
Mutable wrapper for immutable objects.
Expand Down Expand Up @@ -70,10 +72,25 @@ class Variable:
"""

wrap_type: bool = True
storage: List[Any]

def __init__(self, value) -> None:
self.storage = [value]
_storage: List[Any]
_type: Optional[type] = None
_choices: Optional[List] = None
_validator: Optional[Callable] = None
_required: Optional[bool] = False

def __init__( # pylint: disable=R0913
self,
value: Optional[Any] = Null,
type: Optional[type] = None, # pylint: disable=W0622
choices: Optional[List] = None,
validator: Optional[Callable] = None,
required: Optional[bool] = False,
) -> None:
self._storage = [value]
self._type = type
self._choices = choices
self._validator = validator
self._required = required

@property # type: ignore
def __class__(self) -> type:
Expand All @@ -85,15 +102,48 @@ def value(self) -> Any:
Fetch the object wrapped in `Variable`.
"""

return self.storage[0]
return self._storage[0]

@value.setter
def value(self, value) -> None:
r"""
Assign value to the object wrapped in `Variable`.
"""

self.storage[0] = self._get_value(value)
self.validate(value)
self._storage[0] = self._get_value(value)

@property
def storage(self) -> List[Any]:
r"""
Storage of `Variable`.
"""

return self._storage

@storage.setter
def storage(self, *args, **kwargs) -> None:
raise AttributeError("Cannot set storage.")

def validate(self, *args) -> None:
r"""
Validate if the value is valid.
"""

if len(args) == 0:
value = self.value
elif len(args) == 1:
value = args[0]
else:
raise ValueError("Too many arguments.")
if self._required and value is Null:
raise ValueError("Value is required.")
if self._type is not None and not isinstance(value, self._type):
raise TypeError(f"Value {value} is not of type {self._type}.")
if self._choices is not None and value not in self._choices:
raise ValueError(f"Value {value} is not in choices {self._choices}.")
if self._validator is not None and not self._validator(value):
raise ValueError(f"Value {value} is not valid.")

@property
def dtype(self) -> type:
Expand Down Expand Up @@ -384,5 +434,8 @@ def __repr__(self):
def __str__(self):
return self.value if isinstance(self, str) else str(self.value)

def __json__(self):
return self.value

def __contains__(self, name):
return name in self.value
70 changes: 70 additions & 0 deletions tests/test_variable.py
@@ -0,0 +1,70 @@
from pytest import raises

from chanfig import Variable


class Test:
str_var = Variable("CHANFIG", str, validator=lambda x: x.isupper(), choices=["CHANFIG", "CHANG", "LIU"])
int_var = Variable(0, int, validator=lambda x: x > 0, choices=[1, 2, 3])
float_var = Variable(1e-2, float, validator=lambda x: 0.0 <= x < 1.0, choices=[1e-2, 3e-3, 5e-4])
complex_var = Variable(1 + 2j, complex, validator=lambda x: x.real > 0.0, choices=[1 + 2j, 3 + 4j, 5 + 6j])
bool_var = Variable(True, bool)
required_var = Variable(required=True)

def test_str(self):
assert self.str_var.value == "CHANFIG"
self.str_var.value = "CHANG"
assert self.str_var.value == "CHANG"
self.str_var.set("LIU")
assert self.str_var.value == "LIU"
with raises(TypeError):
self.str_var.value = 0
with raises(ValueError):
self.str_var.value = "chang"
with raises(ValueError):
self.str_var.value = "FAIL"

def test_int(self):
assert self.int_var.value == 0
self.int_var.value = 1
assert self.int_var.value == 1
self.int_var.set(2)
assert self.int_var.value == 2
with raises(TypeError):
self.int_var.value = 1.0
with raises(ValueError):
self.int_var.value = 4
with raises(ValueError):
self.int_var.value = -1

def test_float(self):
assert self.float_var.value == 1e-2
self.float_var.value = 3e-3
assert self.float_var.value == 3e-3
self.float_var.set(5e-4)
assert self.float_var.value == 5e-4
with raises(TypeError):
self.float_var.value = 0
with raises(ValueError):
self.float_var.value = 0.4
with raises(ValueError):
self.float_var.value = -1.0

def test_complex(self):
assert self.complex_var.value == 1 + 2j
self.complex_var.value = 3 + 4j
assert self.complex_var.value == 3 + 4j
self.complex_var.set(5 + 6j)
assert self.complex_var.value == 5 + 6j
with raises(TypeError):
self.complex_var.value = 1
with raises(ValueError):
self.complex_var.value = 7 + 8j
with raises(ValueError):
self.complex_var.value = -1 + 2j

def test_required(self):
with raises(ValueError):
self.required_var.validate()
self.required_var.set("valid")
self.required_var.validate()

0 comments on commit e4d179f

Please sign in to comment.