# Jax Level 0 Data Scientist Experience - Chapter 6
## Part 2 - New account registration and code execution requests

Link to the original Jax tutorial: https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html

In [1]:
# Import the necessary libraries
import syft as sy
sy.requires(">=0.8-beta")
import jax
import jax.numpy as jnp



✅ The installed version of syft==0.8.0b8 matches the requirement >=0.8b0


In [2]:
# Register a client to the domain
node = sy.orchestra.launch(name="test-domain-1")
guest_domain_client = node.client
guest_domain_client.register(name="Jane Doe", email="jane@caltech.edu", password="abc123", institution="Caltech", website="https://www.caltech.edu/")
guest_domain_client.login(email="jane@caltech.edu", password="abc123")

SQLite Store Path:
!open file:///tmp/7bca415d13ed1ec841f0d0aede098dbb.sqlite

> Domain: test-domain-1 - 7bca415d13ed1ec841f0d0aede098dbb - NodeType.DOMAIN

Services:
ActionService
DataSubjectMemberService
DataSubjectService
DatasetService
MessageService
MetadataService
NetworkService
PolicyService
ProjectService
RequestService
UserCodeService
UserService


<SyftClient - test-domain-1 <7bca415d13ed1ec841f0d0aede098dbb>: PythonConnection>

In [3]:
# Create a function for code execution
# ATTENTION: ALL LIBRARIES USED SHOULD BE DEFINED INSIDE THE FUNCTION CONTEXT!!!

@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def pytree_examples():
    import jax
    import jax.numpy as jnp

    example_trees = [
        [1, 'a', object()],
        (1, (2, 3), ()),
        [1, {'k1': 2, 'k2': (3, 4)}, 5],
        {'a': 2, 'b': (2, 3)},
        jnp.array([1, 2, 3]),
    ]

    # Let's see how many leaves they have:
    for pytree in example_trees:
        leaves = jax.tree_util.tree_leaves(pytree)
        print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")

@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def common_pytree_functions():
    list_of_lists = [
        [1, 2, 3],
        [1, 2],
        [1, 2, 3, 4]
    ]

    list_tree_map = jax.tree_map(lambda x: x*2, list_of_lists)
    
    another_list_of_lists = list_of_lists
    another_list_tree_map = jax.tree_map(lambda x, y: x+y, list_of_lists, another_list_of_lists)
    return list_tree_map, another_list_tree_map
    
@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def ml_model():
    import numpy as np
    import jax
    import jax.numpy as jnp

    def init_mlp_params(layer_widths):
        params = []
        for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
            params.append(
                dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in),
                    biases=np.ones(shape=(n_out,))
                    )
            )
        return params

    params = init_mlp_params([1, 128, 128, 1])
    print("Shapes:\n", jax.tree_map(lambda x: x.shape, params))

    def forward(params, x):
        *hidden, last = params
        for layer in hidden:
            x = jax.nn.relu(x @ layer['weights'] + layer['biases'])
        return x @ last['weights'] + last['biases']

    def loss_fn(params, x, y):
        return jnp.mean((forward(params, x) - y) ** 2)

    LEARNING_RATE = 0.0001

    @jax.jit
    def update(params, x, y):
        grads = jax.grad(loss_fn)(params, x, y)
        # Note that `grads` is a pytree with the same structure as `params`.
        # `jax.grad` is one of the many JAX functions that has
        # built-in support for pytrees.

        # This is handy, because we can apply the SGD update using tree utils:
        return jax.tree_map(
            lambda p, g: p - LEARNING_RATE * g, params, grads
        )
    import matplotlib.pyplot as plt

    xs = np.random.normal(size=(128, 1))
    ys = xs ** 2

    for _ in range(1000):
        arams = update(params, xs, ys)

    plt.scatter(xs, ys)
    plt.scatter(xs, forward(params, xs), label='Model prediction')
    plt.legend()
    return params
    
