# Haiku Level 0 Data Scientist Experience - Chapter 1
## Part 2 - New account registration and code execution requests

Link to the original Haiku tutorial: https://dm-haiku.readthedocs.io/en/latest/notebooks/basics.html

In [None]:
# Import the necessary libraries
import syft as sy
sy.requires(">=0.8-beta")
import jax
import jax.numpy as jnp
import haiku as hk

In [None]:
# Register a client to the domain
node = sy.orchestra.launch(name="test-domain-1")
guest_domain_client = node.client
guest_domain_client.register(name="Jane Doe", email="jane@caltech.edu", password="abc123", institution="Caltech", website="https://www.caltech.edu/")
guest_domain_client.login(email="jane@caltech.edu", password="abc123")

In [None]:
# Create a function for code execution
# ATTENTION: ALL LIBRARIES USED SHOULD BE DEFINED INSIDE THE FUNCTION CONTEXT!!!

@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def first_example():
    import jax
    import jax.numpy as jnp
    import haiku as hk
    import numpy as np
    
    class MyLinear1(hk.Module):

        def __init__(self, output_size, name=None):
            super().__init__(name=name)
            self.output_size = output_size

        def __call__(self, x):
            j, k = x.shape[-1], self.output_size
            w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(j))
            w = hk.get_parameter("w", shape=[j, k], dtype=x.dtype, init=w_init)
            b = hk.get_parameter("b", shape=[k], dtype=x.dtype, init=jnp.ones)
            return jnp.dot(x, w) + b

    def _forward_fn_linear1(x):
        module = MyLinear1(output_size=2)
        return module(x)

    forward_linear1 = hk.transform(_forward_fn_linear1)

    print(forward_linear1)

    dummy_x = jnp.array([[1., 2., 3.]])
    rng_key = jax.random.PRNGKey(42)

    params = forward_linear1.init(rng=rng_key, x=dummy_x)
    print(params)
    
    sample_x = jnp.array([[1., 2., 3.]])
    sample_x_2 = jnp.array([[4., 5., 6.], [7., 8., 9.]])

    output_1 = forward_linear1.apply(params=params, x=sample_x, rng=rng_key)
    # Outputs are identical for given inputs since the forward inference is non-stochastic.
    output_2 = forward_linear1.apply(params=params, x=sample_x, rng=rng_key)

    output_3 = forward_linear1.apply(params=params, x=sample_x_2, rng=rng_key)

    print(f'Output 1 : {output_1}')
    print(f'Output 2 (same as output 1): {output_2}')
    print(f'Output 3 : {output_3}')
    
    forward_without_rng = hk.without_apply_rng(hk.transform(_forward_fn_linear1))
    params = forward_without_rng.init(rng=rng_key, x=sample_x)
    output = forward_without_rng.apply(x=sample_x, params=params)
    print(f'Output without random key in forward pass \n {output}')
    
    mutated_params = jax.tree_util.tree_map(lambda x: x+1., params)
    print(f'Mutated params \n : {mutated_params}')
    mutated_output = forward_without_rng.apply(x=sample_x, params=mutated_params)
    print(f'Output with mutated params \n {mutated_output}')
    
@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def stateful_inference_example():
    import jax
    import jax.numpy as jnp
    import haiku as hk
    import numpy as np
    
    def stateful_f(x):
        counter = hk.get_state("counter", shape=[], dtype=jnp.int32, init=jnp.ones)
        multiplier = hk.get_parameter('multiplier', shape=[1,], dtype=x.dtype, init=jnp.ones)
        hk.set_state("counter", counter + 1)
        output = x + multiplier * counter
        return output

    rng_key = jax.random.PRNGKey(42)
    stateful_forward = hk.without_apply_rng(hk.transform_with_state(stateful_f))
    sample_x = jnp.array([[5., ]])
    params, state = stateful_forward.init(x=sample_x, rng=rng_key)
    print(f'Initial params:\n{params}\nInitial state:\n{state}')
    print('##########')
    for i in range(3):
        output, state = stateful_forward.apply(params, state, x=sample_x)
        print(f'After {i+1} iterations:\nOutput: {output}\nState: {state}')
        print('##########')
    
