# JAX 101 - 04 Advanced Automatic Differentiation
Link to the original JAX tutorial: https://jax.readthedocs.io/en/latest/jax-101/04-advanced-autodiff.html

## Part 0 - Data Owner Setup

In [1]:
# Import the necessary libraries
import syft as sy
sy.requires(">=0.8,<0.9")

import jax
import jax.numpy as jnp
import numpy as np



✅ The installed version of syft==0.8.1b3 matches the requirement >=0.8 and the requirement <0.9


In [2]:
# Launch the domain
node = sy.orchestra.launch(name="test-domain-1", reset=True, dev_mode=True)
data_owner_client = node.login(email="info@openmined.org", password="changethis")

SQLite Store Path:
!open file:///var/folders/sz/hkfsnn612hq56r7cs5rd540r0000gn/T/7bca415d13ed4ec881f0d0aede098dbb.sqlite



## Part 1 - Data Scientist

In [3]:
# Register a client to the domain
node = sy.orchestra.launch(name="test-domain-1")
data_scientist_client = node.client
data_scientist_client.register(name="Jane Doe", email="jane@caltech.edu", password="abc123", institution="Caltech", website="https://www.caltech.edu/")
data_scientist_client.login(email="jane@caltech.edu", password="abc123")

SQLite Store Path:
!open file:///var/folders/sz/hkfsnn612hq56r7cs5rd540r0000gn/T/7bca415d13ed4ec881f0d0aede098dbb.sqlite



<SyftClient - test-domain-1 <7bca415d13ed4ec881f0d0aede098dbb>: PythonConnection>

In [4]:
# Create a function for code execution
# ATTENTION: ALL LIBRARIES USED SHOULD BE DEFINED INSIDE THE FUNCTION CONTEXT!!!

@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def higher_order_derivatives():
    import jax
    f = lambda x: x**3 + 2*x**2 - 3*x + 1

    dfdx = jax.grad(f)
    d2fdx = jax.grad(dfdx)
    d3fdx = jax.grad(d2fdx)
    d4fdx = jax.grad(d3fdx)
    
    print(dfdx(1.))
    print(d2fdx(1.))
    print(d3fdx(1.))
    print(d4fdx(1.))


@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def stopping_gradients():
    import jax
    import jax.numpy as jnp
    
    # Value function and initial parameters
    value_fn = lambda theta, state: jnp.dot(theta, state)
    theta = jnp.array([0.1, -0.1, 0.])
    
    # An example transition.
    s_tm1 = jnp.array([1., 2., -1.])
    r_t = jnp.array(1.)
    s_t = jnp.array([2., 1., 0.])

    def td_loss(theta, s_tm1, r_t, s_t):
        v_tm1 = value_fn(theta, s_tm1)
        target = r_t + value_fn(theta, s_t)
        return (target - v_tm1) ** 2

    td_update = jax.grad(td_loss)
    delta_theta = td_update(theta, s_tm1, r_t, s_t)

    print("Pseudo naive loss", delta_theta)
    
    def td_loss(theta, s_tm1, r_t, s_t):
        v_tm1 = value_fn(theta, s_tm1)
        target = r_t + value_fn(theta, s_t)
        return (jax.lax.stop_gradient(target) - v_tm1) ** 2

    td_update = jax.grad(td_loss)
    delta_theta = td_update(theta, s_tm1, r_t, s_t)

    print("Correct loss", delta_theta)
    
    perex_grads = jax.jit(jax.vmap(jax.grad(td_loss), in_axes=(None, 0, 0, 0)))

    # Test it:
    batched_s_tm1 = jnp.stack([s_tm1, s_tm1])
    batched_r_t = jnp.stack([r_t, r_t])
    batched_s_t = jnp.stack([s_t, s_t])

    print("Per example grads", perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t))

    dtdloss_dtheta = jax.grad(td_loss)

    print("Gradient loss on unbatched inputs", dtdloss_dtheta(theta, s_tm1, r_t, s_t))

    almost_perex_grads = jax.vmap(dtdloss_dtheta)

    batched_theta = jnp.stack([theta, theta])
    member_gradient = almost_perex_grads(batched_theta, batched_s_tm1, batched_r_t, batched_s_t)
    print("Gradient for on member of a batch", member_gradient)

    inefficient_perex_grads = jax.vmap(dtdloss_dtheta, in_axes=(None, 0, 0, 0))
    grads = inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
    print("Inefficient gradients", grads)
    
    perex_grads = jax.jit(inefficient_perex_grads)
    grads = perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
    print("Efficient gradients", grads)
    
    %timeit inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()
    %timeit perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()