@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def custom_pytree_nodes():
    import jax
    class MyContainer:
        """A named container."""

        def __init__(self, name: str, a: int, b: int, c: int):
            self.name = name
            self.a = a
            self.b = b
            self.c = c
    print(jax.tree_util.tree_leaves([
            MyContainer('Alice', 1, 2, 3),
            MyContainer('Bob', 4, 5, 6)
        ]))
    
    try:
        print(jax.tree_map(lambda x: x + 1, [
            MyContainer('Alice', 1, 2, 3),
            MyContainer('Bob', 4, 5, 6)
        ]))
    except Exception as e:
        print(e)
        
    from typing import Tuple, Iterable

    def flatten_MyContainer(container) -> Tuple[Iterable[int], str]:
        """Returns an iterable over container contents, and aux data."""
        flat_contents = [container.a, container.b, container.c]

        # we don't want the name to appear as a child, so it is auxiliary data.
        # auxiliary data is usually a description of the structure of a node,
        # e.g., the keys of a dict -- anything that isn't a node's children.
        aux_data = container.name
        return flat_contents, aux_data

    def unflatten_MyContainer(
        aux_data: str, flat_contents: Iterable[int]) -> MyContainer:
        """Converts aux data and the flat contents into a MyContainer."""
        return MyContainer(aux_data, *flat_contents)

    jax.tree_util.register_pytree_node(
        MyContainer, flatten_MyContainer, unflatten_MyContainer)

    return jax.tree_util.tree_leaves([
        MyContainer('Alice', 1, 2, 3),
        MyContainer('Bob', 4, 5, 6)
    ])

@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def gotchas():
    import jax
    import jax.numpy as jnp
    a_tree = [jnp.zeros((2, 3)), jnp.zeros((3, 4))]

    # Try to make another tree with ones instead of zeros
    shapes = jax.tree_map(lambda x: x.shape, a_tree)
    print(jax.tree_map(jnp.ones, shapes))
    print(jax.tree_util.tree_leaves([None, None, None]))

@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def patterns():
    def tree_transpose(list_of_trees):
        """Convert a list of trees of identical structure into a single tree of lists."""
        return jax.tree_map(lambda *xs: list(xs), *list_of_trees)


    # Convert a dataset from row-major to column-major:
    episode_steps = [dict(t=1, obs=3), dict(t=2, obs=4)]
    print(tree_transpose(episode_steps))
    
    print(jax.tree_transpose(
        outer_treedef = jax.tree_structure([0 for e in episode_steps]),
        inner_treedef = jax.tree_structure(episode_steps[0]),
        pytree_to_transpose = episode_steps
    ))


In [4]:
# Test our function locally 
pytree_examples()
common_pytree_functions()
ml_model()
custom_pytree_nodes()
gotchas()
patterns()



Convolve: [11. 20. 29.]
Naive solution for batched convolve [[11. 20. 29.]
 [11. 20. 29.]]
Vectorized solution for batched convolve [[11. 20. 29.]
 [11. 20. 29.]]
Convolve: [11. 20. 29.]
Automatic vectorization [[11. 20. 29.]
 [11. 20. 29.]]
Automatic vectorization with axes [[11. 11.]
 [20. 20.]
 [29. 29.]]
Automatic vectorization with one ax [[11. 20. 29.]
 [11. 20. 29.]]
Jitted Automatic vectorization [[11. 20. 29.]
 [11. 20. 29.]]


In [5]:
# Submit the function for code execution
guest_domain_client.api.services.code.request_code_execution(pytree_examples)
guest_domain_client.api.services.code.request_code_execution(common_pytree_functions)
guest_domain_client.api.services.code.request_code_execution(ml_model)
guest_domain_client.api.services.code.request_code_execution(custom_pytree_nodes)
guest_domain_client.api.services.code.request_code_execution(gotchas)
guest_domain_client.api.services.code.request_code_execution(patterns)

queue_task: Start
task_runner: Start


syft.service.code.user_code.SubmitUserCode


task_producer: Start


Ok(syft.service.code.user_code.UserCode)


