# JAX 101 - 02 Jitting
Link to the original JAX tutorial: https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html

## Part 0 - Data Owner Setup

In [1]:
# Import the necessary libraries
import syft as sy
sy.requires(">=0.8,<0.9")

import jax
import jax.numpy as jnp
import numpy as np



✅ The installed version of syft==0.8.1b3 matches the requirement >=0.8 and the requirement <0.9


In [2]:
# Launch the domain
node = sy.orchestra.launch(name="test-domain-1", port=8080, reset=True, dev_mode=True)
data_owner_client = node.login(email="info@openmined.org", password="changethis")

Starting test-domain-1 server on 0.0.0.0:8080




Waiting for server to start
Waiting for server to start
Waiting for server to start

SQLite Store Path:
!open file:///var/folders/sz/hkfsnn612hq56r7cs5rd540r0000gn/T/7bca415d13ed4ec881f0d0aede098dbb.sqlite



INFO:     Started server process [70869]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8080 (Press CTRL+C to quit)


Waiting for server to start
INFO:     127.0.0.1:53080 - "GET /api/v2/metadata HTTP/1.1" 200 OK
Server Started
INFO:     127.0.0.1:53082 - "GET /api/v2/metadata HTTP/1.1" 200 OK
INFO:     127.0.0.1:53082 - "POST /api/v2/login HTTP/1.1" 200 OK
INFO:     127.0.0.1:53082 - "GET /api/v2/api?verify_key=aec6ea4dfc049ceacaeeebc493167a88a200ddc367b1fa32da652444b635d21f HTTP/1.1" 200 OK


## Part 1 - Data Scientist

In [3]:
# Register a client to the domain
node = sy.orchestra.launch(name="test-domain-1")
data_scientist_client = node.client
data_scientist_client.register(name="Jane Doe", email="jane@caltech.edu", password="abc123", institution="Caltech", website="https://www.caltech.edu/")
data_scientist_client.login(email="jane@caltech.edu", password="abc123")

SQLite Store Path:
!open file:///var/folders/sz/hkfsnn612hq56r7cs5rd540r0000gn/T/7bca415d13ed4ec881f0d0aede098dbb.sqlite



<SyftClient - test-domain-1 <7bca415d13ed4ec881f0d0aede098dbb>: PythonConnection>

In [4]:
# Create a function for code execution
# ATTENTION: ALL LIBRARIES USED SHOULD BE DEFINED INSIDE THE FUNCTION CONTEXT!!!

@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def func_log2():
    import jax
    import jax.numpy as jnp

    global_list = []

    def log2(x):
        global_list.append(x)
        ln_x = jnp.log(x)
        ln_2 = jnp.log(2.0)
        return ln_x / ln_2

    print(jax.make_jaxpr(log2)(3.0))


@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def func_log2_with_print():
    import jax
    import jax.numpy as jnp
    
    def log2_with_print(x):
        print("printed x:", x)
        ln_x = jnp.log(x)
        ln_2 = jnp.log(2.0)
        return ln_x / ln_2

    print(jax.make_jaxpr(log2_with_print)(3.))


@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def func_log2_if_rank_2():
    import jax.numpy as jnp
    
    def log2_if_rank_2(x):
        if x.ndim == 2:
            ln_x = jnp.log(x)
            ln_2 = jnp.log(2.0)
            return ln_x / ln_2
        else:
            return x
        
    print(jax.make_jaxpr(log2_if_rank_2)(jax.numpy.array([1, 2, 3])))

@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def func_selu():
    import jax
    import jax.numpy as jnp

    def selu(x, alpha=1.67, lambda_=1.05):
        return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

    x = jnp.arange(1000000)
    %timeit selu(x).block_until_ready()
    
    selu_jit = jax.jit(selu)

    # Warm up
    selu_jit(x).block_until_ready()

    %timeit selu_jit(x).block_until_ready()
    

@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def func_jit_error():
    import jax
    
    # Condition on value of x.

    def f(x):
        if x > 0:
            return x
        else:
            return 2 * x

    f_jit = jax.jit(f)
    f_jit(10)  # Should raise an error. 


@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def func_jit_error2():
    import jax
    
    # While loop conditioned on x and n.

    def g(x, n):
        i = 0
        while i < n:
            i += 1
        return x + i

    g_jit = jax.jit(g)
    g_jit(10, 20)  # Should raise an error. )


