# Jax Level 0 Data Scientist Experience - Chapter 9 - Getting started with Jax MLPs, CNNs, and RNNs

Link to the original blog post by Robert Tjarko Lange: https://roberttlange.com/posts/2020/03/blog-post-10/

In [1]:
# Import the necessary libraries
import syft as sy
sy.requires(">=0.8-beta")
# import jax
# import jax.numpy as jnp



✅ The installed version of syft==0.8.1b1 matches the requirement >=0.8b0


## Part 1 - User login and code execution requests

In [2]:
# Register a client to the domain
node = sy.orchestra.launch(name="test-domain-1")
guest_domain_client = node.client
guest_domain_client.register(name="Jane Doe", email="jane@caltech.edu", password="abc123", institution="Caltech", website="https://www.caltech.edu/")
guest_domain_client.login(email="jane@caltech.edu", password="abc123")

SQLite Store Path:
!open file:///var/folders/sz/hkfsnn612hq56r7cs5rd540r0000gn/T/7bca415d13ed1ec841f0d0aede098dbb.sqlite



<SyftClient - test-domain-1 <7bca415d13ed1ec841f0d0aede098dbb>: PythonConnection>

### What is this JAX thing?

In [3]:
# TODO: remove this temporary workaround
import numpy as np
dummy_tensor = np.array([1,2,3])
dummy_action = sy.ActionObject.from_obj(dummy_tensor)
ptr = dummy_action.send(guest_domain_client)

In [4]:
# Create a function for code execution
@sy.syft_function(input_policy=sy.ExactMatch(dummy=ptr),
                  output_policy=sy.SingleExecutionExactOutput())
def func_dot_time_comparison(dummy):
    # Note: using different naming conventions for numpy and jax
    # compared to the original blog post, i.e. onp => np, np => jnp.
    import numpy as np
    import jax.numpy as jnp
    from jax import random

    # Generate key which is used to generate random numbers
    key = random.PRNGKey(1)

    # Generate a random matrix
    x = random.uniform(key, (1000, 1000))
    # Compare running times of 3 different matrix multiplications
    %time y = np.dot(x, x)
    %time y = jnp.dot(x, x)
    %time y = jnp.dot(x, x).block_until_ready()

### A Few Basic Concepts & Conventions - jit, grad & vmap

In [5]:
@sy.syft_function(input_policy=sy.ExactMatch(dummy=ptr),
                  output_policy=sy.SingleExecutionExactOutput())
def func_jit(dummy):
    # Note: using different naming conventions for numpy and jax
    # compared to the original blog post, i.e. onp => np, np => jnp.
    import numpy as np
    import jax.numpy as jnp
    from jax import random, jit, grad

    def ReLU(x):
        """ Rectified Linear Unit (ReLU) activation function """
        return jnp.maximum(0, x)
    
    # NOTE: generating the random matrix again.
    # Any way to share code between syft.functions?
    key = random.PRNGKey(1)
    x = random.uniform(key, (1000, 1000))

    jit_ReLU = jit(ReLU)

    %time out = ReLU(x).block_until_ready()
    # Call jitted version to compile for evaluation time!
    %time jit_ReLU(x).block_until_ready()
    %time out = jit_ReLU(x).block_until_ready()

In [6]:
@sy.syft_function(input_policy=sy.ExactMatch(dummy=ptr),
                  output_policy=sy.SingleExecutionExactOutput())
def func_grad(dummy):
    # Note: using different naming conventions for numpy and jax
    # compared to the original blog post, i.e. onp => np, np => jnp.
    import numpy as np
    import jax.numpy as jnp
    from jax import random, jit, grad

    # NOTE: using the same ReLU function as in the previous example
    def ReLU(x):
        """ Rectified Linear Unit (ReLU) activation function """
        return jnp.maximum(0, x)
    
    def FiniteDiffGrad(x):
        """ Compute the finite difference derivative approx for the ReLU"""
        return np.array((ReLU(x + 1e-3) - ReLU(x - 1e-3)) / (2 * 1e-3))
    
    # NOTE: generating the random matrix again.
    # Any way to share code between syft.functions?
    key = random.PRNGKey(1)
    x = random.uniform(key, (1000, 1000))

    # Compare the Jax gradient with a finite difference approximation
    print("Jax Grad: ", jit(grad(jit(ReLU)))(2.))
    print("FD Gradient:", FiniteDiffGrad(2.))

