# Haiku Level 0 Data Scientist Experience - Chapter 5
## Part 1 - New account registration and code execution requests

Link to the original Haiku tutorial: https://dm-haiku.readthedocs.io/en/latest/notebooks/parameter_sharing.html

In [None]:
# Import the necessary libraries
import syft as sy
sy.requires(">=0.8-beta")
import jax
import jax.numpy as jnp
import haiku as hk

In [None]:
# Register a client to the domain
node = sy.orchestra.launch(name="test-domain-1")
guest_domain_client = node.client
guest_domain_client.register(name="Jane Doe", email="jane@caltech.edu", password="abc123", institution="Caltech", website="https://www.caltech.edu/")
guest_domain_client.login(email="jane@caltech.edu", password="abc123")

In [None]:
# Create a function for code execution
# ATTENTION: ALL LIBRARIES USED SHOULD BE DEFINED INSIDE THE FUNCTION CONTEXT!!!

@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def example():
    #@title Imports and accessory functions
    import chex
    import functools
    import haiku as hk
    import jax
    import jax.numpy as jnp

    def parameter_shapes(params):
        """Make printing parameters a little more readable."""
        return jax.tree_util.tree_map(lambda p: p.shape, params)


    def transform_and_print_shapes(fn, x_shape=(2, 3)):
        """Print name and shape of the parameters."""
        rng = jax.random.PRNGKey(42)
        x = jnp.ones(x_shape)

        transformed_fn = hk.transform(fn)
        params = transformed_fn.init(rng, x)
        print('\nThe name and shape of the parameters are:')
        print(parameter_shapes(params))

    def assert_all_equal(params_1, params_2):
        assert all(jax.tree_util.tree_leaves(
            jax.tree_util.tree_map(lambda a, b: (a == b).all(), params_1, params_2)))
        
    
    w_init = hk.initializers.TruncatedNormal(stddev=1)

    class SimpleModule(hk.Module):
        """A simple module class with one variable."""

        def __init__(self, output_channels, name=None):
            super().__init__(name)
            assert isinstance(output_channels, int)
            self._output_channels = output_channels

        def __call__(self, x):
            w_shape = (x.shape[-1], self._output_channels)
            w = hk.get_parameter("w", w_shape, x.dtype, init=w_init)
            return jnp.dot(x, w)
        
    def f(x):
        # This instance will be named `a_simple_module`.
        simple = SimpleModule(output_channels=2)
        simple_out = simple(x)  # implicitly calls module_install.__call__()
        print(f'The name assigned to "simple" is: "{simple.module_name}".')
        return simple_out

    transform_and_print_shapes(f)
    
    def f(x):
        # This instance will be named `a_simple_module`.
        simple_one = SimpleModule(output_channels=2)
        # This instance will be named `a_simple_module_1`.
        simple_two = SimpleModule(output_channels=2)
        first_out = simple_one(x)
        second_out = simple_two(x)
        print(f'The name assigned to "simple_one" is: "{simple_one.module_name}".')
        print(f'The name assigned to "simple_two" is: "{simple_two.module_name}".')
        return first_out + second_out

    transform_and_print_shapes(f)
    
    def f(x):
        # This instance will be named `a_simple_module`.
        simple_one = SimpleModule(output_channels=2)
        first_out = simple_one(x)
        second_out = simple_one(x)  # share parameters w/ previous call
        print(f'The name assigned to "simple_one" is: "{simple_one.module_name}".')
        return first_out + second_out

    transform_and_print_shapes(f)
    
    class NestedModule(hk.Module):
        """A module class with a nested module created in the constructor."""
        def __init__(self, output_channels, name=None):
            super().__init__(name)
            assert isinstance(output_channels, int)
            self._output_channels = output_channels
            self.inner_simple = SimpleModule(self._output_channels)

        def __call__(self, x):
            w_shape = (x.shape[-1], self._output_channels)
            # Another variable that is also called `w`.
            w = hk.get_parameter("w", w_shape, x.dtype, init=w_init)
            return jnp.dot(x, w) + self.inner_simple(x)
        
    def f(x):
        # This will be named `a_nested_module` and the SimpleModule instance created
        # inside it will be named `a_nested_module/a_simple_module`.
        nested = NestedModule(output_channels=2)
        nested_out = nested(x)
        print('The name assigned to outer module (i.e., "nested") is: '
                f'"{nested.module_name}".')
        print('The name assigned to the inner module (i.e., inside "nested") is: "'
                f'{nested.inner_simple.module_name}".')
        return nested_out

    transform_and_print_shapes(f)

    class TwiceNestedModule(hk.Module):
        """A module class with a nested module containing a nested module."""

        def __init__(self, output_channels, name=None):
            super().__init__(name)
            assert isinstance(output_channels, int)
            self._output_channels = output_channels
            self.inner_nested = NestedModule(self._output_channels)

        def __call__(self, x):
            w_shape = (x.shape[-1], self._output_channels)
            w = hk.get_parameter("w", w_shape, x.dtype, init=w_init)
            return jnp.dot(x, w) + self.inner_nested(x)
        
    def f(x):
        """Create the module instances and inspect their names."""
        # Instantiate a NestedModule instance. This will be named `a_nested_module`.
        # The SimpleModule instance created inside it will be named
        # a_nested_module/a_simple_module`.
        outer = TwiceNestedModule(output_channels=2)
        outer_out = outer(x)
        print(f'The name assigned to the most outer class is: "{outer.module_name}".')
        print('The name assigned to the module inside "double_nested" is: "'
                f'{outer.inner_nested.module_name}".')
        print('The name assigned to the module inside it is "'
                f'{outer.inner_nested.inner_simple.module_name}".')
        return outer_out

    transform_and_print_shapes(f)

    def f(x):
        """A SimpleModule followed by a Linear layer."""
        module_instance = SimpleModule(output_channels=2)
        out = module_instance(x)
        linear = hk.Linear(40)
        return linear(out)

    def g(x):
        """A SimpleModule followed by an MLP."""
        module_instance = SimpleModule(output_channels=2)
        return module_instance(x) * 2  # twice

    # Transform both functions, and print their respective parameter shapes.
    rng = jax.random.PRNGKey(42)
    x = jnp.ones((2, 3))
    transformed_f = hk.transform(f)
    params_f = transformed_f.init(rng, x)
    transformed_g = hk.transform(g)
    params_g = transformed_g.init(rng, x)
    print('f parameters:', parameter_shapes(params_f))
    print('g parameters:', parameter_shapes(params_g))

    # Transform both functions at once with hk.multi_transform , and print the
    # resulting merged parameter structure.

    def multitransform_f_and_g():
        def template(x):
            return f(x), g(x)
        return template, (f, g)
    init, (f_apply, g_apply) = hk.multi_transform(multitransform_f_and_g)
    merged_params = init(rng, x)

    print('\nThe name and shape of the multi-transform parameters are:\n',
        parameter_shapes(merged_params))
    
    def f(x):
        """Apply SimpleModule to x."""
        module_instance = SimpleModule(output_channels=2)
        out = module_instance(x)
        return out

    def g(x):
        """Like f, but double the output"""
        module_instance = SimpleModule(output_channels=2)
        out = module_instance(x)
        return out * 2

    # Transform both functions, and print the parameter shapes.
    rng = jax.random.PRNGKey(42)
    x = jnp.ones((2, 3))

    transformed_f = hk.transform(f)
    params_f = transformed_f.init(rng, x)
    transformed_g = hk.transform(g)
    params_g = transformed_g.init(rng, x)

    print('f parameters:', parameter_shapes(params_f))
    print('g parameters:', parameter_shapes(params_g))
    
    def f(x):
        """A SimpleModule followed by a Linear layer."""
        module_instance = SimpleModule(output_channels=2)
        out = module_instance(x)
        linear = hk.Linear(40)
        return linear(out)

    def g(x):
        """A SimpleModule followed by an MLP."""
        module_instance = SimpleModule(output_channels=2)
        out = module_instance(x)
        linear = hk.nets.MLP((10, 40))
        return linear(out)  

    # Transform both functions, and print the parameter shapes.
    rng = jax.random.PRNGKey(42)
    x = jnp.ones((2, 3))

    transformed_f = hk.transform(f)
    params_f = transformed_f.init(rng, x)
    transformed_g = hk.transform(g)
    params_g = transformed_g.init(rng, x)

    print('\nThe name and shape of the f parameters are:\n',
        parameter_shapes(params_f))
    print('\nThe name and shape of the g parameters are:\n',
        parameter_shapes(params_g))
    
    merged_params = hk.data_structures.merge(params_f, params_g)
    print('\nThe name and shape of the shared parameters are:\n',
        parameter_shapes(merged_params))
    
    f_out = transformed_f.apply(merged_params, rng, x)
    g_out = transformed_g.apply(merged_params, rng, x)

    print('f_out mean:', f_out.mean())
    print('g_out mean:', g_out.mean())
    
    def f(x):
        """A SimpleModule followed by two Linear layers."""
        module_instance = SimpleModule(output_channels=2)
        out = module_instance(x)
        mlp = hk.nets.MLP((10, 5))
        out = mlp(out)
        last_linear = hk.Linear(4)
        return last_linear(out)

    def g(x):
        """Same as f, with a bigger final layer."""
        module_instance = SimpleModule(output_channels=2)
        out = module_instance(x)
        mlp = hk.nets.MLP((10, 5))
        out = mlp(out)
        last_linear = hk.Linear(20)  # another Linear, but bigger
        return last_linear(out)

    # Transform both functions, and print the parameter shapes.
    rng = jax.random.PRNGKey(42)
    x = jnp.ones((2, 3))

    transformed_f = hk.transform(f)
    params_f = transformed_f.init(rng, x)
    transformed_g = hk.transform(g)
    params_g = transformed_g.init(rng, x)

    print('\nThe name and shape of the f parameters are:\n',
        parameter_shapes(params_f))
    print('\nThe name and shape of the g parameters are:\n',
        parameter_shapes(params_g))
    
    merged_params = hk.data_structures.merge(params_f, params_g)
    print('\nThe name and shape of the merged parameters are:\n',
        parameter_shapes(merged_params))

    try: 
        f_out = transformed_f.apply(merged_params, rng, x)  # fails
        # ValueError: 'linear/w' with retrieved shape (5, 20) does not match shape=[5, 4] dtype=dtype('float32')
    except Exception:
        pass
    
    try:
        module_instance = SimpleModule(output_channels=2)  # this fails
        # ValueError: All `hk.Module`s must be initialized inside an `hk.transform`.
        mlp = hk.nets.MLP((10, 5))
        def f(x):
            """A SimpleModule followed by a Linear layer."""
            out = module_instance(x)
            out = mlp(out)
            linear = hk.Linear(4)
            return linear(out)

        def g(x):
            """A SimpleModule followed by a bigger Linear layer."""
            out = module_instance(x)
            out = mlp(out)
            linear = hk.Linear(20)  # another Linear, but bigger
            return linear(out)
    except Exception:
        pass
    
    class CachedModule():

        def __call__(self, *inputs):
            # Create the instances if are not in the cache.
            if not hasattr(self, 'cached_simple_module'):
                self.cached_simple_module = SimpleModule(output_channels=2)
            if not hasattr(self, 'cached_mlp'):
                self.cached_mlp = hk.nets.MLP((10, 5))

            # Apply the cached instances.
            out = self.cached_simple_module(*inputs)
            out = self.cached_mlp(out)
            return out


    def f(x):
        """A SimpleModule followed by a Linear layer."""
        shared_preprocessing = CachedModule()
        out = shared_preprocessing(x)
        linear = hk.Linear(4)
        return linear(out)

    def g(x):
        """A SimpleModule followed by a bigger Linear layer."""
        shared_preprocessing = CachedModule()
        out = shared_preprocessing(x)
        linear = hk.Linear(20)  # another Linear, but bigger
        return linear(out)


    # Transform both functions, and print the parameter shapes.
    rng = jax.random.PRNGKey(42)
    x = jnp.ones((2, 3))

    transformed_f = hk.transform(f)
    params_f = transformed_f.init(rng, x)
    transformed_g = hk.transform(g)
    params_g = transformed_g.init(rng, x)

    print('\nThe name and shape of the f parameters are:\n',
        parameter_shapes(params_f))
    print('\nThe name and shape of the g parameters are:\n',
        parameter_shapes(params_g))

    # Verify that the simple module parameters are shared.
    assert_all_equal(params_f['mlp/~/linear_0'],
                    params_g['mlp/~/linear_0'])
    assert_all_equal(params_f['mlp/~/linear_1'],
                    params_g['mlp/~/linear_1'])
    print('\nThe MLP parameters are shared!')    
    
    def share_parameters():
        def decorator(fn):
            def wrapper(*args, **kwargs):
                if wrapper.instance is None:
                    wrapper.instance = hk.to_module(fn)()
                return wrapper.instance(*args, **kwargs)
            wrapper.instance = None
            return functools.wraps(fn)(wrapper)
        return decorator


    class Wrapper():

        @share_parameters()
        def shared_preprocessing(self, x):
            simple_module = SimpleModule(output_channels=2)
            out = simple_module(x)
            mlp = hk.nets.MLP((10, 5))
            return mlp(out)

        def f(self, x):
            """A SimpleModule followed by a Linear layer."""
            out = self.shared_preprocessing(x)
            linear = hk.Linear(4)
            return linear(out)

        def g(self, x):
            """A SimpleModule followed by a bigger Linear layer."""
            out = self.shared_preprocessing(x)
            linear = hk.Linear(20)  # another Linear, but bigger
            return linear(out)

    # Transform both functions, and print the parameter shapes.
    rng = jax.random.PRNGKey(42)
    x = jnp.ones((2, 3))

    wrapper = Wrapper()
    transformed_f = hk.transform(wrapper.f)
    params_f = transformed_f.init(rng, x)
    transformed_g = hk.transform(wrapper.f)
    params_g = transformed_g.init(rng, x)

    print('\nThe name and shape of the f parameters are:\n',
        parameter_shapes(params_f))
    print('\nThe name and shape of the g parameters are:\n',
        parameter_shapes(params_g))

    # Verify that the simple module parameters are shared.
    assert_all_equal(params_f['shared_preprocessing/mlp/~/linear_0'],
                    params_g['shared_preprocessing/mlp/~/linear_0'])
    assert_all_equal(params_f['shared_preprocessing/mlp/~/linear_1'],
                    params_g['shared_preprocessing/mlp/~/linear_1'])
    print('\nThe MLP parameters are shared!')


In [None]:
# Test our function locally 
example()

In [None]:
# Submit the function for code execution
guest_domain_client.api.services.code.request_code_execution(example)

In [None]:
guest_domain_client.api.services.code.example()

### Go to the Data Owner Notebook for Part 2!

## Part 3 - Downloading the Results

In [None]:
guest_domain_client._api = None
_ = guest_domain_client.api

In [None]:
result = guest_domain_client.api.services.code.example()

In [None]:
result.get_result()

In [None]:
print(result.get_stderr())

In [None]:
print(result.get_stdout())