@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def haiku_nets_example():
    import jax
    import jax.numpy as jnp
    import haiku as hk
    import numpy as np
    # See: https://dm-haiku.readthedocs.io/en/latest/api.html#common-modules
    class MyLinear1(hk.Module):

        def __init__(self, output_size, name=None):
            super().__init__(name=name)
            self.output_size = output_size

        def __call__(self, x):
            j, k = x.shape[-1], self.output_size
            w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(j))
            w = hk.get_parameter("w", shape=[j, k], dtype=x.dtype, init=w_init)
            b = hk.get_parameter("b", shape=[k], dtype=x.dtype, init=jnp.ones)
            return jnp.dot(x, w) + b
        
    class MyModuleCustom(hk.Module):
        def __init__(self, output_size=2, name='custom_linear'):
            super().__init__(name=name)
            self._internal_linear_1 = hk.nets.MLP(output_sizes=[2, 3], name='hk_internal_linear')
            self._internal_linear_2 = MyLinear1(output_size=output_size, name='old_linear')

        def __call__(self, x):
            return self._internal_linear_2(self._internal_linear_1(x))

    def _custom_forward_fn(x):
        module = MyModuleCustom()
        return module(x)
    
    rng_key = jax.random.PRNGKey(42)
    sample_x = jnp.array([[5., ]])
    custom_forward_without_rng = hk.without_apply_rng(hk.transform(_custom_forward_fn))
    params = custom_forward_without_rng.init(rng=rng_key, x=sample_x)
    return params

@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def hk_next_rng_key_example():
    import jax
    import jax.numpy as jnp
    import haiku as hk
    import numpy as np
    
    rng_key = jax.random.PRNGKey(42)
    
    class HkRandom2(hk.Module):
        def __init__(self, rate=0.5):
            super().__init__()
            self.rate = rate

        def __call__(self, x):
            key1 = hk.next_rng_key()
            return jax.random.bernoulli(key1, 1.0 - self.rate, shape=x.shape)


    class HkRandomNest(hk.Module):
        def __init__(self, rate=0.5):
            super().__init__()
            self.rate = rate
            self._another_random_module = HkRandom2()

        def __call__(self, x):
            key2 = hk.next_rng_key()
            p1 = self._another_random_module(x)
            p2 = jax.random.bernoulli(key2, 1.0 - self.rate, shape=x.shape)
            print(f'Bernoullis are  : {p1, p2}')

    # Note that the modules that are stochastic cannot be wrapped with hk.without_apply_rng()
    forward = hk.transform(lambda x: HkRandomNest()(x))

    x = jnp.array(1.)
    print("INIT:")
    params = forward.init(rng_key, x=x)
    print("APPLY:")
    prediction = forward.apply(params, x=x, rng=rng_key)
    
    for _ in range(3):
        forward.apply(params, x=x, rng=rng_key)
        
    for _ in range(3):
        rng_key, apply_rng_key = jax.random.split(rng_key)
        forward.apply(params, x=x, rng=apply_rng_key)
        
    rng_sequence = hk.PRNGSequence(rng_key)
    for _ in range(3):
        forward.apply(params, x=x, rng=next(rng_sequence))

In [None]:
# Test our function locally 
first_example()
stateful_inference_example()
haiku_nets_example()
hk_next_rng_key_example()

In [None]:
# Submit the function for code execution
guest_domain_client.api.services.code.request_code_execution(first_example)
guest_domain_client.api.services.code.request_code_execution(stateful_inference_example)
guest_domain_client.api.services.code.request_code_execution(haiku_nets_example)
guest_domain_client.api.services.code.request_code_execution(hk_next_rng_key_example)

In [None]:
guest_domain_client.api.services.code.first_example()

### Go to the Data Owner Notebook for Part 2!

## Part 3 - Downloading the Results

In [None]:
guest_domain_client._api = None
_ = guest_domain_client.api

In [None]:
result = guest_domain_client.api.services.code.first_example()

In [None]:
result.get_result()

In [None]:
print(result.get_stdout())