Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make AbstractNode nonoptional #8667

Merged
merged 5 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 4 additions & 7 deletions packages/syft/src/syft/service/action/action_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# stdlib
import importlib
from typing import Any
from typing import cast

# third party
import numpy as np
Expand All @@ -10,7 +9,6 @@
from result import Result

# relative
from ...abstract_node import AbstractNode
from ...node.credentials import SyftVerifyKey
from ...serde.serializable import serializable
from ...types.datetime import DateTime
Expand Down Expand Up @@ -59,7 +57,7 @@ def np_array(self, context: AuthedServiceContext, data: Any) -> Any:
if not isinstance(data, np.ndarray):
data = np.array(data)
# cast here since we are sure that AuthedServiceContext has a node
context.node = cast(AbstractNode, context.node)

np_obj = NumpyArrayObject(
dtype=data.dtype,
shape=data.shape,
Expand Down Expand Up @@ -127,7 +125,7 @@ def _set(
action_object = action_object.private
else:
action_object = action_object.mock
context.node = cast(AbstractNode, context.node)

action_object.syft_point_to(context.node.id)
return Ok(action_object)
return result.err()
Expand Down Expand Up @@ -267,7 +265,7 @@ def get_pointer(
self, context: AuthedServiceContext, uid: UID
) -> Result[ActionObjectPointer, str]:
"""Get a pointer from the action store"""
context.node = cast(AbstractNode, context.node)

result = self.store.get_pointer(
uid=uid, credentials=context.credentials, node_uid=context.node.id
)
Expand Down Expand Up @@ -443,7 +441,7 @@ def set_result_to_store(
output_readers = []

read_permission = ActionPermission.READ
context.node = cast(AbstractNode, context.node)

result_action_object._set_obj_location_(
context.node.id,
context.credentials,
Expand Down Expand Up @@ -659,7 +657,6 @@ def execute(
# relative
from .plan import Plan

context.node = cast(AbstractNode, context.node)
if action.action_type == ActionType.CREATEOBJECT:
result_action_object = Ok(action.create_object)
# print(action.create_object, "already in blob storage")
Expand Down
10 changes: 1 addition & 9 deletions packages/syft/src/syft/service/blob_storage/service.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
# stdlib
from pathlib import Path
from typing import cast

# third party
import requests

# relative
from ...abstract_node import AbstractNode
from ...serde.serializable import serializable
from ...service.action.action_object import ActionObject
from ...store.blob_storage import BlobRetrieval
Expand Down Expand Up @@ -88,8 +86,6 @@ def mount_azure(
return SyftError(message=res.value)
remote_profile = res.ok()

context.node = cast(AbstractNode, context.node)

seaweed_config = context.node.blob_storage_client.config
# we cache this here such that we can use it when reading a file from azure
# from the remote_name
Expand Down Expand Up @@ -204,7 +200,6 @@ def read(
message=f"No blob storage entry exists for uid: {uid}, or you have no permissions to read it"
)

context.node = cast(AbstractNode, context.node)
with context.node.blob_storage_client.connect() as conn:
res: BlobRetrieval = conn.read(
obj.location, obj.type_, bucket_name=obj.bucket_name
Expand All @@ -222,7 +217,6 @@ def read(
def allocate(
self, context: AuthedServiceContext, obj: CreateBlobStorageEntry
) -> BlobDepositType | SyftError:
context.node = cast(AbstractNode, context.node)
with context.node.blob_storage_client.connect() as conn:
secure_location = conn.allocate(obj)

Expand Down Expand Up @@ -305,7 +299,7 @@ def mark_write_complete(
)
if result.is_err():
return SyftError(message=f"{result.err()}")
context.node = cast(AbstractNode, context.node)

with context.node.blob_storage_client.connect() as conn:
result = conn.complete_multipart_upload(obj, etags)

Expand All @@ -324,8 +318,6 @@ def delete(
message=f"No blob storage entry exists for uid: {uid}, or you have no permissions to read it"
)

context.node = cast(AbstractNode, context.node)

try:
with context.node.blob_storage_client.connect() as conn:
file_unlinked_result = conn.delete(obj.location)
Expand Down
8 changes: 3 additions & 5 deletions packages/syft/src/syft/service/code/user_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from typing_extensions import Self

# relative
from ...abstract_node import AbstractNode
from ...abstract_node import NodeType
from ...client.api import APIRegistry
from ...client.api import NodeIdentity
Expand Down Expand Up @@ -199,7 +198,6 @@ def denied(self) -> bool:
return False

def for_user_context(self, context: AuthedServiceContext) -> UserCodeStatus:
context.node = cast(AbstractNode, context.node)
if context.node.node_type == NodeType.ENCLAVE:
keys = {status for status, _ in self.status_dict.values()}
if len(keys) == 1 and UserCodeStatus.APPROVED in keys:
Expand Down Expand Up @@ -532,8 +530,8 @@ def get_output_history(
return SyftError(
message="Execution denied, Please wait for the code to be approved"
)
node = cast(AbstractNode, context.node)
output_service = cast(OutputService, node.get_service("outputservice"))

output_service = cast(OutputService, context.node.get_service("outputservice"))
return output_service.get_by_user_code_id(context, self.id)

def store_as_history(
Expand All @@ -550,7 +548,7 @@ def store_as_history(
)

output_ids = filter_only_uids(outputs)
context.node = cast(AbstractNode, context.node)

output_service = context.node.get_service("outputservice")
output_service = cast(OutputService, output_service)
execution_result = output_service.create(
Expand Down
9 changes: 1 addition & 8 deletions packages/syft/src/syft/service/code/user_code_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from result import Result

# relative
from ...abstract_node import AbstractNode
from ...abstract_node import NodeType
from ...client.enclave_client import EnclaveClient
from ...serde.serializable import serializable
Expand Down Expand Up @@ -144,8 +143,6 @@ def _request_code_execution_inner(
message="The code to be submitted (name and content) already exists"
)

context.node = cast(AbstractNode, context.node)

worker_pool_service = context.node.get_service("SyftWorkerPoolService")
pool_result = worker_pool_service._get_worker_pool(
context,
Expand Down Expand Up @@ -262,7 +259,6 @@ def load_user_code(self, context: AuthedServiceContext) -> None:
def get_results(
self, context: AuthedServiceContext, inp: UID | UserCode
) -> list[UserCode] | SyftError:
context.node = cast(AbstractNode, context.node)
uid = inp.id if isinstance(inp, UserCode) else inp
code_result = self.stash.get_by_uid(context.credentials, uid=uid)

Expand Down Expand Up @@ -340,7 +336,7 @@ def is_execution_on_owned_args_allowed(
) -> bool | SyftError:
if context.role == ServiceRole.ADMIN:
return True
context.node = cast(AbstractNode, context.node)

user_service = context.node.get_service("userservice")
current_user = user_service.get_current_user(context=context)
return current_user.mock_execution_permission
Expand All @@ -349,7 +345,6 @@ def keep_owned_kwargs(
self, kwargs: dict[str, Any], context: AuthedServiceContext
) -> dict[str, Any] | SyftError:
"""Return only the kwargs that are owned by the user"""
context.node = cast(AbstractNode, context.node)

action_service = context.node.get_service("actionservice")

Expand Down Expand Up @@ -474,7 +469,6 @@ def _call(
return can_execute.to_result() # type: ignore

# Execute the code item
context.node = cast(AbstractNode, context.node)

action_service = context.node.get_service("actionservice")

Expand Down Expand Up @@ -541,7 +535,6 @@ def _call(
def has_code_permission(
self, code_item: UserCode, context: AuthedServiceContext
) -> SyftSuccess | SyftError:
context.node = cast(AbstractNode, context.node)
if not (
context.credentials == context.node.verify_key
or context.credentials == code_item.user_verify_key
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# stdlib
from typing import cast

# relative
from ...abstract_node import AbstractNode
from ...node.credentials import SyftVerifyKey
from ...serde.serializable import serializable
from ...store.document_store import DocumentStore
Expand Down Expand Up @@ -45,7 +43,6 @@ def submit_version(
code: SubmitUserCode | UserCode,
comment: str | None = None,
) -> SyftSuccess | SyftError:
context.node = cast(AbstractNode, context.node)
user_code_service = context.node.get_service("usercodeservice")
if isinstance(code, SubmitUserCode):
result = user_code_service._submit(context=context, code=code)
Expand Down Expand Up @@ -126,7 +123,7 @@ def fetch_histories_for_user(
result = self.stash.get_by_verify_key(
credentials=context.credentials, user_verify_key=user_verify_key
)
context.node = cast(AbstractNode, context.node)

user_code_service = context.node.get_service("usercodeservice")

def get_code(uid: UID) -> UserCode | SyftError:
Expand Down Expand Up @@ -170,7 +167,6 @@ def get_histories_for_current_user(
def get_history_for_user(
self, context: AuthedServiceContext, email: str
) -> CodeHistoriesDict | SyftError:
context.node = cast(AbstractNode, context.node)
user_service = context.node.get_service("userservice")
result = user_service.stash.get_by_email(
credentials=context.credentials, email=email
Expand All @@ -195,7 +191,6 @@ def get_histories_group_by_user(
return SyftError(message=result.err())
code_histories: list[CodeHistory] = result.ok()

context.node = cast(AbstractNode, context.node)
user_service = context.node.get_service("userservice")
result = user_service.stash.get_all(context.credentials)
if result.is_err():
Expand Down Expand Up @@ -228,7 +223,6 @@ def get_by_func_name_and_user_email(
user_email: str,
user_id: UID,
) -> list[CodeHistory] | SyftError:
context.node = cast(AbstractNode, context.node)
user_service = context.node.get_service("userservice")
user_verify_key = user_service.user_verify_key(user_email)

Expand Down
8 changes: 3 additions & 5 deletions packages/syft/src/syft/service/context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# stdlib
from typing import Any
from typing import cast

# third party
from typing_extensions import Self
Expand All @@ -24,7 +23,7 @@ class NodeServiceContext(Context, SyftObject):
__version__ = SYFT_OBJECT_VERSION_2

id: UID | None = None # type: ignore[assignment]
node: AbstractNode | None = None
node: AbstractNode


class AuthedServiceContext(NodeServiceContext):
Expand All @@ -48,7 +47,6 @@ def with_credentials(self, credentials: SyftVerifyKey, role: ServiceRole) -> Sel
return AuthedServiceContext(credentials=credentials, role=role, node=self.node)

def as_root_context(self) -> Self:
self.node = cast(AbstractNode, self.node)
return AuthedServiceContext(
credentials=self.node.verify_key, role=ServiceRole.ADMIN, node=self.node
)
Expand All @@ -71,15 +69,15 @@ class UnauthedServiceContext(NodeServiceContext):
__version__ = SYFT_OBJECT_VERSION_2

login_credentials: UserLoginCredentials
node: AbstractNode | None = None
node: AbstractNode
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This the critical change, the rest is cleanup.

role: ServiceRole = ServiceRole.NONE


class ChangeContext(SyftBaseObject):
__canonical_name__ = "ChangeContext"
__version__ = SYFT_OBJECT_VERSION_2

node: AbstractNode | None = None
node: AbstractNode
approving_user_credentials: SyftVerifyKey | None = None
requesting_user_credentials: SyftVerifyKey | None = None
extra_kwargs: dict = {}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
# stdlib
from typing import cast

# third party
from result import Result

# relative
from ...abstract_node import AbstractNode
from ...node.credentials import SyftVerifyKey
from ...serde.serializable import serializable
from ...store.document_store import BaseUIDStoreStash
Expand Down Expand Up @@ -72,7 +70,6 @@ def add(
) -> SyftSuccess | SyftError:
"""Register a data subject."""

context.node = cast(AbstractNode, context.node)
member_relationship_add = context.node.get_service_method(
DataSubjectMemberService.add
)
Expand Down Expand Up @@ -109,7 +106,6 @@ def get_all(self, context: AuthedServiceContext) -> list[DataSubject] | SyftErro
def get_members(
self, context: AuthedServiceContext, data_subject_name: str
) -> list[DataSubject] | SyftError:
context.node = cast(AbstractNode, context.node)
get_relatives = context.node.get_service_method(
DataSubjectMemberService.get_relatives
)
Expand Down
14 changes: 4 additions & 10 deletions packages/syft/src/syft/service/dataset/dataset_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,10 @@ def add(
)
if result.is_err():
return SyftError(message=str(result.err()))
if context.node is not None:
return SyftSuccess(
message=f"Dataset uploaded to '{context.node.name}'. "
f"To see the datasets uploaded by a client on this node, use command `[your_client].datasets`"
)
else:
return SyftSuccess(
message="Dataset uploaded not to a node."
"To see the datasets uploaded by a client on this node, use command `[your_client].datasets`"
)
return SyftSuccess(
message=f"Dataset uploaded to '{context.node.name}'. "
f"To see the datasets uploaded by a client on this node, use command `[your_client].datasets`"
)

@service_method(
path="dataset.get_all",
Expand Down
3 changes: 0 additions & 3 deletions packages/syft/src/syft/service/enclave/enclave_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,6 @@ def send_user_code_inputs_to_enclave(
def propagate_inputs_to_enclave(
user_code: UserCode, context: ChangeContext
) -> SyftSuccess | SyftError:
if context.node is None:
return SyftError(message=f"context {context}'s node is None")

if isinstance(user_code.enclave_metadata, EnclaveMetadata):
# TODO 🟣 Restructure url it work for local mode host.docker.internal

Expand Down
5 changes: 0 additions & 5 deletions packages/syft/src/syft/service/job/job_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import cast

# relative
from ...abstract_node import AbstractNode
from ...node.worker_settings import WorkerSettings
from ...serde.serializable import serializable
from ...store.document_store import DocumentStore
Expand Down Expand Up @@ -117,8 +116,6 @@ def restart(
if res.is_err():
return SyftError(message=res.err())

context.node = cast(AbstractNode, context.node)

job = res.ok()
job.status = JobStatus.CREATED
self.update(context=context, job=job)
Expand Down Expand Up @@ -228,7 +225,6 @@ def add_read_permission_job_for_code_owner(
def add_read_permission_log_for_code_owner(
self, context: AuthedServiceContext, log_id: UID, user_code: UserCode
) -> Any:
context.node = cast(AbstractNode, context.node)
log_service = context.node.get_service("logservice")
log_service = cast(LogService, log_service)
return log_service.stash.add_permission(
Expand All @@ -245,7 +241,6 @@ def add_read_permission_log_for_code_owner(
def create_job_for_user_code_id(
self, context: AuthedServiceContext, user_code_id: UID
) -> Job | SyftError:
context.node = cast(AbstractNode, context.node)
job = Job(
id=UID(),
node_uid=context.node.id,
Expand Down