# JAX 101 - 04 Advanced Automatic Differentiation
Link to the original JAX tutorial: https://jax.readthedocs.io/en/latest/jax-101/04-advanced-autodiff.html

## Part 0 - Data Owner Setup

In [None]:
# Import the necessary libraries
import syft as sy
sy.requires(">=0.8,<0.9")

import jax
import jax.numpy as jnp
import numpy as np

In [None]:
# Launch the domain
node = sy.orchestra.launch(name="test-domain-1", reset=True, dev_mode=True)
data_owner_client = node.login(email="info@openmined.org", password="changethis")

## Part 1 - Data Scientist

In [None]:
# Register a client to the domain
node = sy.orchestra.launch(name="test-domain-1")
data_scientist_client = node.client
data_scientist_client.register(name="Jane Doe", email="jane@caltech.edu", password="abc123", institution="Caltech", website="https://www.caltech.edu/")
data_scientist_client.login(email="jane@caltech.edu", password="abc123")

In [None]:
# Create a function for code execution
# ATTENTION: ALL LIBRARIES USED SHOULD BE DEFINED INSIDE THE FUNCTION CONTEXT!!!

@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def higher_order_derivatives():
    import jax
    f = lambda x: x**3 + 2*x**2 - 3*x + 1

    dfdx = jax.grad(f)
    d2fdx = jax.grad(dfdx)
    d3fdx = jax.grad(d2fdx)
    d4fdx = jax.grad(d3fdx)
    
    print(dfdx(1.))
    print(d2fdx(1.))
    print(d3fdx(1.))
    print(d4fdx(1.))


@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def stopping_gradients():
    import jax
    import jax.numpy as jnp
    
    # Value function and initial parameters
    value_fn = lambda theta, state: jnp.dot(theta, state)
    theta = jnp.array([0.1, -0.1, 0.])
    
    # An example transition.
    s_tm1 = jnp.array([1., 2., -1.])
    r_t = jnp.array(1.)
    s_t = jnp.array([2., 1., 0.])

    def td_loss(theta, s_tm1, r_t, s_t):
        v_tm1 = value_fn(theta, s_tm1)
        target = r_t + value_fn(theta, s_t)
        return (target - v_tm1) ** 2

    td_update = jax.grad(td_loss)
    delta_theta = td_update(theta, s_tm1, r_t, s_t)

    print("Pseudo naive loss", delta_theta)
    
    def td_loss(theta, s_tm1, r_t, s_t):
        v_tm1 = value_fn(theta, s_tm1)
        target = r_t + value_fn(theta, s_t)
        return (jax.lax.stop_gradient(target) - v_tm1) ** 2

    td_update = jax.grad(td_loss)
    delta_theta = td_update(theta, s_tm1, r_t, s_t)

    print("Correct loss", delta_theta)
    
    perex_grads = jax.jit(jax.vmap(jax.grad(td_loss), in_axes=(None, 0, 0, 0)))

    # Test it:
    batched_s_tm1 = jnp.stack([s_tm1, s_tm1])
    batched_r_t = jnp.stack([r_t, r_t])
    batched_s_t = jnp.stack([s_t, s_t])

    print("Per example grads", perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t))

    dtdloss_dtheta = jax.grad(td_loss)

    print("Gradient loss on unbatched inputs", dtdloss_dtheta(theta, s_tm1, r_t, s_t))

    almost_perex_grads = jax.vmap(dtdloss_dtheta)

    batched_theta = jnp.stack([theta, theta])
    member_gradient = almost_perex_grads(batched_theta, batched_s_tm1, batched_r_t, batched_s_t)
    print("Gradient for on member of a batch", member_gradient)

    inefficient_perex_grads = jax.vmap(dtdloss_dtheta, in_axes=(None, 0, 0, 0))
    grads = inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
    print("Inefficient gradients", grads)
    
    perex_grads = jax.jit(inefficient_perex_grads)
    grads = perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
    print("Efficient gradients", grads)
    
    %timeit inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()
    %timeit perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()


@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def straight_through_estimator():
    import jax
    import jax.numpy as jnp
    
    def f(x):
        return jnp.round(x)  # non-differentiable

    def straight_through_f(x):
        # Create an exactly-zero expression with Sterbenz lemma that has
        # an exactly-one gradient.
        zero = x - jax.lax.stop_gradient(x)
        return zero + jax.lax.stop_gradient(f(x))

    print("f(x): ", f(3.2))
    print("straight_through_f(x):", straight_through_f(3.2))

    print("grad(f)(x):", jax.grad(f)(3.2))
    print("grad(straight_through_f)(x):", jax.grad(straight_through_f)(3.2))

In [None]:
# Test our function locally 
higher_order_derivatives()
stopping_gradients()
straight_through_estimator()

In [None]:
# Submit the function for code execution
data_scientist_client.api.services.code.request_code_execution(higher_order_derivatives)
data_scientist_client.api.services.code.request_code_execution(stopping_gradients)
data_scientist_client.api.services.code.request_code_execution(straight_through_estimator)

## Part 2 - Data Owner Reviewing and Approving Requests

In [None]:
data_owner_client = node.login(email="info@openmined.org", password="changethis")

In [None]:
# Get messages from domain
messages = data_owner_client.api.services.messages.get_all()
messages

In [None]:
from helpers import review_request, run_submitted_function, accept_request

for message in messages:
    review_request(message)
    real_result = run_submitted_function(message)
    accept_request(message, real_result)

## Part 3 - Downloading the Results

### Tutorial complete 👏

In [None]:
result = data_scientist_client.api.services.code.higher_order_derivatives()
assert not isinstance(result, sy.SyftError)

result = data_scientist_client.api.services.code.stopping_gradients()
assert not isinstance(result, sy.SyftError)

result = data_scientist_client.api.services.code.straight_through_estimator()
assert not isinstance(result, sy.SyftError)

In [None]:
if node.node_type.value == "python":
    node.land()