@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def straight_through_estimator():
    import jax
    import jax.numpy as jnp
    
    def f(x):
        return jnp.round(x)  # non-differentiable

    def straight_through_f(x):
        # Create an exactly-zero expression with Sterbenz lemma that has
        # an exactly-one gradient.
        zero = x - jax.lax.stop_gradient(x)
        return zero + jax.lax.stop_gradient(f(x))

    print("f(x): ", f(3.2))
    print("straight_through_f(x):", straight_through_f(3.2))

    print("grad(f)(x):", jax.grad(f)(3.2))
    print("grad(straight_through_f)(x):", jax.grad(straight_through_f)(3.2))

In [5]:
# Test our function locally 
higher_order_derivatives()
stopping_gradients()
straight_through_estimator()

4.0
10.0
6.0
0.0
Pseudo naive loss [ 2.4 -2.4  2.4]
Correct loss [-2.4 -4.8  2.4]
Per example grads [[-2.4 -4.8  2.4]
 [-2.4 -4.8  2.4]]
Gradient loss on unbatched inputs [-2.4 -4.8  2.4]
Gradient for on member of a batch [[-2.4 -4.8  2.4]
 [-2.4 -4.8  2.4]]
Inefficient gradients [[-2.4 -4.8  2.4]
 [-2.4 -4.8  2.4]]
Efficient gradients [[-2.4 -4.8  2.4]
 [-2.4 -4.8  2.4]]
4.03 ms ± 42.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
6.15 µs ± 125 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
f(x):  3.0
straight_through_f(x): 3.0
grad(f)(x): 0.0
grad(straight_through_f)(x): 1.0


In [6]:
# Submit the function for code execution
data_scientist_client.api.services.code.request_code_execution(higher_order_derivatives)
data_scientist_client.api.services.code.request_code_execution(stopping_gradients)
data_scientist_client.api.services.code.request_code_execution(straight_through_estimator)

```python
class Request:
  id: str = a6faa675879d48308d92829dfb258e73
  requesting_user_verify_key: str = c5dc97891534f57b6030b4eecd2a5d95976e393e3e2f41d15fec7ebbf518a9d7
  approving_user_verify_key: str = None
  request_time: str = 2023-05-29 06:23:13
  approval_time: str = None
  status: str = RequestStatus.PENDING
  node_uid: str = 7bca415d13ed4ec881f0d0aede098dbb
  request_hash: str = "3059373cddc0bdee707154cc110710fb6fee5a73f23619d079ce1fdeea5254da"
  changes: str = [syft.service.request.request.UserCodeStatusChange]

```

## Part 2 - Data Owner Reviewing and Approving Requests

In [7]:
data_owner_client = node.login(email="info@openmined.org", password="changethis")

In [8]:
# Get messages from domain
messages = data_owner_client.api.services.messages.get_all()
messages

Unnamed: 0,type,id,subject,status,created_at,linked_obj
0,syft.service.message.messages.Message,ecd9f02838ec4987a9135a7274617a8a,Approval Request,MessageStatus.UNDELIVERED,2023-05-29 06:23:13,<<class 'syft.service.request.request.Request'...
1,syft.service.message.messages.Message,b3ec0d138e4b4c3ba7fe30e090624a49,Approval Request,MessageStatus.UNDELIVERED,2023-05-29 06:23:13,<<class 'syft.service.request.request.Request'...
2,syft.service.message.messages.Message,d55fd935cf204a97b7b0c01e0424bfc8,Approval Request,MessageStatus.UNDELIVERED,2023-05-29 06:23:13,<<class 'syft.service.request.request.Request'...


