Skip to content

Commit

Permalink
Merge pull request #8732 from OpenMined/add_user_to_context
Browse files Browse the repository at this point in the history
Add UserView to TwinAPI context
  • Loading branch information
IonesioJunior committed Apr 23, 2024
2 parents b05d753 + d86009e commit 6918621
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions packages/syft/src/syft/service/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ...types.uid import UID
from ..context import AuthedServiceContext
from ..response import SyftError
from ..user.user import UserView

NOT_ACCESSIBLE_STRING = "N / A"

Expand All @@ -48,6 +49,7 @@ class TwinAPIAuthedContext(AuthedServiceContext):
__canonical_name__ = "AuthedServiceContext"
__version__ = SYFT_OBJECT_VERSION_1

user: UserView | None = None
settings: dict[str, Any] | None = None
code: HelperFunctionSet | None = None
state: dict[Any, Any] | None = None
Expand All @@ -58,7 +60,8 @@ class TwinAPIContextView(SyftObject):
__canonical_name__ = "TwinAPIContextView"
__version__ = SYFT_OBJECT_VERSION_1

__repr_attrs__ = ["settings", "state"]
__repr_attrs__ = ["settings", "state", "user"]
user: UserView
settings: dict[str, Any]
state: dict[Any, Any]

Expand Down Expand Up @@ -193,6 +196,9 @@ def build_internal_context(

helper_function_set = HelperFunctionSet(helper_function_dict)

user_service = context.node.get_service("userservice")
user = user_service.get_current_user(context)

return TwinAPIAuthedContext(
credentials=context.credentials,
role=context.role,
Expand All @@ -204,6 +210,7 @@ def build_internal_context(
settings=self.settings or {},
code=helper_function_set,
state=self.state or {},
user=user,
)

def __call__(self, *args: Any, **kwargs: Any) -> Any:
Expand Down Expand Up @@ -333,7 +340,7 @@ class CreateTwinAPIEndpoint(BaseTwinAPIEndpoint):
@serializable()
class TwinAPIEndpoint(SyncableSyftObject):
# version
__canonical_name__ = "TwinAPIEndpoint"
__canonical_name__: str = "TwinAPIEndpoint"
__version__ = SYFT_OBJECT_VERSION_1

def __init__(self, **kwargs: Any) -> None:
Expand Down Expand Up @@ -555,7 +562,7 @@ def code_string(context: TransformContext) -> TransformContext:

@transform(TwinAPIAuthedContext, TwinAPIContextView)
def twin_api_context_to_twin_api_context_view() -> list[Callable]:
return [keep(["state", "settings"])]
return [keep(["state", "settings", "user"])]


@transform(CreateTwinAPIEndpoint, TwinAPIEndpoint)
Expand Down

0 comments on commit 6918621

Please sign in to comment.