Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add UserView to TwinAPI context #8732

Merged
merged 3 commits into from
Apr 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 10 additions & 3 deletions packages/syft/src/syft/service/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ...types.uid import UID
from ..context import AuthedServiceContext
from ..response import SyftError
from ..user.user import UserView

NOT_ACCESSIBLE_STRING = "N / A"

Expand All @@ -48,6 +49,7 @@ class TwinAPIAuthedContext(AuthedServiceContext):
__canonical_name__ = "AuthedServiceContext"
__version__ = SYFT_OBJECT_VERSION_1

user: UserView | None = None
settings: dict[str, Any] | None = None
code: HelperFunctionSet | None = None
state: dict[Any, Any] | None = None
Expand All @@ -58,7 +60,8 @@ class TwinAPIContextView(SyftObject):
__canonical_name__ = "TwinAPIContextView"
__version__ = SYFT_OBJECT_VERSION_1

__repr_attrs__ = ["settings", "state"]
__repr_attrs__ = ["settings", "state", "user"]
user: UserView
settings: dict[str, Any]
state: dict[Any, Any]

Expand Down Expand Up @@ -193,6 +196,9 @@ def build_internal_context(

helper_function_set = HelperFunctionSet(helper_function_dict)

user_service = context.node.get_service("userservice")
user = user_service.get_current_user(context)

return TwinAPIAuthedContext(
credentials=context.credentials,
role=context.role,
Expand All @@ -204,6 +210,7 @@ def build_internal_context(
settings=self.settings or {},
code=helper_function_set,
state=self.state or {},
user=user,
)

def __call__(self, *args: Any, **kwargs: Any) -> Any:
Expand Down Expand Up @@ -333,7 +340,7 @@ class CreateTwinAPIEndpoint(BaseTwinAPIEndpoint):
@serializable()
class TwinAPIEndpoint(SyncableSyftObject):
# version
__canonical_name__ = "TwinAPIEndpoint"
__canonical_name__: str = "TwinAPIEndpoint"
__version__ = SYFT_OBJECT_VERSION_1

def __init__(self, **kwargs: Any) -> None:
Expand Down Expand Up @@ -555,7 +562,7 @@ def code_string(context: TransformContext) -> TransformContext:

@transform(TwinAPIAuthedContext, TwinAPIContextView)
def twin_api_context_to_twin_api_context_view() -> list[Callable]:
return [keep(["state", "settings"])]
return [keep(["state", "settings", "user"])]


@transform(CreateTwinAPIEndpoint, TwinAPIEndpoint)
Expand Down