In [7]:
@sy.syft_function(input_policy=sy.ExactMatch(dummy=ptr),
                  output_policy=sy.SingleExecutionExactOutput())
def func_vmap(dummy):
    # Note: using different naming conventions for numpy and jax
    # compared to the original blog post, i.e. onp => np, np => jnp.
    import numpy as np
    import jax.numpy as jnp
    from jax import random, jit, vmap

    batch_dim = 32
    feature_dim = 100
    hidden_dim = 512

    # Generate a batch of vectors to process
    key = random.PRNGKey(1)
    X = random.normal(key, (batch_dim, feature_dim))

    # Generate Gaussian weights and biases
    params = [random.normal(key, (hidden_dim, feature_dim)),
            random.normal(key, (hidden_dim, ))]
    
    # NOTE: using the same ReLU function as in the previous example
    def ReLU(x):
        """ Rectified Linear Unit (ReLU) activation function """
        return jnp.maximum(0, x)

    def relu_layer(params, x):
        """ Simple ReLu layer for single sample """
        return ReLU(np.dot(params[0], x) + params[1])

    def batch_version_relu_layer(params, x):
        """ Error prone batch version """
        return ReLU(np.dot(X, params[0].T) + params[1])

    def vmap_relu_layer(params, x):
        """ vmap version of the ReLU layer """
        return jit(vmap(relu_layer, in_axes=(None, 0), out_axes=0))

    out = np.stack([relu_layer(params, X[i, :]) for i in range(X.shape[0])])
    out = batch_version_relu_layer(params, X)
    out = vmap_relu_layer(params, X)    

In [8]:
# Test our functions locally
# func_dot_time_comparison()
"""
NOTE: running time comparison results differ from the ones on the blog post:
CPU times: user 45.3 ms, sys: 6.12 ms, total: 51.5 ms
Wall time: 16 ms
CPU times: user 24.1 ms, sys: 3.28 ms, total: 27.4 ms
Wall time: 6.85 ms
CPU times: user 149 ms, sys: 23 ms, total: 172 ms
Wall time: 30.4 ms
"""

# func_jit()
# func_grad()
# func_vmap()

'\nNOTE: running time comparison results differ from the ones on the blog post:\nCPU times: user 45.3 ms, sys: 6.12 ms, total: 51.5 ms\nWall time: 16 ms\nCPU times: user 24.1 ms, sys: 3.28 ms, total: 27.4 ms\nWall time: 6.85 ms\nCPU times: user 149 ms, sys: 23 ms, total: 172 ms\nWall time: 30.4 ms\n'

In [14]:
# Submit the function for code execution
# guest_domain_client.api.services.code.request_code_execution(func_dot_time_comparison)
# guest_domain_client.api.services.code.request_code_execution(func_jit)
# guest_domain_client.api.services.code.request_code_execution(func_grad)
guest_domain_client.api.services.code.request_code_execution(func_vmap)


```python
class Request:
  id: str = 4fa6610a7a034987b40fb0e664830e39
  requesting_user_verify_key: str = e63f655cceae01f1118fe0353c1a3f2898488b21677810fbc69a3b05cac4723e
  approving_user_verify_key: str = None
  request_time: str = 2023-05-11 09:47:05
  approval_time: str = None
  status: str = RequestStatus.PENDING
  node_uid: str = 7bca415d13ed1ec841f0d0aede098dbb
  request_hash: str = "012f462d5b608c048b585b5bcff2f9b5545ea8bdd630529b51b6868fdc5fdc9f"
  changes: str = [syft.service.request.request.UserCodeStatusChange]

```

## Part 3
### Training a MNIST Multilayer Perceptron in JAX

In [10]:
# seaborn used by helpers.py
%pip install seaborn

You should consider upgrading via the '/Users/antti/.pyenv/versions/3.10.4/envs/jax_1/bin/python -m pip install --upgrade pip' command.[0m[33m
[0mNote: you may need to restart the kernel to use updated packages.


In [11]:
@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def func_mlp():
    # Import some additional JAX and dataloader helpers
    from jax.scipy.special import logsumexp
    from jax.experimental import optimizers

    # import torch
    # from torchvision import datasets, transforms

    import time
    from helpers import plot_mnist_examples

    