# Haiku Level 0 Data Scientist Experience - Chapter 3
## Part 2 - New account registration and code execution requests

Link to the original Haiku tutorial: https://dm-haiku.readthedocs.io/en/latest/notebooks/build_your_own_haiku.html

In [None]:
# Import the necessary libraries
import syft as sy
sy.requires(">=0.8-beta")
import jax
import jax.numpy as jnp
import haiku as hk

In [None]:
# Register a client to the domain
node = sy.orchestra.launch(name="test-domain-1")
guest_domain_client = node.client
guest_domain_client.register(name="Jane Doe", email="jane@caltech.edu", password="abc123", institution="Caltech", website="https://www.caltech.edu/")
guest_domain_client.login(email="jane@caltech.edu", password="abc123")

In [None]:
# Create a function for code execution
# ATTENTION: ALL LIBRARIES USED SHOULD BE DEFINED INSIDE THE FUNCTION CONTEXT!!!

@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def example(ds_train, ds_test):
    import haiku as hk
    import jax
    import jax.numpy as jnp
    from jax.experimental import jax2tf
    import sonnet as snt
    import tensorflow as tf
    import tree
    
    def f(x):
        net = hk.nets.MLP([300, 100, 10])
        return net(x)

    f = hk.transform(f)

    rng = jax.random.PRNGKey(42)
    x = jnp.ones([1, 28 * 28 * 1])
    params = f.init(rng, x)

    def create_variable(path, value):
        name = '/'.join(map(str, path)).replace('~', '_')
        return tf.Variable(value, name=name)

    class JaxModule(snt.Module):
        def __init__(self, params, apply_fn, name=None):
            super().__init__(name=name)
            self._params = tree.map_structure_with_path(create_variable, params)
            self._apply = jax2tf.convert(lambda p, x: apply_fn(p, None, x))
            self._apply = tf.autograph.experimental.do_not_convert(self._apply)

        def __call__(self, inputs):
            return self._apply(self._params, inputs)

    net = JaxModule(params, f.apply)
    [v.name for v in net.trainable_variables]

    def normalize_img(image, label):
        """Normalizes images: `uint8` -> `float32`."""
        image = tf.cast(image, tf.float32) / 255.
        return image, label

    ds_train = ds_train.map(normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    ds_train = ds_train.cache()
    ds_train = ds_train.shuffle(60000)
    ds_train = ds_train.batch(100)
    ds_train = ds_train.repeat()
    ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)

    ds_test = ds_test.map(normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    ds_test = ds_test.batch(100)
    ds_test = ds_test.cache()
    ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)
    
    net = JaxModule(params, f.apply)
    opt = snt.optimizers.Adam(1e-3)

    @tf.function(experimental_compile=True, autograph=False)
    def train_step(images, labels):
        """Performs one optimizer step on a single mini-batch."""
        with tf.GradientTape() as tape:
            images = snt.flatten(images)
            logits = net(images)
            loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                                labels=labels)
            loss = tf.reduce_mean(loss)
            params = tape.watched_variables()
            loss += 1e-4 * sum(map(tf.nn.l2_loss, params))

        grads = tape.gradient(loss, params)
        opt.apply(grads, params)
        return loss

    for step, (images, labels) in enumerate(ds_train.take(6001)):
        loss = train_step(images, labels)
        if step % 1000 == 0:
            print(f"Step {step}: {loss.numpy()}")


In [None]:
# Test our function locally 
first_example()
stateful_inference_example()
haiku_nets_example()
hk_next_rng_key_example()

In [None]:
# Submit the function for code execution
guest_domain_client.api.services.code.request_code_execution(first_example)
guest_domain_client.api.services.code.request_code_execution(stateful_inference_example)
guest_domain_client.api.services.code.request_code_execution(haiku_nets_example)
guest_domain_client.api.services.code.request_code_execution(hk_next_rng_key_example)

In [None]:
guest_domain_client.api.services.code.first_example()

### Go to the Data Owner Notebook for Part 2!

## Part 3 - Downloading the Results

In [None]:
guest_domain_client._api = None
_ = guest_domain_client.api

In [None]:
result = guest_domain_client.api.services.code.haiku_nets_example()

In [None]:
result.get_result()

In [None]:
print(result.get_stderr())