# Haiku Level 0 Data Scientist Experience - Chapter 3
## Part 2 - New account registration and code execution requests

Link to the original Haiku tutorial: https://dm-haiku.readthedocs.io/en/latest/notebooks/build_your_own_haiku.html

In [None]:
# Import the necessary libraries
import syft as sy
sy.requires(">=0.8-beta")
import jax
import jax.numpy as jnp
import haiku as hk

In [None]:
# Register a client to the domain
node = sy.orchestra.launch(name="test-domain-1")
guest_domain_client = node.client
guest_domain_client.register(name="Jane Doe", email="jane@caltech.edu", password="abc123", institution="Caltech", website="https://www.caltech.edu/")
guest_domain_client.login(email="jane@caltech.edu", password="abc123")

In [None]:
# Create a function for code execution
# ATTENTION: ALL LIBRARIES USED SHOULD BE DEFINED INSIDE THE FUNCTION CONTEXT!!!

@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def basic_strategy():
    # What we need to build
    def my_stateless_apply(params, x):
        return params['w'] * x
    
    ## Basic Strategy
    
    # Global state which holds the parameters for the transformed function.
    # get_param uses this to know where to get params from.
    current_params = []

    def transform(f):
        def apply_f(params, *args, **kwargs):
            current_params.append(params)
            outs = f(*args, **kwargs)
            current_params.pop()
            return outs

        return apply_f


    def get_param(identifier):
        return current_params[-1][identifier]

    params = dict(w=5)
    my_stateless_apply(params, 5)
    
    class MyModule:
        def apply(self, x):
            return get_param('w') * x

    transform(MyModule().apply)(params, 5)
    
    import jax
    import jax.numpy as jnp

    def linear(x):
        return x @ get_param('w') + get_param('b')

    params = dict(w=jnp.ones((3, 5)), b=jnp.ones((5,)))
    apply = transform(linear)

    print(jax.jit(apply)(params, jnp.ones((10, 3))))


