In [1]:
import syft as sy

In [2]:
worker = sy.Worker.named("test-domain-1", reset=True)
domain_client = worker.root_client

> Starting Worker: test-domain-1 - 7bca415d13ed1ec841f0d0aede098dbb - NodeType.DOMAIN - [<class 'syft.core.node.new.user_service.UserService'>, <class 'syft.core.node.new.metadata_service.MetadataService'>, <class 'syft.core.node.new.action_service.ActionService'>, <class 'syft.core.node.new.test_service.TestService'>, <class 'syft.core.node.new.dataset_service.DatasetService'>, <class 'syft.core.node.new.user_code_service.UserCodeService'>, <class 'syft.core.node.new.request_service.RequestService'>, <class 'syft.core.node.new.data_subject_service.DataSubjectService'>, <class 'syft.core.node.new.network_service.NetworkService'>, <class 'syft.core.node.new.message_service.MessageService'>, <class 'syft.core.node.new.project_service.ProjectService'>, <class 'syft.core.node.new.data_subject_member_service.DataSubjectMemberService'>]


In [3]:
from jax import random
from flax import linen as nn
key = random.PRNGKey(42)



In [4]:
train_data = random.uniform(key, shape=(4, 28, 28, 1))

In [5]:
assert round(train_data.sum()) == 1602

In [6]:
train = sy.ActionObject.from_obj(train_data)

In [7]:
type(train)

syft.core.node.new.action_object.AnyActionObject

In [33]:
type(train_data)

jaxlib.xla_extension.DeviceArray

In [8]:
type(train.syft_action_data), train.id, train.shape

(jaxlib.xla_extension.DeviceArray,
 <UID: d7d7c9c7a9d94b589fae1275de67c96c>,
 (4, 28, 28, 1))

In [9]:
train_domain_obj = domain_client.api.services.action.set(train)

In [10]:
class MLP(nn.Module):
    out_dims: int

    @nn.compact
    def __call__(self, x):
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(128)(x)
        x = nn.relu(x)
        x = nn.Dense(self.out_dims)(x)
        return x

model = MLP(out_dims=10)

In [38]:
weights = model.init(key, train_data)

In [13]:
w = sy.ActionObject.from_obj(weights)

In [14]:
type(w)

syft.core.node.new.action_object.AnyActionObject

In [15]:
type(w.syft_action_data), w.id

(flax.core.frozen_dict.FrozenDict, <UID: 4c9291c106244b84abc635f8809217fd>)

In [16]:
weight_domain_obj = domain_client.api.services.action.set(w)

In [17]:
@sy.syft_function(input_policy=sy.ExactMatch(weights=weight_domain_obj.id, data=train_domain_obj.id),
                  output_policy=sy.SingleExecutionExactOutput())
def train_mlp(weights, data):
    from flax import linen as nn

    class MLP(nn.Module):
        out_dims: int

        @nn.compact
        def __call__(self, x):
            x = x.reshape((x.shape[0], -1))
            x = nn.Dense(128)(x)
            x = nn.relu(x)
            x = nn.Dense(self.out_dims)(x)
            return x

    model = MLP(out_dims=10)
    output = model.apply(weights, data)
    return output

Train locally

In [39]:
output = train_mlp(weights=weights, data=train_data)

In [40]:
assert round(output.sum(), 2) == -3.24

In [41]:
output

DeviceArray([[ 0.14943041, -0.36854096, -0.64575584, -0.38621526,
              -0.28981561,  0.14723957,  0.35607396,  0.898455  ,
              -0.46983801,  0.21583178],
             [-0.36093625, -0.0785419 , -0.41703793, -0.82913101,
               0.06887782,  0.079618  ,  0.22278813,  0.55593109,
              -0.53083418, -0.0054186 ],
             [-0.31463861,  0.0295174 , -0.62358003, -0.08584507,
              -0.24341324, -0.17701984,  0.3985397 ,  0.67374497,
              -0.14091304,  0.0577738 ],
             [-0.3278211 , -0.35691213, -0.77101191, -0.52124855,
               0.10943515, -0.01648953,  0.27638874,  0.55057775,
              -0.11716184,  0.05130892]], dtype=float64)

Remote training

In [21]:
request = domain_client.api.services.code.request_code_execution(train_mlp)
request

```python
class Request:
  id: str = 7ea0885963f345e6858c52c5bad1838c
  requesting_user_verify_key: str = aec6ea4dfc049ceacaeeebc493167a88a200ddc367b1fa32da652444b635d21f
  approving_user_verify_key: str = None
  request_time: str = 2023-03-15 01:34:19
  approval_time: str = None
  status: str = RequestStatus.PENDING
  node_uid: str = 7bca415d13ed1ec841f0d0aede098dbb
  request_hash: str = "0596f097d5277da75a5c8d6d9b22f80c1e71f107b441c1b61da5fddfe2c0c83a"
  changes: str = [syft.core.node.new.request.UserCodeStatusChange]

```

In [22]:
request.approve()

In [23]:
domain_client._api = None
_ = domain_client.api

In [24]:
result = domain_client.api.services.code.train_mlp(weights=w.id, data=train.id)

In [25]:
result

DeviceArray([[ 0.14943041, -0.36854096, -0.64575584, -0.38621526,
              -0.28981561,  0.14723957,  0.35607396,  0.898455  ,
              -0.46983801,  0.21583178],
             [-0.36093625, -0.0785419 , -0.41703793, -0.82913101,
               0.06887782,  0.079618  ,  0.22278813,  0.55593109,
              -0.53083418, -0.0054186 ],
             [-0.31463861,  0.0295174 , -0.62358003, -0.08584507,
              -0.24341324, -0.17701984,  0.3985397 ,  0.67374497,
              -0.14091304,  0.0577738 ],
             [-0.3278211 , -0.35691213, -0.77101191, -0.52124855,
               0.10943515, -0.01648953,  0.27638874,  0.55057775,
              -0.11716184,  0.05130892]], dtype=float64)

In [26]:
assert round(float(result.sum()), 2) == -3.24

IGNORING: got action result id=<UID: 11c806acf70f4af385fee0738dcdd09e> syft_parent_id=None syft_action_data=DeviceArray(-3.23658824, dtype=float64) syft_node_uid=<UID: 7bca415d13ed1ec841f0d0aede098dbb> syft_internal_type=typing.Any syft_passthrough_attrs=[] syft_dont_wrap_attrs=[]
result doesnt have a syft_node_uid attr


In [27]:
type(weights)

flax.core.frozen_dict.FrozenDict

In [28]:
type(w)

syft.core.node.new.action_object.AnyActionObject

In [29]:
train.id

<UID: d7d7c9c7a9d94b589fae1275de67c96c>