Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix sync jobs #8746

Merged
merged 6 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 17 additions & 1 deletion packages/syft/src/syft/service/api/api_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,23 @@ def call(
)
if isinstance(custom_endpoint, SyftError):
return custom_endpoint
return Ok(custom_endpoint.exec(context, *args, **kwargs))

exec_result = custom_endpoint.exec(context, *args, **kwargs)

if isinstance(exec_result, SyftError):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets add types for these in our Tech Debt cycle, the user should be able to return UserSyftErrors that we can pass back as "good" results.

Also making it so they don't need to import syft as sy or other things in their code would be good.

return Ok(exec_result)

action_obj = ActionObject.from_obj(exec_result)
action_service = cast(ActionService, context.node.get_service(ActionService))
result = action_service.set_result_to_store(
context=context,
result_action_object=action_obj,
has_result_read_permission=True,
)
if result.is_err():
return SyftError(message=f"Failed to set result to store: {result.err()}")

return Ok(result.ok())

@service_method(path="api.call_public", name="call_public", roles=GUEST_ROLE_LEVEL)
def call_public(
Expand Down
8 changes: 6 additions & 2 deletions packages/syft/src/syft/service/job/job_stash.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@ def resolve(self) -> Any | SyftNotReady:

def get_sync_dependencies(self, context: AuthedServiceContext) -> list[UID]: # type: ignore
dependencies = []
if self.result is not None:
if self.result is not None and isinstance(self.result, ActionObject):
dependencies.append(self.result.id.id)

if self.log_id:
Expand Down Expand Up @@ -864,7 +864,11 @@ def get_by_result_id(
else:
res = res.ok()
# beautiful query
res = [x for x in res if x.result is not None and x.result.id.id == res_id]
res = [
x
for x in res
if isinstance(x.result, ActionObject) and x.result.id.id == res_id
]
if len(res) == 0:
return Ok(None)
elif len(res) > 1:
Expand Down
25 changes: 19 additions & 6 deletions packages/syft/src/syft/service/request/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def undo(self, context: ChangeContext) -> Result[SyftSuccess, SyftError]:

def __repr_syft_nested__(self) -> str:
return f"Apply <b>{self.apply_permission_type}</b> to \
<i>{self.linked_obj.object_type.__canonical_name__}:{self.linked_obj.object_uid.short()}</i>"
<i>{self.linked_obj.object_type.__canonical_name__}:{self.linked_obj.object_uid.short()}</i>."


@serializable()
Expand Down Expand Up @@ -435,6 +435,19 @@ def _repr_html_(self) -> Any:

"""

@property
def html_description(self) -> str:
desc = " ".join([x.__repr_syft_nested__() for x in self.changes])
# desc = desc.replace('\n', '')
# desc = desc.replace('<br>', '\n')
desc = desc.replace(". ", ".\n\n")
desc = desc.replace("<b>", "")
desc = desc.replace("</b>", "")
desc = desc.replace("<i>", "")
desc = desc.replace("</i>", "")

return desc

def _coll_repr_(self) -> dict[str, str | dict[str, str]]:
if self.status == RequestStatus.APPROVED:
badge_color = "badge-green"
Expand All @@ -452,7 +465,7 @@ def _coll_repr_(self) -> dict[str, str | dict[str, str]]:
]

return {
"Description": " ".join([x.__repr_syft_nested__() for x in self.changes]),
"Description": self.html_description,
"Requested By": "\n".join(user_data),
"Status": status_badge,
}
Expand Down Expand Up @@ -1202,15 +1215,15 @@ def __repr_syft_nested__(self) -> str:
f"Request to change <b>{self.code.service_func_name}</b> "
f"(Pool Id: <b>{self.code.worker_pool_name}</b>) "
)
msg += "to permission <b>RequestStatus.APPROVED</b>"
msg += "to permission <strong>RequestStatus.APPROVED.</strong>"
if self.nested_solved:
if self.link.nested_codes == {}: # type: ignore
msg += ". No nested requests"
msg += "No nested requests."
else:
msg += ".<br><br>This change requests the following nested functions calls:<br>"
msg += "<br><br>This change requests the following nested functions calls:<br>"
msg += self.nested_repr()
else:
msg += ". Nested Requests not resolved"
msg += "Nested Requests not resolved."
return msg

def _repr_markdown_(self, wrap_as_python: bool = True, indent: int = 0) -> str:
Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/service/sync/sync_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def get_all_syncable_items(
elif isinstance(obj, Job) and obj.result is not None:
if isinstance(obj.result, ActionObject):
obj.result = obj.result.as_empty()
action_object_ids.add(obj.result.id)
action_object_ids.add(obj.result.id)

for uid in action_object_ids:
action_object = context.node.get_service("actionservice").get(
Expand Down