HANDLE API ( fieldsName = ["data"],
  fieldsData = [
    [ "\000\000\000\000\377\001\000\000\000\000\000\000\000\000\004\000\031\000\000\000\016\000\000\000\031\000\000\000\016\000\000\000\005\000\000\000\212\000\000\000\000\000\000\000\000\000\000\000result.result.Ok\000\000\000\000\000\000\000\000\005\000\000\000:\000\000\000\005\000\000\000\016\000\000\000_value\000\000\001\000\000\000\302|\000\000\000\000\000\000\362\001\000\000\000\000\000\000\000\000\004\000!\000\000\000N\000\000\000A\000\000\000N\000\000\000\005\000\000\000*\001\000\000\000\000\000\000\000\000\000\000syft.service.request.request.Request\000\000\000\000E\000\000\000r\000\000\000y\000\000\000\322\000\000\000\265\000\000\000B\000\000\000\211\004\000\000\032\000\000\000\355\004\000\000J\000\000\000U\005\000\000j\000\000\000\241\005\000\000j\000\000\000\205\006\000\000\332\000\000\000\r\a\000\000:\000\000\000)\000\000\000\016\000\000\000e\000\000\000\016\000\000\000\225\000\000\000\016\000\000\000i\004\000\000\016\00

```python
class Request:
  id: str = 5c8e58623dfb47ac817578be839a3857
  requesting_user_verify_key: str = 8ab18aff8196409a0c0aa2da0d34dad0ad56269ce5ed2566a342c404cea9aae2
  approving_user_verify_key: str = None
  request_time: str = 2023-04-21 08:40:21
  approval_time: str = None
  status: str = RequestStatus.PENDING
  node_uid: str = 7bca415d13ed1ec841f0d0aede098dbb
  request_hash: str = "74dcb8f8847d84a69458be95fd99032d08a9463c98d53026e8ed4ef8a9f4640e"
  changes: str = [syft.service.request.request.UserCodeStatusChange]

```

In [9]:
guest_domain_client.api.services.code.higher_order_derivatives()

queue_task: Start
task_runner: Start
HANDLE API ( fieldsName = ["data"],
  fieldsData = [
    [ "\000\000\000\0005\000\000\000\000\000\000\000\000\000\004\000!\000\000\000\016\000\000\000!\000\000\000\016\000\000\000\005\000\000\000\032\001\000\000\000\000\000\000\000\000\000\000syft.service.response.SyftNotReady\000\000\000\000\000\000\005\000\000\000B\000\000\000\005\000\000\000\016\000\000\000message\000\001\000\000\000\302\t\000\000\000\000\000\000&\000\000\000\000\000\000\000\000\000\004\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\005\000\000\000j\000\000\000\t\000\000\000\016\000\000\000builtins.str\000\000\000\000\001\000\000\000j\a\000\000<class \'syft.service.code.user_code.UserCode\'> Your code is waiting for approval: {NodeView(node_name=\'test-domain-1\', verify_key=aec6ea4dfc049ceacaeeebc493167a88a200ddc367b1fa32da652444b635d21f): <UserCodeStatus.SUBMITTED: \'submitted\'>}\000\000\000" ] ],
  fullyQualifiedName = "syft.client.api.SyftAPIData" )
HAND

### Go to the Data Owner Notebook for Part 2!

## Part 3 - Downloading the Results

In [12]:
result = guest_domain_client.api.services.code.higher_order_derivatives()

queue_task: Start
task_runner: Start


Policy not valid


<syft.service.action.action_service.ActionService object at 0x7fbfd874a050>
Ok(syft.service.code.user_code.UserCodeExecutionResult)
HANDLE API ( fieldsName = ["data"],
  fieldsData = [
  fullyQualifiedName = "syft.client.api.SyftAPIData" )
HANDLE API SIGNED syft.client.api.SignedSyftAPICall
task_runner: End
task_producer: Start
task_producer: End
queue_task: End
HANDLE API ( fieldsName = ["data"],
  fieldsData = [
  fullyQualifiedName = "syft.client.api.SyftAPIData" )
HANDLE API SIGNED syft.client.api.SignedSyftAPICall


In [13]:
result.get_result()

None

In [14]:
print(result.get_stdout())

Convolve: [11. 20. 29.]
Automatic vectorization [[11. 20. 29.]
 [11. 20. 29.]]
Automatic vectorization with axes [[11. 11.]
 [20. 20.]
 [29. 29.]]
Automatic vectorization with one ax [[11. 20. 29.]
 [11. 20. 29.]]
Jitted Automatic vectorization [[11. 20. 29.]
 [11. 20. 29.]]

