# JAX 101 - 06 Working with PyTrees
Link to the original JAX tutorial: https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html

## Part 0 - Data Owner Setup

In [None]:
# Import the necessary libraries
import syft as sy
sy.requires(">=0.8,<0.9")

import jax
import jax.numpy as jnp
import numpy as np

In [None]:
# Launch the domain
node = sy.orchestra.launch(name="test-domain-1", reset=True)
domain_client = node.login(email="info@openmined.org", password="changethis")

## Part 1 - Data Scientist

In [None]:
# Register a client to the domain
node = sy.orchestra.launch(name="test-domain-1")
data_scientist_client = node.client
data_scientist_client.register(name="Jane Doe", email="jane@caltech.edu", password="abc123", institution="Caltech", website="https://www.caltech.edu/")
data_scientist_client.login(email="jane@caltech.edu", password="abc123")

In [None]:
# Create a function for code execution
# ATTENTION: ALL LIBRARIES USED SHOULD BE DEFINED INSIDE THE FUNCTION CONTEXT!!!

@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def pytree_examples():
    import jax
    import jax.numpy as jnp

    example_trees = [
        [1, 'a', object()],
        (1, (2, 3), ()),
        [1, {'k1': 2, 'k2': (3, 4)}, 5],
        {'a': 2, 'b': (2, 3)},
        jnp.array([1, 2, 3]),
    ]

    # Let's see how many leaves they have:
    for pytree in example_trees:
        leaves = jax.tree_util.tree_leaves(pytree)
        print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")

@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def common_pytree_functions():
    list_of_lists = [
        [1, 2, 3],
        [1, 2],
        [1, 2, 3, 4]
    ]

    list_tree_map = jax.tree_map(lambda x: x*2, list_of_lists)
    
    another_list_of_lists = list_of_lists
    another_list_tree_map = jax.tree_map(lambda x, y: x+y, list_of_lists, another_list_of_lists)
    return list_tree_map, another_list_tree_map
    
@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def ml_model():
    import numpy as np
    import jax
    import jax.numpy as jnp

    def init_mlp_params(layer_widths):
        params = []
        for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
            params.append(
                dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in),
                    biases=np.ones(shape=(n_out,))
                    )
            )
        return params

    params = init_mlp_params([1, 128, 128, 1])
    print("Shapes:\n", jax.tree_map(lambda x: x.shape, params))

    def forward(params, x):
        *hidden, last = params
        for layer in hidden:
            x = jax.nn.relu(x @ layer['weights'] + layer['biases'])
        return x @ last['weights'] + last['biases']

    def loss_fn(params, x, y):
        return jnp.mean((forward(params, x) - y) ** 2)

    LEARNING_RATE = 0.0001

    @jax.jit
    def update(params, x, y):
        grads = jax.grad(loss_fn)(params, x, y)
        # Note that `grads` is a pytree with the same structure as `params`.
        # `jax.grad` is one of the many JAX functions that has
        # built-in support for pytrees.

        # This is handy, because we can apply the SGD update using tree utils:
        return jax.tree_map(
            lambda p, g: p - LEARNING_RATE * g, params, grads
        )
    import matplotlib.pyplot as plt

    xs = np.random.normal(size=(128, 1))
    ys = xs ** 2

    for _ in range(1000):
        arams = update(params, xs, ys)

    plt.scatter(xs, ys)
    plt.scatter(xs, forward(params, xs), label='Model prediction')
    plt.legend()
    return params
    
