Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix warnings in dataset transform + privateattr check #8570

Merged
merged 11 commits into from Mar 14, 2024
1 change: 1 addition & 0 deletions packages/syft/src/syft/service/action/action_object.py
Expand Up @@ -232,6 +232,7 @@ class ActionObjectPointer:
"__repr_str__", # pydantic
"__repr_args__", # pydantic
"__post_init__", # syft
"__validate_private_attrs__", # syft
"id", # syft
"to_mongo", # syft 🟡 TODO 23: Add composeable / inheritable object passthrough attrs
"__attr_searchable__", # syft
Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/service/code/user_code.py
Expand Up @@ -118,7 +118,7 @@ class UserCodeStatusCollection(SyncableSyftObject):
status_dict: dict[NodeIdentity, tuple[UserCodeStatus, str]] = {}
user_code_link: LinkedObject

def get_diffs(self, ext_obj: Any) -> list[AttrDiff]:
def syft_get_diffs(self, ext_obj: Any) -> list[AttrDiff]:
# relative
from ...service.sync.diff_state import AttrDiff

Expand Down
10 changes: 5 additions & 5 deletions packages/syft/src/syft/service/dataset/dataset.py
Expand Up @@ -756,17 +756,17 @@ def create_and_store_twin(context: TransformContext) -> TransformContext:


def infer_shape(context: TransformContext) -> TransformContext:
if context.output is not None and context.output["shape"] is None:
if context.output is None:
raise ValueError(f"{context}'s output is None. No transformation happened")
if context.output["shape"] is None:
if context.obj is not None and not _is_action_data_empty(context.obj.mock):
context.output["shape"] = get_shape_or_len(context.obj.mock)
else:
print(f"{context}'s output is None. No transformation happened")
return context


def set_data_subjects(context: TransformContext) -> TransformContext | SyftError:
if context.output is None:
return SyftError(f"{context}'s output is None. No transformation happened")
raise ValueError(f"{context}'s output is None. No transformation happened")
if context.node is None:
return SyftError(
"f{context}'s node is None, please log in. No trasformation happened"
Expand Down Expand Up @@ -796,7 +796,7 @@ def add_default_node_uid(context: TransformContext) -> TransformContext:
if context.output["node_uid"] is None and context.node is not None:
context.output["node_uid"] = context.node.id
else:
print(f"{context}'s output is None. No transformation happened.")
raise ValueError(f"{context}'s output is None. No transformation happened")
return context


Expand Down
Expand Up @@ -151,7 +151,7 @@ def add_msg_creation_time(context: TransformContext) -> TransformContext:
if context.output is not None:
context.output["created_at"] = DateTime.now()
else:
print(f"{context}'s output is None. No transformation happened.")
raise ValueError(f"{context}'s output is None. No transformation happened")
return context


Expand Down
9 changes: 7 additions & 2 deletions packages/syft/src/syft/service/policy/policy.py
Expand Up @@ -441,11 +441,16 @@ def apply_output(

class UserOutputPolicy(OutputPolicy):
__canonical_name__ = "UserOutputPolicy"

# Do not validate private attributes of user-defined policies, User annotations can
# contain any type and throw a NameError when resolving.
__validate_private_attrs__ = False
pass


class UserInputPolicy(InputPolicy):
__canonical_name__ = "UserInputPolicy"
__validate_private_attrs__ = False
pass


Expand Down Expand Up @@ -572,7 +577,7 @@ def generate_unique_class_name(context: TransformContext) -> TransformContext:
unique_name = f"{service_class_name}_{context.credentials}_{code_hash}"
context.output["unique_name"] = unique_name
else:
print(f"{context}'s output is None. No transformation happened.")
raise ValueError(f"{context}'s output is None. No transformation happened")

return context

Expand Down Expand Up @@ -696,7 +701,7 @@ def compile_code(context: TransformContext) -> TransformContext:
+ context.output["parsed_code"]
)
else:
print(f"{context}'s output is None. No transformation happened.")
raise ValueError(f"{context}'s output is None. No transformation happened")

return context

Expand Down
6 changes: 5 additions & 1 deletion packages/syft/src/syft/types/syft_object.py
Expand Up @@ -419,6 +419,7 @@ def make_id(cls, values: Any) -> Any:
__attr_custom_repr__: ClassVar[list[str] | None] = (
None # show these in html repr of an object
)
__validate_private_attrs__: ClassVar[bool] = True

def __syft_get_funcs__(self) -> list[tuple[str, Signature]]:
funcs = print_type_cache[type(self)]
Expand Down Expand Up @@ -577,11 +578,14 @@ def __post_init__(self) -> None:
pass

def _syft_set_validate_private_attrs_(self, **kwargs: Any) -> None:
if not self.__validate_private_attrs__:
return
# Validate and set private attributes
# https://github.com/pydantic/pydantic/issues/2105
annotations = typing.get_type_hints(self.__class__, localns=locals())
for attr, decl in self.__private_attributes__.items():
value = kwargs.get(attr, decl.get_default())
var_annotation = self.__annotations__.get(attr)
var_annotation = annotations.get(attr)
if value is not PydanticUndefined:
if var_annotation is not None:
# Otherwise validate value against the variable annotation
Expand Down