In [None]:
SYFT_VERSION = ">=0.8.2.b0,<0.9"
package_string = f'"syft{SYFT_VERSION}"'
# %pip install {package_string} -q

In [None]:
# third party
import haiku as hk
import jax
from jax import random

# syft absolute
import syft as sy

sy.requires(SYFT_VERSION)

In [None]:
node = sy.orchestra.launch(name="test-domain-1", port="auto", dev_mode=True)

In [None]:
domain_client = node.login(email="info@openmined.org", password="changethis")

In [None]:
key = random.PRNGKey(42)

In [None]:
train_data = random.uniform(key, shape=(4, 28, 28, 1))

In [None]:
assert round(train_data.sum()) == 1602

In [None]:
train = sy.ActionObject.from_obj(train_data)

In [None]:
type(train.syft_action_data), train.id, train.shape

In [None]:
train_domain_obj = domain_client.api.services.action.set(train)

In [None]:
class MLP(hk.Module):
    def __init__(self, out_dims, name=None):
        super().__init__(name=name)
        self.out_dims = out_dims

    def __call__(self, x):
        x = x.reshape((x.shape[0], -1))
        x = hk.Linear(128)(x)
        x = jax.nn.relu(x)
        x = hk.Linear(self.out_dims)(x)
        return x


def _forward_fn_linear1(x):
    module = MLP(out_dims=10)
    return module(x)


model = hk.transform(_forward_fn_linear1)

In [None]:
weights = model.init(key, train.syft_action_data)

In [None]:
assert isinstance(weights, dict)

In [None]:
w = sy.ActionObject.from_obj(weights)

In [None]:
type(w.syft_action_data), w.id

In [None]:
weight_domain_obj = domain_client.api.services.action.set(w)

In [None]:
@sy.syft_function(
    input_policy=sy.ExactMatch(weights=weight_domain_obj.id, data=train_domain_obj.id),
    output_policy=sy.SingleExecutionExactOutput(),
)
def train_mlp(weights, data):
    # third party
    import haiku as hk
    import jax

    class MLP(hk.Module):
        def __init__(self, out_dims, name=None):
            super().__init__(name=name)
            self.out_dims = out_dims

        def __call__(self, x):
            x = x.reshape((x.shape[0], -1))
            x = hk.Linear(128)(x)
            x = jax.nn.relu(x)
            x = hk.Linear(self.out_dims)(x)
            return x

    def _forward_fn_linear1(x):
        module = MLP(out_dims=10)
        return module(x)

    model = hk.transform(_forward_fn_linear1)
    rng_key = jax.random.PRNGKey(42)
    output = model.apply(params=weights, x=data, rng=rng_key)
    return output

In [None]:
pointer = train_mlp(weights=weight_domain_obj, data=train_domain_obj)
output = pointer.get()

In [None]:
assert round(output.sum(), 2) == -0.86

In [None]:
request = domain_client.code.request_code_execution(train_mlp)
request

In [None]:
request.approve()

In [None]:
domain_client._api = None
_ = domain_client.api

In [None]:
result_ptr = domain_client.code.train_mlp(weights=w.id, data=train.id)

In [None]:
result = result_ptr.get()

In [None]:
assert round(float(result.sum()), 2) == -0.86

In [None]:
if node.node_type.value == "python":
    node.land()