# Haiku Level 0 Data Scientist Experience - Chapter 2
## Part 2 - New account registration and code execution requests

Link to the original Haiku tutorial: https://dm-haiku.readthedocs.io/en/latest/notebooks/transforms.html

In [None]:
# Import the necessary libraries
import syft as sy
sy.requires(">=0.8-beta")
import jax
import jax.numpy as jnp
import haiku as hk

In [None]:
# Register a client to the domain
node = sy.orchestra.launch(name="test-domain-1")
guest_domain_client = node.client
guest_domain_client.register(name="Jane Doe", email="jane@caltech.edu", password="abc123", institution="Caltech", website="https://www.caltech.edu/")
guest_domain_client.login(email="jane@caltech.edu", password="abc123")

In [None]:
# Create a function for code execution
# ATTENTION: ALL LIBRARIES USED SHOULD BE DEFINED INSIDE THE FUNCTION CONTEXT!!!

@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def limitations_example():
    import jax
    import jax.numpy as jnp
    import haiku as hk
    import numpy as np
    
    def net(x): # inside of a hk.transform, this is still side-effecting
        w = hk.get_parameter("w", (2, 2), init=jnp.ones)
        return w @ x

    def eval_shape_net(x):
        output_shape = jax.eval_shape(net, x) # eval_shape on side-effecting function
        return net(x)                         # UnexpectedTracerError!

    init, _ = hk.transform(eval_shape_net)
    try:
        init(jax.random.PRNGKey(666), jnp.ones((2, 2)))
    except jax.errors.UnexpectedTracerError:
        print("UnexpectedTracerError: applied JAX transform to side effecting function")
        
    def net(w, x): # no side effects!
        return w @ x

    def eval_shape_net(x):
        w = hk.get_parameter("w", (3, 2), init=jnp.ones)
        output_shape = jax.eval_shape(net, w, x) # net is now side-effect free
        return output_shape, net(w, x)

    key = jax.random.PRNGKey(777)
    x = jnp.ones((2, 3))
    init, apply = hk.transform(eval_shape_net)
    params = init(key, x)
    apply(params, key, x)
    
    def eval_shape_net(x):
        net = hk.nets.MLP([300, 100])
        output_shape = jax.eval_shape(net, x)
        return output_shape, net(x)

    init, _ = hk.transform(eval_shape_net)
    try:
        init(jax.random.PRNGKey(666), jnp.ones((2, 2)))
    except jax.errors.UnexpectedTracerError:
        print("UnexpectedTracerError: applied JAX transform to side effecting function")
        
   
@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def haiku_lift_example():
    import jax
    import jax.numpy as jnp
    import haiku as hk
    import numpy as np
    
    def eval_shape_net(x):
        net = hk.nets.MLP([300, 100])    # still side-effecting
        init, apply = hk.transform(net)  # nested transform
        params = hk.lift(init, name="inner")(hk.next_rng_key(), x) # register parameters in outer module scope with name "inner"
        output_shape = jax.eval_shape(apply, params, hk.next_rng_key(), x) # apply is a functionaly pure function and can be transformed!
        out = net(x)
        return out, output_shape


    init, apply = hk.transform(eval_shape_net)
    params = init(jax.random.PRNGKey(777), jnp.ones((100, 100)))
    apply(params, jax.random.PRNGKey(777), jnp.ones((100, 100)))
    jax.tree_util.tree_map(lambda x: x.shape, params)
    
   
@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def jax_transform_example():
    import jax
    import jax.numpy as jnp
    import haiku as hk
    import numpy as np

    def eval_shape_net(x):
        net = hk.nets.MLP([300, 100])         # still side-effecting
        output_shape = hk.eval_shape(net, x)  # hk.eval_shape threads through the Haiku state for you
        out = net(x)
        return out, output_shape


    init, apply = hk.transform(eval_shape_net)
    params = init(jax.random.PRNGKey(777), jnp.ones((100, 100)))
    out = apply(params, jax.random.PRNGKey(777), jnp.ones((100, 100)))


In [None]:
# Test our function locally 
limitations_example()
haiku_lift_example()
jax_transform_example()

In [None]:
# Submit the function for code execution
guest_domain_client.api.services.code.request_code_execution(limitations_example)
guest_domain_client.api.services.code.request_code_execution(haiku_lift_example)
guest_domain_client.api.services.code.request_code_execution(jax_transform_example)

In [None]:
guest_domain_client.api.services.code.limitations_example()

### Go to the Data Owner Notebook for Part 2!

## Part 3 - Downloading the Results

In [None]:
guest_domain_client._api = None
_ = guest_domain_client.api

In [None]:
result = guest_domain_client.api.services.code.limitations_example()

In [None]:
result.get_result()

In [None]:
print(result.get_stderr())