@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def custom_pytree_nodes():
    import jax
    class MyContainer:
        """A named container."""

        def __init__(self, name: str, a: int, b: int, c: int):
            self.name = name
            self.a = a
            self.b = b
            self.c = c
    print(jax.tree_util.tree_leaves([
            MyContainer('Alice', 1, 2, 3),
            MyContainer('Bob', 4, 5, 6)
        ]))
    
    try:
        print(jax.tree_map(lambda x: x + 1, [
            MyContainer('Alice', 1, 2, 3),
            MyContainer('Bob', 4, 5, 6)
        ]))
    except Exception as e:
        print(e)
        
    from typing import Tuple, Iterable

    def flatten_MyContainer(container) -> Tuple[Iterable[int], str]:
        """Returns an iterable over container contents, and aux data."""
        flat_contents = [container.a, container.b, container.c]

        # we don't want the name to appear as a child, so it is auxiliary data.
        # auxiliary data is usually a description of the structure of a node,
        # e.g., the keys of a dict -- anything that isn't a node's children.
        aux_data = container.name
        return flat_contents, aux_data

    def unflatten_MyContainer(
        aux_data: str, flat_contents: Iterable[int]) -> MyContainer:
        """Converts aux data and the flat contents into a MyContainer."""
        return MyContainer(aux_data, *flat_contents)

    jax.tree_util.register_pytree_node(
        MyContainer, flatten_MyContainer, unflatten_MyContainer)

    return jax.tree_util.tree_leaves([
        MyContainer('Alice', 1, 2, 3),
        MyContainer('Bob', 4, 5, 6)
    ])

@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def gotchas():
    import jax
    import jax.numpy as jnp
    a_tree = [jnp.zeros((2, 3)), jnp.zeros((3, 4))]

    # Try to make another tree with ones instead of zeros
    shapes = jax.tree_map(lambda x: x.shape, a_tree)
    print(jax.tree_map(jnp.ones, shapes))
    print(jax.tree_util.tree_leaves([None, None, None]))

@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def patterns():
    def tree_transpose(list_of_trees):
        """Convert a list of trees of identical structure into a single tree of lists."""
        return jax.tree_map(lambda *xs: list(xs), *list_of_trees)


    # Convert a dataset from row-major to column-major:
    episode_steps = [dict(t=1, obs=3), dict(t=2, obs=4)]
    print(tree_transpose(episode_steps))
    
    print(jax.tree_transpose(
        outer_treedef = jax.tree_structure([0 for e in episode_steps]),
        inner_treedef = jax.tree_structure(episode_steps[0]),
        pytree_to_transpose = episode_steps
    ))

In [None]:
# Test our function locally 
pytree_examples()
common_pytree_functions()
ml_model()
custom_pytree_nodes()
gotchas()
patterns()

In [None]:
# Submit the function for code execution
result = data_scientist_client.api.services.code.request_code_execution(pytree_examples)
assert not isinstance(result, sy.SyftError)

result = data_scientist_client.api.services.code.request_code_execution(common_pytree_functions)
assert not isinstance(result, sy.SyftError)

result = data_scientist_client.api.services.code.request_code_execution(ml_model)
assert not isinstance(result, sy.SyftError)

result = data_scientist_client.api.services.code.request_code_execution(custom_pytree_nodes)
assert not isinstance(result, sy.SyftError)

result = data_scientist_client.api.services.code.request_code_execution(gotchas)
assert not isinstance(result, sy.SyftError)

result = data_scientist_client.api.services.code.request_code_execution(patterns)
assert not isinstance(result, sy.SyftError)

## Part 2 - Data Owner Reviewing and Approving Requests

In [None]:
# Get messages from domain
messages = domain_client.api.services.messages.get_all()
messages

In [None]:
from helpers import review_request, run_submitted_function, accept_request

for message in messages:
    review_request(message)
    real_result = run_submitted_function(message)
    accept_request(message, real_result)

## Part 3 - Downloading the Results

### Tutorial complete 👏

In [None]:
result data_scientist_client.api.services.code.pytree_examples()
assert
data_scientist_client.api.services.code.common_pytree_functions()
data_scientist_client.api.services.code.ml_model()
data_scientist_client.api.services.code.custom_pytree_nodes()
data_scientist_client.api.services.code.gotchas()
data_scientist_client.api.services.code.patterns()

In [None]:
if node.node_type.value == "python":
    node.land()