@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def func_jit_loop():
    import jax
    
    # While loop conditioned on x and n with a jitted body.

    @jax.jit
    def loop_body(prev_i):
        return prev_i + 1

    def g_inner_jitted(x, n):
        i = 0
        while i < n:
            i = loop_body(i)
        return x + i

    return g_inner_jitted(10, 20)


@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def func_jit_static_args():
    import jax
    
    # While loop conditioned on x and n.

    def g(x, n):
        i = 0
        while i < n:
            i += 1
        return x + i

    def f(x):
        if x > 0:
            return x
        else:
            return 2 * x
        
    f_jit_correct = jax.jit(f, static_argnums=0)
    print(f_jit_correct(10))
    
    g_jit_correct = jax.jit(g, static_argnames=['n'])
    print(g_jit_correct(10, 20))
    

@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def func_jit_partial_decorator():
    import jax
    
    from functools import partial

    @partial(jax.jit, static_argnames=['n'])
    def g_jit_decorated(x, n):
        i = 0
        while i < n:
            i += 1
        return x + i

    return g_jit_decorated(10, 20)


@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def func_jit_caching():
    from functools import partial
    import jax

    def unjitted_loop_body(prev_i):
        return prev_i + 1

    def g_inner_jitted_partial(x, n):
        i = 0
        while i < n:
            # Don't do this! each time the partial returns
            # a function with different hash
            i = jax.jit(partial(unjitted_loop_body))(i)
        return x + i

    def g_inner_jitted_lambda(x, n):
        i = 0
        while i < n:
            # Don't do this!, lambda will also return
            # a function with a different hash
            i = jax.jit(lambda x: unjitted_loop_body(x))(i)
        return x + i

    def g_inner_jitted_normal(x, n):
        i = 0
        while i < n:
            # this is OK, since JAX can find the
            # cached, compiled function
            i = jax.jit(unjitted_loop_body)(i)
        return x + i

    print("jit called in a loop with partials:")
    %timeit g_inner_jitted_partial(10, 20).block_until_ready()

    print("jit called in a loop with lambdas:")
    %timeit g_inner_jitted_lambda(10, 20).block_until_ready()

    print("jit called in a loop with caching:")
    %timeit g_inner_jitted_normal(10, 20).block_until_ready()


