# JAX 101 - 05 Pseudo Random Numbers
Link to the original JAX tutorial: https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html

## Part 0 - Data Owner Setup

In [None]:
# Import the necessary libraries
import syft as sy
sy.requires(">=0.8,<0.9")

import jax
import jax.numpy as jnp
import numpy as np

In [None]:
# Launch the domain
node = sy.orchestra.launch(name="test-domain-1", reset=True, dev_mode=True)
data_owner_client = node.login(email="info@openmined.org", password="changethis")

## Part 1 - Data Scientist

In [None]:
# Register a client to the domain
node = sy.orchestra.launch(name="test-domain-1")
data_scientist_client = node.client
data_scientist_client.register(name="Jane Doe", email="jane@caltech.edu", password="abc123", institution="Caltech", website="https://www.caltech.edu/")
data_scientist_client.login(email="jane@caltech.edu", password="abc123")

In [None]:
# Create a function for code execution
# ATTENTION: ALL LIBRARIES USED SHOULD BE DEFINED INSIDE THE FUNCTION CONTEXT!!!

@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def random_numbers_numpy():
    import numpy as np
    np.random.seed(0)
    
    def print_truncated_random_state():
        """To avoid spamming the outputs, print only part of the state."""
        full_random_state = np.random.get_state()
        print(str(full_random_state)[:460], '...')
    print_truncated_random_state()

    np.random.seed(0)
    print_truncated_random_state()

    _ = np.random.uniform()
    print_truncated_random_state()
    
    np.random.seed(0)
    print(np.random.uniform(size=3))
    
    np.random.seed(0)
    print("individually:", np.stack([np.random.uniform() for _ in range(3)]))

    np.random.seed(0)
    print("all at once: ", np.random.uniform(size=3))
    

@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def random_numbers_jax():
    
    import numpy as np

    np.random.seed(0)

    def bar(): return np.random.uniform()
    def baz(): return np.random.uniform()
    def foo(): return bar() + 2 * baz()

    print(foo())
    
    from jax import random
    key = random.PRNGKey(42)
    print(key)
    print(random.normal(key))
    print(random.normal(key))
    
    print("old key", key)
    new_key, subkey = random.split(key)
    del key  # The old key is discarded -- we must never use it again.
    normal_sample = random.normal(subkey)
    print(r"    \---SPLIT --> new key   ", new_key)
    print(r"             \--> new subkey", subkey, "--> normal", normal_sample)
    del subkey  # The subkey is also discarded after use.

    # Note: you don't actually need to `del` keys -- that's just for emphasis.
    # Not reusing the same values is enough.

    key = new_key  # If we wanted to do this again, we would use new_key as the key.
    key, subkey = random.split(key)
    key, *forty_two_subkeys = random.split(key, num=43)

    key = random.PRNGKey(42)
    subkeys = random.split(key, 3)
    sequence = np.stack([random.normal(subkey) for subkey in subkeys])
    print("individually:", sequence)

    key = random.PRNGKey(42)
    print("all at once: ", random.normal(key, shape=(3,)))

In [None]:
# Test our function locally 
random_numbers_numpy()
random_numbers_jax()

In [None]:
# Submit the function for code execution
data_scientist_client.api.services.code.request_code_execution(random_numbers_numpy)
data_scientist_client.api.services.code.request_code_execution(random_numbers_jax)

## Part 2 - Data Owner Reviewing and Approving Requests

In [None]:
data_owner_client = node.login(email="info@openmined.org", password="changethis")

In [None]:
# Get messages from domain
messages = data_owner_client.api.services.messages.get_all()
messages

In [None]:
from helpers import review_request, run_submitted_function, accept_request

for message in messages:
    review_request(message)
    real_result = run_submitted_function(message)
    accept_request(message, real_result)

## Part 3 - Downloading the Results

### Tutorial complete 👏

In [None]:
result = data_scientist_client.api.services.code.random_numbers_numpy()
assert not isinstance(result, sy.SyftError)

result = data_scientist_client.api.services.code.random_numbers_jax()
assert not isinstance(result, sy.SyftError)

In [None]:
node.land()