In [None]:
import syft as sy

In [None]:
worker = sy.Worker.named("test-domain-1", reset=True)
domain_client = worker.root_client

In [None]:
from jax import random
from flax import linen as nn
key = random.PRNGKey(42)

In [None]:
train_data = random.uniform(key, shape=(4, 28, 28, 1))

In [None]:
print("round(train_data.sum())", round(train_data.sum()))

In [None]:
assert round(train_data.sum()) == 1602

In [None]:
train = sy.ActionObject.from_obj(train_data)

In [None]:
type(train.syft_action_data), train.id, train.shape

In [None]:
_ = domain_client.api.services.action.set(train)

In [None]:
class MLP(nn.Module):
    out_dims: int

    @nn.compact
    def __call__(self, x):
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(128)(x)
        x = nn.relu(x)
        x = nn.Dense(self.out_dims)(x)
        return x

model = MLP(out_dims=10)

In [None]:
weights = model.init(key, train.syft_action_data)

In [None]:
w = sy.ActionObject.from_obj(weights)

In [None]:
type(w.syft_action_data), w.id

In [None]:
_ = domain_client.api.services.action.set(w)

In [None]:
@sy.syft_function(input_policy=sy.ExactMatch(weights=w.id, data=train.id),
                  output_policy=sy.SingleExecutionExactOutput())
def train_mlp(weights, data):
    from flax import linen as nn

    class MLP(nn.Module):
        out_dims: int

        @nn.compact
        def __call__(self, x):
            x = x.reshape((x.shape[0], -1))
            x = nn.Dense(128)(x)
            x = nn.relu(x)
            x = nn.Dense(self.out_dims)(x)
            return x

    model = MLP(out_dims=10)
    output = model.apply(weights, data)
    return output

In [None]:
output = train_mlp(weights=weights, data=train_data)

In [None]:
print("round(output.sum(), 2)", round(output.sum(), 2))

In [None]:
assert round(output.sum(), 2) == -3.24

In [None]:
request = domain_client.api.services.code.request_code_execution(train_mlp)
request

In [None]:
request.approve()

In [None]:
domain_client._api = None
_ = domain_client.api

In [None]:
result = domain_client.api.services.code.train_mlp(weights=w.id, data=train.id)

In [None]:
result

In [None]:
print("round(float(result.sum()), 2)", round(float(result.sum()), 2))

In [None]:
assert round(float(result.sum()), 2) == -3.24