In [5]:
# Test our function locally 
func_log2()
func_log2_with_print()
func_log2_if_rank_2()
func_selu()
# func_jit_error()
# func_jit_error2()
func_jit_loop()
func_jit_static_args()
func_jit_partial_decorator()
func_jit_caching()

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f64[][39m. [34m[22m[1mlet
    [39m[22m[22mb[35m:f64[][39m = log a
    c[35m:f64[][39m = log 2.0
    d[35m:f64[][39m = div b c
  [34m[22m[1min [39m[22m[22m(d,) }
printed x: Traced<ShapedArray(float64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f64[][39m. [34m[22m[1mlet
    [39m[22m[22mb[35m:f64[][39m = log a
    c[35m:f64[][39m = log 2.0
    d[35m:f64[][39m = div b c
  [34m[22m[1min [39m[22m[22m(d,) }
{ [34m[22m[1mlambda [39m[22m[22m; a[35m:i64[3][39m. [34m[22m[1mlet[39m[22m[22m  [34m[22m[1min [39m[22m[22m(a,) }
9.49 ms ± 158 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.97 ms ± 482 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
10
30
jit called in a loop with partials:
284 ms ± 6.68 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
jit called in a loop with lambdas:
280 ms ± 2.87 ms per loop (mean ± 

In [6]:
# Submit the function for code execution
data_scientist_client.api.services.code.request_code_execution(func_log2)
data_scientist_client.api.services.code.request_code_execution(func_log2_with_print)
data_scientist_client.api.services.code.request_code_execution(func_log2_if_rank_2)
data_scientist_client.api.services.code.request_code_execution(func_selu)
# data_scientist_client.api.services.code.request_code_execution(func_jit_error)
# data_scientist_client.api.services.code.request_code_execution(func_jit_error2)
data_scientist_client.api.services.code.request_code_execution(func_jit_loop)
data_scientist_client.api.services.code.request_code_execution(func_jit_static_args)
data_scientist_client.api.services.code.request_code_execution(func_jit_partial_decorator)
data_scientist_client.api.services.code.request_code_execution(func_jit_caching)

```python
class Request:
  id: str = 3584410ee57e4ee2b9319cfeb6cbc0f0
  requesting_user_verify_key: str = ee16171561655265b480ec7970b148a66c6521831267e11cbf13fba3430e891d
  approving_user_verify_key: str = None
  request_time: str = 2023-05-29 06:12:26
  approval_time: str = None
  status: str = RequestStatus.PENDING
  node_uid: str = 7bca415d13ed4ec881f0d0aede098dbb
  request_hash: str = "72cbee9d8353405dd31058aa513041052464cd711a1563f3bb0c242ec9fe8fa6"
  changes: str = [syft.service.request.request.UserCodeStatusChange]

```

## Part 2 - Data Owner Reviewing and Approving Requests

In [7]:
data_owner_client = node.login(email="info@openmined.org", password="changethis")


In [8]:
# Get messages from domain
messages = data_owner_client.api.services.messages.get_all()
messages

Unnamed: 0,type,id,subject,status,created_at,linked_obj
0,syft.service.message.messages.Message,a3903dbfedc443f8bf7f9ffb5b5eb115,Approval Request,MessageStatus.UNDELIVERED,2023-05-29 06:12:25,<<class 'syft.service.request.request.Request'...
1,syft.service.message.messages.Message,21a2010fae7249f9a851d898465a0258,Approval Request,MessageStatus.UNDELIVERED,2023-05-29 06:12:25,<<class 'syft.service.request.request.Request'...
2,syft.service.message.messages.Message,be900b6cc94643d9bd64addb2731df10,Approval Request,MessageStatus.UNDELIVERED,2023-05-29 06:12:25,<<class 'syft.service.request.request.Request'...
3,syft.service.message.messages.Message,c822c6bb1b5c491d8ccf16118f75bed0,Approval Request,MessageStatus.UNDELIVERED,2023-05-29 06:12:25,<<class 'syft.service.request.request.Request'...
4,syft.service.message.messages.Message,1baaaba754f14ef18faf96f278cf5dc1,Approval Request,MessageStatus.UNDELIVERED,2023-05-29 06:12:26,<<class 'syft.service.request.request.Request'...


In [9]:
# Review, run and accept the requests
from helpers import review_request, run_submitted_function, accept_request

for message in messages:
    review_request(message)
    real_result = run_submitted_function(message)
    accept_request(message, real_result)

func_selu
@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def func_selu():
    import jax
    import jax.numpy as jnp

    def selu(x, alpha=1.67, lambda_=1.05):
        return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

    x = jnp.arange(1000000)
    get_ipython().run_line_magic('timeit', 'selu(x).block_until_ready()')
    
    selu_jit = jax.jit(selu)

    # Warm up
    selu_jit(x).block_until_ready()

    get_ipython().run_line_magic('timeit', 'selu_jit(x).block_until_ready()')

syft.service.code.user_code.UserCodeExecutionResult


exec_result=syft.service.code.user_code.UserCodeExecutionResult
action_object=Pointer:
syft.service.code.user_code.UserCodeExecutionResult


message='Request 97771809d36546608040eaf7e23b5deb changes applied'
func_jit_loop
@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def func_jit_loop():
    import jax
    
    # While loop conditioned on x and n with a jitted body.

    @jax.jit
    def loop_body(prev_i):
        return prev_i + 1

    def g_inner_jitted(x, n):
        i = 0
        while i < n:
            i = loop_body(i)
        return x + i

    return g_inner_jitted(10, 20)

syft.service.code.user_code.UserCodeExecutionResult


exec_result=syft.service.code.user_code.UserCodeExecutionResult
action_object=Pointer:
syft.service.code.user_code.UserCodeExecutionResult


message='Request 540b560546884194964054c3445d866f changes applied'
func_jit_static_args
@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def func_jit_static_args():
    import jax
    
    # While loop conditioned on x and n.

    def g(x, n):
        i = 0
        while i < n:
            i += 1
        return x + i

    def f(x):
        if x > 0:
            return x
        else:
            return 2 * x
        
    f_jit_correct = jax.jit(f, static_argnums=0)
    print(f_jit_correct(10))
    
    g_jit_correct = jax.jit(g, static_argnames=['n'])
    print(g_jit_correct(10, 20))

syft.service.code.user_code.UserCodeExecutionResult


exec_result=syft.service.code.user_code.UserCodeExecutionResult
action_object=Pointer:
syft.service.code.user_code.UserCodeExecutionResult


message='Request 540e80a929fc41a99dc0851604296e34 changes applied'
func_jit_partial_decorator
@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def func_jit_partial_decorator():
    import jax
    
    from functools import partial

    @partial(jax.jit, static_argnames=['n'])
    def g_jit_decorated(x, n):
        i = 0
        while i < n:
            i += 1
        return x + i

    return g_jit_decorated(10, 20)

syft.service.code.user_code.UserCodeExecutionResult


exec_result=syft.service.code.user_code.UserCodeExecutionResult
action_object=Pointer:
syft.service.code.user_code.UserCodeExecutionResult


message='Request b6771ed8d5ab4797b3a4940294ad13c3 changes applied'
func_jit_caching
@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def func_jit_caching():
    from functools import partial
    import jax

    def unjitted_loop_body(prev_i):
        return prev_i + 1

    def g_inner_jitted_partial(x, n):
        i = 0
        while i < n:
            # Don't do this! each time the partial returns
            # a function with different hash
            i = jax.jit(partial(unjitted_loop_body))(i)
        return x + i

    def g_inner_jitted_lambda(x, n):
        i = 0
        while i < n:
            # Don't do this!, lambda will also return
            # a function with a different hash
            i = jax.jit(lambda x: unjitted_loop_body(x))(i)
        return x + i

    def g_inner_jitted_normal(x, n):
        i = 0
        while i < n:
            # this is OK, since JAX can find the
            # cached, compiled func

exec_result=syft.service.code.user_code.UserCodeExecutionResult
action_object=Pointer:
syft.service.code.user_code.UserCodeExecutionResult


message='Request 3584410ee57e4ee2b9319cfeb6cbc0f0 changes applied'
func_log2
@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def func_log2():
    import jax
    import jax.numpy as jnp

    global_list = []

    def log2(x):
        global_list.append(x)
        ln_x = jnp.log(x)
        ln_2 = jnp.log(2.0)
        return ln_x / ln_2

    print(jax.make_jaxpr(log2)(3.0))

syft.service.code.user_code.UserCodeExecutionResult


exec_result=syft.service.code.user_code.UserCodeExecutionResult
action_object=Pointer:
syft.service.code.user_code.UserCodeExecutionResult


message='Request dd05f11283834cc88797add48f30abee changes applied'
func_log2_if_rank_2
@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def func_log2_if_rank_2():
    import jax.numpy as jnp
    
    def log2_if_rank_2(x):
        if x.ndim == 2:
            ln_x = jnp.log(x)
            ln_2 = jnp.log(2.0)
            return ln_x / ln_2
        else:
            return x
        
    print(jax.make_jaxpr(log2_if_rank_2)(jax.numpy.array([1, 2, 3])))

syft.service.code.user_code.UserCodeExecutionResult


exec_result=syft.service.code.user_code.UserCodeExecutionResult
action_object=Pointer:
syft.service.code.user_code.UserCodeExecutionResult


message='Request ea840454227343cd848859f2e2b644f1 changes applied'
func_log2_with_print
@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def func_log2_with_print():
    import jax
    import jax.numpy as jnp
    
    def log2_with_print(x):
        print("printed x:", x)
        ln_x = jnp.log(x)
        ln_2 = jnp.log(2.0)
        return ln_x / ln_2

    print(jax.make_jaxpr(log2_with_print)(3.))

syft.service.code.user_code.UserCodeExecutionResult


exec_result=syft.service.code.user_code.UserCodeExecutionResult
action_object=Pointer:
syft.service.code.user_code.UserCodeExecutionResult


message='Request 9b0b145a40b7455ba3bfe8fb49aa8c99 changes applied'


<Figure size 640x480 with 0 Axes>

##  Part 3 - Data Scientist

In [10]:
result = data_scientist_client.api.services.code.func_log2()
assert not isinstance(result, sy.SyftError)

result = data_scientist_client.api.services.code.func_log2_with_print()
assert not isinstance(result, sy.SyftError)

result = data_scientist_client.api.services.code.func_log2_if_rank_2()
assert not isinstance(result, sy.SyftError)

result = data_scientist_client.api.services.code.func_jit_loop()
assert not isinstance(result, sy.SyftError)

result = data_scientist_client.api.services.code.func_jit_static_args()
assert not isinstance(result, sy.SyftError)

result = data_scientist_client.api.services.code.func_jit_partial_decorator()
assert not isinstance(result, sy.SyftError)

result = data_scientist_client.api.services.code.func_jit_caching()
assert not isinstance(result, sy.SyftError)

### Tutorial complete 👏

In [11]:
if node.node_type.value == "python":
    node.land()