In [9]:
from helpers import review_request, run_submitted_function, accept_request

for message in messages:
    review_request(message)
    real_result = run_submitted_function(message)
    accept_request(message, real_result)

straight_through_estimator
@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def straight_through_estimator():
    import jax
    import jax.numpy as jnp
    
    def f(x):
        return jnp.round(x)  # non-differentiable

    def straight_through_f(x):
        # Create an exactly-zero expression with Sterbenz lemma that has
        # an exactly-one gradient.
        zero = x - jax.lax.stop_gradient(x)
        return zero + jax.lax.stop_gradient(f(x))

    print("f(x): ", f(3.2))
    print("straight_through_f(x):", straight_through_f(3.2))

    print("grad(f)(x):", jax.grad(f)(3.2))
    print("grad(straight_through_f)(x):", jax.grad(straight_through_f)(3.2))

syft.service.code.user_code.UserCodeExecutionResult


exec_result=syft.service.code.user_code.UserCodeExecutionResult
action_object=Pointer:
syft.service.code.user_code.UserCodeExecutionResult


message='Request a6faa675879d48308d92829dfb258e73 changes applied'
higher_order_derivatives
@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def higher_order_derivatives():
    import jax
    f = lambda x: x**3 + 2*x**2 - 3*x + 1

    dfdx = jax.grad(f)
    d2fdx = jax.grad(dfdx)
    d3fdx = jax.grad(d2fdx)
    d4fdx = jax.grad(d3fdx)
    
    print(dfdx(1.))
    print(d2fdx(1.))
    print(d3fdx(1.))
    print(d4fdx(1.))

syft.service.code.user_code.UserCodeExecutionResult


exec_result=syft.service.code.user_code.UserCodeExecutionResult
action_object=Pointer:
syft.service.code.user_code.UserCodeExecutionResult


message='Request c33c7abf37884e10a2cd9c98e87b7e99 changes applied'
stopping_gradients
@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def stopping_gradients():
    import jax
    import jax.numpy as jnp
    
    # Value function and initial parameters
    value_fn = lambda theta, state: jnp.dot(theta, state)
    theta = jnp.array([0.1, -0.1, 0.])
    
    # An example transition.
    s_tm1 = jnp.array([1., 2., -1.])
    r_t = jnp.array(1.)
    s_t = jnp.array([2., 1., 0.])

    def td_loss(theta, s_tm1, r_t, s_t):
        v_tm1 = value_fn(theta, s_tm1)
        target = r_t + value_fn(theta, s_t)
        return (target - v_tm1) ** 2

    td_update = jax.grad(td_loss)
    delta_theta = td_update(theta, s_tm1, r_t, s_t)

    print("Pseudo naive loss", delta_theta)
    
    def td_loss(theta, s_tm1, r_t, s_t):
        v_tm1 = value_fn(theta, s_tm1)
        target = r_t + value_fn(theta, s_t)
        return (jax.lax.stop_gradi

exec_result=syft.service.code.user_code.UserCodeExecutionResult
action_object=Pointer:
syft.service.code.user_code.UserCodeExecutionResult


message='Request c7ad57e183ed4c259eadf87c02fe9e53 changes applied'


<Figure size 640x480 with 0 Axes>

## Part 3 - Downloading the Results

### Tutorial complete 👏

In [10]:
result = data_scientist_client.api.services.code.higher_order_derivatives()
assert not isinstance(result, sy.SyftError)

result = data_scientist_client.api.services.code.stopping_gradients()
assert not isinstance(result, sy.SyftError)

result = data_scientist_client.api.services.code.straight_through_estimator()
assert not isinstance(result, sy.SyftError)

In [11]:
if node.node_type.value == "python":
    node.land()