Skip to content

Commit

Permalink
Merge pull request #8599 from OpenMined/fix_api_bugs
Browse files Browse the repository at this point in the history
Fix api bugs
  • Loading branch information
shubham3121 committed Mar 19, 2024
2 parents 2933921 + 847b84f commit 32e0f11
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 15 deletions.
31 changes: 19 additions & 12 deletions packages/syft/src/syft/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,20 @@ def unwrap_and_migrate_annotation(annotation: Any, object_versions: dict) -> Any
return migrated_annotation[0]


def result_needs_api_update(api_call_result: Any) -> bool:
# relative
from ..service.request.request import Request
from ..service.request.request import UserCodeStatusChange

if isinstance(api_call_result, Request) and any(
isinstance(x, UserCodeStatusChange) for x in api_call_result.changes
):
return True
if isinstance(api_call_result, SyftSuccess) and api_call_result.require_api_update:
return True
return False


@instrument
@serializable(
attrs=[
Expand Down Expand Up @@ -741,23 +755,16 @@ def make_call(self, api_call: SyftAPICall) -> Result:

if isinstance(result, OkErr):
if result.is_ok():
res = result.ok()
# we update the api when we create objects that change it
self.update_api(res)
return res
result = result.ok()
else:
return result.err()
result = result.err()
# we update the api when we create objects that change it
self.update_api(result)
return result

def update_api(self, api_call_result: Any) -> None:
# TODO: hacky stuff with typing and imports to prevent circular imports
# relative
from ..service.request.request import Request
from ..service.request.request import UserCodeStatusChange

if isinstance(api_call_result, Request) and any(
isinstance(x, UserCodeStatusChange) for x in api_call_result.changes
):
if result_needs_api_update(api_call_result):
if self.refresh_api_callback is not None:
self.refresh_api_callback()

Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/service/code/user_code_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def submit(
result = self._submit(context=context, code=code)
if result.is_err():
return SyftError(message=str(result.err()))
return SyftSuccess(message="User Code Submitted")
return SyftSuccess(message="User Code Submitted", require_api_update=True)

def _submit(
self, context: AuthedServiceContext, code: UserCode | SubmitUserCode
Expand Down
5 changes: 4 additions & 1 deletion packages/syft/src/syft/service/request/request_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ..service import TYPE_TO_SERVICE
from ..service import service_method
from ..user.user import UserView
from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL
from ..user.user_roles import GUEST_ROLE_LEVEL
from ..user.user_service import UserService
from .request import Change
Expand Down Expand Up @@ -105,7 +106,9 @@ def submit(
print("Failed to submit Request", e)
raise e

@service_method(path="request.get_all", name="get_all")
@service_method(
path="request.get_all", name="get_all", roles=DATA_SCIENTIST_ROLE_LEVEL
)
def get_all(self, context: AuthedServiceContext) -> list[Request] | SyftError:
result = self.stash.get_all(context.credentials)
if result.is_err():
Expand Down
7 changes: 6 additions & 1 deletion packages/syft/src/syft/service/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,18 @@
class SyftResponseMessage(SyftBaseModel):
message: str
_bool: bool = True
require_api_update: bool = False

def __bool__(self) -> bool:
return self._bool

def __eq__(self, other: Any) -> bool:
if isinstance(other, SyftResponseMessage):
return self.message == other.message and self._bool == other._bool
return (
self.message == other.message
and self._bool == other._bool
and self.require_api_update == other.require_api_update
)
return self._bool == other

def __repr__(self) -> str:
Expand Down

0 comments on commit 32e0f11

Please sign in to comment.