# JAX 101 - 02 Jitting
Link to the original JAX tutorial: https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html

## Part 0 - Data Owner Setup

In [None]:
# Import the necessary libraries
import syft as sy
sy.requires(">=0.8,<0.9")

import jax
import jax.numpy as jnp
import numpy as np

In [None]:
# Launch the domain
node = sy.orchestra.launch(name="test-domain-1", port=8080, reset=True, dev_mode=True)
data_owner_client = node.login(email="info@openmined.org", password="changethis")

## Part 1 - Data Scientist

In [None]:
# Register a client to the domain
node = sy.orchestra.launch(name="test-domain-1")
data_scientist_client = node.client
data_scientist_client.register(name="Jane Doe", email="jane@caltech.edu", password="abc123", institution="Caltech", website="https://www.caltech.edu/")
data_scientist_client.login(email="jane@caltech.edu", password="abc123")

In [None]:
# Create a function for code execution
# ATTENTION: ALL LIBRARIES USED SHOULD BE DEFINED INSIDE THE FUNCTION CONTEXT!!!

@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def func_log2():
    import jax
    import jax.numpy as jnp

    global_list = []

    def log2(x):
        global_list.append(x)
        ln_x = jnp.log(x)
        ln_2 = jnp.log(2.0)
        return ln_x / ln_2

    print(jax.make_jaxpr(log2)(3.0))


@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def func_log2_with_print():
    import jax
    import jax.numpy as jnp
    
    def log2_with_print(x):
        print("printed x:", x)
        ln_x = jnp.log(x)
        ln_2 = jnp.log(2.0)
        return ln_x / ln_2

    print(jax.make_jaxpr(log2_with_print)(3.))


@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def func_log2_if_rank_2():
    import jax.numpy as jnp
    
    def log2_if_rank_2(x):
        if x.ndim == 2:
            ln_x = jnp.log(x)
            ln_2 = jnp.log(2.0)
            return ln_x / ln_2
        else:
            return x
        
    print(jax.make_jaxpr(log2_if_rank_2)(jax.numpy.array([1, 2, 3])))

@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def func_selu():
    import jax
    import jax.numpy as jnp

    def selu(x, alpha=1.67, lambda_=1.05):
        return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

    x = jnp.arange(1000000)
    %timeit selu(x).block_until_ready()
    
    selu_jit = jax.jit(selu)

    # Warm up
    selu_jit(x).block_until_ready()

    %timeit selu_jit(x).block_until_ready()
    

@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def func_jit_error():
    import jax
    
    # Condition on value of x.

    def f(x):
        if x > 0:
            return x
        else:
            return 2 * x

    f_jit = jax.jit(f)
    f_jit(10)  # Should raise an error. 


@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def func_jit_error2():
    import jax
    
    # While loop conditioned on x and n.

    def g(x, n):
        i = 0
        while i < n:
            i += 1
        return x + i

    g_jit = jax.jit(g)
    g_jit(10, 20)  # Should raise an error. )


@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def func_jit_loop():
    import jax
    
    # While loop conditioned on x and n with a jitted body.

    @jax.jit
    def loop_body(prev_i):
        return prev_i + 1

    def g_inner_jitted(x, n):
        i = 0
        while i < n:
            i = loop_body(i)
        return x + i

    return g_inner_jitted(10, 20)


@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def func_jit_static_args():
    import jax
    
    # While loop conditioned on x and n.

    def g(x, n):
        i = 0
        while i < n:
            i += 1
        return x + i

    def f(x):
        if x > 0:
            return x
        else:
            return 2 * x
        
    f_jit_correct = jax.jit(f, static_argnums=0)
    print(f_jit_correct(10))
    
    g_jit_correct = jax.jit(g, static_argnames=['n'])
    print(g_jit_correct(10, 20))
    

@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def func_jit_partial_decorator():
    import jax
    
    from functools import partial

    @partial(jax.jit, static_argnames=['n'])
    def g_jit_decorated(x, n):
        i = 0
        while i < n:
            i += 1
        return x + i

    return g_jit_decorated(10, 20)


@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def func_jit_caching():
    from functools import partial
    import jax

    def unjitted_loop_body(prev_i):
        return prev_i + 1

    def g_inner_jitted_partial(x, n):
        i = 0
        while i < n:
            # Don't do this! each time the partial returns
            # a function with different hash
            i = jax.jit(partial(unjitted_loop_body))(i)
        return x + i

    def g_inner_jitted_lambda(x, n):
        i = 0
        while i < n:
            # Don't do this!, lambda will also return
            # a function with a different hash
            i = jax.jit(lambda x: unjitted_loop_body(x))(i)
        return x + i

    def g_inner_jitted_normal(x, n):
        i = 0
        while i < n:
            # this is OK, since JAX can find the
            # cached, compiled function
            i = jax.jit(unjitted_loop_body)(i)
        return x + i

    print("jit called in a loop with partials:")
    %timeit g_inner_jitted_partial(10, 20).block_until_ready()

    print("jit called in a loop with lambdas:")
    %timeit g_inner_jitted_lambda(10, 20).block_until_ready()

    print("jit called in a loop with caching:")
    %timeit g_inner_jitted_normal(10, 20).block_until_ready()


In [None]:
# Test our function locally 
func_log2()
func_log2_with_print()
func_log2_if_rank_2()
func_selu()
# func_jit_error()
# func_jit_error2()
func_jit_loop()
func_jit_static_args()
func_jit_partial_decorator()
func_jit_caching()

In [None]:
# Submit the function for code execution
data_scientist_client.api.services.code.request_code_execution(func_log2)
data_scientist_client.api.services.code.request_code_execution(func_log2_with_print)
data_scientist_client.api.services.code.request_code_execution(func_log2_if_rank_2)
data_scientist_client.api.services.code.request_code_execution(func_selu)
# data_scientist_client.api.services.code.request_code_execution(func_jit_error)
# data_scientist_client.api.services.code.request_code_execution(func_jit_error2)
data_scientist_client.api.services.code.request_code_execution(func_jit_loop)
data_scientist_client.api.services.code.request_code_execution(func_jit_static_args)
data_scientist_client.api.services.code.request_code_execution(func_jit_partial_decorator)
data_scientist_client.api.services.code.request_code_execution(func_jit_caching)

## Part 2 - Data Owner Reviewing and Approving Requests

In [None]:
data_owner_client = node.login(email="info@openmined.org", password="changethis")


In [None]:
# Get messages from domain
messages = data_owner_client.api.services.messages.get_all()
messages

In [None]:
# Review, run and accept the requests
from helpers import review_request, run_submitted_function, accept_request

for message in messages:
    review_request(message)
    real_result = run_submitted_function(message)
    accept_request(message, real_result)

##  Part 3 - Data Scientist

In [None]:
result = data_scientist_client.api.services.code.func_log2()
assert not isinstance(result, sy.SyftError)

result = data_scientist_client.api.services.code.func_log2_with_print()
assert not isinstance(result, sy.SyftError)

result = data_scientist_client.api.services.code.func_log2_if_rank_2()
assert not isinstance(result, sy.SyftError)

result = data_scientist_client.api.services.code.func_jit_loop()
assert not isinstance(result, sy.SyftError)

result = data_scientist_client.api.services.code.func_jit_static_args()
assert not isinstance(result, sy.SyftError)

result = data_scientist_client.api.services.code.func_jit_partial_decorator()
assert not isinstance(result, sy.SyftError)

result = data_scientist_client.api.services.code.func_jit_caching()
assert not isinstance(result, sy.SyftError)

### Tutorial complete 👏

In [None]:
if node.node_type.value == "python":
    node.land()