@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def mini_haiku():   
    from typing import NamedTuple, Dict, Callable
    import numpy as np
    
    # Since we're tracking more than just the current params,
    # we introduce the concept of a frame as the object that holds
    # state during a transformed execution.
    frame_stack = []

    class Frame(NamedTuple):
        """Tracks what's going on during a call of a transformed function."""
        params: Dict[str, jnp.ndarray]
        is_initialising: bool = False

    def current_frame():
        return frame_stack[-1]


    class Transformed(NamedTuple):
        init: Callable
        apply: Callable


    def transform(f) -> Transformed:

        def init_f(*args, **kwargs):
            frame_stack.append(Frame({}, is_initialising=True))
            f(*args, **kwargs)
            frame = frame_stack.pop()
            return frame.params

        def apply_f(params, *args, **kwargs):
            frame_stack.append(Frame(params))
            outs = f(*args, **kwargs)
            frame_stack.pop()
            return outs

        return Transformed(init_f, apply_f)

    def get_param(identifier, shape):
        if current_frame().is_initialising:
            current_frame().params[identifier] = np.random.normal(size=shape)

        return current_frame().params[identifier]
    
    # Make printing parameters a little more readable
    def parameter_shapes(params):
        return jax.tree_util.tree_map(lambda p: p.shape, params)


    class Linear:

        def __init__(self, width):
            self._width = width

        def __call__(self, x):
            w = get_param('w', shape=(x.shape[-1], self._width))
            b = get_param('b', shape=(self._width,))
            return x @ w + b

    init, apply = transform(Linear(4))

    data = jnp.ones((2, 3))

    params = init(data)
    print(parameter_shapes(params))
    print(apply(params, data))

    ## Adding unique parameter names    

    import dataclasses
    import collections
    from typing import Dict
    import numpy as np

    @dataclasses.dataclass
    class Frame:
        """Tracks what's going on during a call of a transformed function."""
        params: Dict[str, jnp.ndarray]
        is_initialising: bool = False

        # Keeps track of how many modules of each class have been created so far.
        # Used to assign new modules unique names.
        module_counts: Dict[str, int] = dataclasses.field(
            default_factory=lambda: collections.defaultdict(lambda: 0))

        # Keeps track of the entire path to the current module method call.
        # Module methods, when called, will add themselves to this stack.
        # Used to give each parameter a unique name corresponding to the
        # method scope it is in.
        call_stack: list = dataclasses.field(default_factory=list)

        def create_param_path(self, identifier) -> str:
            """Creates a unique path for this param."""
            return '/'.join(['~'] + self.call_stack + [identifier])

        def create_unique_module_name(self, module_name: str) -> str:
            """Assigns a unique name to the module by appending its number to its name."""
            number = self.module_counts[module_name]
            self.module_counts[module_name] += 1
            return f"{module_name}_{number}"

    frame_stack = []

    def current_frame():
        return frame_stack[-1]


    class Module:
        def __init__(self):
            # Assign a unique (for the current `transform` call)
            # name to this instance of the module.
            self._unique_name = current_frame().create_unique_module_name(
                self.__class__.__name__)


    def module_method(f):
        """A decorator for Module methods."""
        # In the real Haiku, this doesn't face the user but is applied by a metaclass.

        def wrapped(self, *args, **kwargs):
            """A version of f that lets the frame know it's being called."""
            # Self is the instance to which this method is attached.
            module_name = self._unique_name
            call_stack = current_frame().call_stack
            call_stack.append(module_name)
            call_stack.append(f.__name__)
            outs = f(self, *args, **kwargs)
            assert call_stack.pop() == f.__name__
            assert call_stack.pop() == module_name
            return outs

        return wrapped


    def get_param(identifier, shape):
        frame = current_frame()
        param_path = frame.create_param_path(identifier)

        if frame.is_initialising:
            frame.params[param_path] = np.random.normal(size=shape)

        return frame.params[param_path]


    class Linear(Module):
        def __init__(self, width):
            super().__init__()
            self._width = width

        @module_method  # Again, this decorator is behind-the-scenes in real Haiku.
        def __call__(self, x):
            w = get_param('w', shape=(x.shape[-1], self._width))
            b = get_param('b', shape=(self._width,))
            return x @ w + b
        
    init, apply = transform(lambda x: Linear(4)(x))

    params = init(data)
    print(parameter_shapes(params))
    print(apply(params, data))
    
    class MLP(Module):

        def __init__(self, widths):
            super().__init__()
            self._widths = widths

        @module_method
        def __call__(self, x):
            for w in self._widths:
                out = Linear(w)(x)
                x = jax.nn.sigmoid(out)
            return out
        
    init, apply = transform(lambda x: MLP([3, 5])(x))
    print(parameter_shapes(init(data)))
    
    class ParameterReuseTest(Module):

        @module_method
        def __call__(self, x):
            f = Linear(x.shape[-1])

            x = f(x)
            x = jax.nn.relu(x)
            return f(x)

    init, forward = transform(lambda x: ParameterReuseTest()(x))
    print(parameter_shapes(init(data)))
    
    ## Example training loop
    import matplotlib.pyplot as plt
    
    # Data: a quadratic curve.
    xs = np.linspace(-2., 2., num=128)[:, None]  # Generate array of shape (128, 1).
    ys = xs ** 2

    # Model
    def mlp(x):
        return MLP([128, 128, 1])(x)

    init, forward = transform(mlp)
    params = init(xs)
    parameter_shapes(params)
    
    # Loss function and update function
    def loss_fn(params, x, y):
        return jnp.mean((forward(params, x) - y) ** 2)

    LEARNING_RATE = 0.003

    @jax.jit
    def update(params, x, y):
        grads = jax.grad(loss_fn)(params, x, y)
        return jax.tree_util.tree_map(
            lambda p, g: p - LEARNING_RATE * g, params, grads
        )
    for _ in range(5000):
        params = update(params, xs, ys)
        
    plt.scatter(xs, ys, label='Data')
    plt.scatter(xs, forward(params, xs), label='Model prediction')
    plt.legend()
    plt.show()

In [None]:
# Test our function locally 
basic_strategy()
mini_haiku()

In [None]:
# Submit the function for code execution
guest_domain_client.api.services.code.request_code_execution(basic_strategy)
guest_domain_client.api.services.code.request_code_execution(mini_haiku)

In [None]:
guest_domain_client.api.services.code.basic_strategy()

### Go to the Data Owner Notebook for Part 2!

## Part 3 - Downloading the Results

In [None]:
guest_domain_client._api = None
_ = guest_domain_client.api

In [None]:
result = guest_domain_client.api.services.code.mini_haiku()

In [None]:
result.get_result()

In [None]:
print(result.get_stderr())