# JAX 101 - 05 Pseudo Random Numbers
Link to the original JAX tutorial: https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html

## Part 0 - Data Owner Setup

In [12]:
# Import the necessary libraries
import syft as sy
sy.requires(">=0.8,<0.9")

import jax
import jax.numpy as jnp
import numpy as np

✅ The installed version of syft==0.8.1b3 matches the requirement >=0.8 and the requirement <0.9


In [13]:
# Launch the domain
node = sy.orchestra.launch(name="test-domain-1", reset=True, dev_mode=True)
data_owner_client = node.login(email="info@openmined.org", password="changethis")

SQLite Store Path:
!open file:///var/folders/sz/hkfsnn612hq56r7cs5rd540r0000gn/T/7bca415d13ed4ec881f0d0aede098dbb.sqlite



## Part 1 - Data Scientist

In [14]:
# Register a client to the domain
node = sy.orchestra.launch(name="test-domain-1")
data_scientist_client = node.client
data_scientist_client.register(name="Jane Doe", email="jane@caltech.edu", password="abc123", institution="Caltech", website="https://www.caltech.edu/")
data_scientist_client.login(email="jane@caltech.edu", password="abc123")

SQLite Store Path:
!open file:///var/folders/sz/hkfsnn612hq56r7cs5rd540r0000gn/T/7bca415d13ed4ec881f0d0aede098dbb.sqlite



<SyftClient - test-domain-1 <7bca415d13ed4ec881f0d0aede098dbb>: PythonConnection>

In [15]:
# Create a function for code execution
# ATTENTION: ALL LIBRARIES USED SHOULD BE DEFINED INSIDE THE FUNCTION CONTEXT!!!

@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def random_numbers_numpy():
    import numpy as np
    np.random.seed(0)
    
    def print_truncated_random_state():
        """To avoid spamming the outputs, print only part of the state."""
        full_random_state = np.random.get_state()
        print(str(full_random_state)[:460], '...')
    print_truncated_random_state()

    np.random.seed(0)
    print_truncated_random_state()

    _ = np.random.uniform()
    print_truncated_random_state()
    
    np.random.seed(0)
    print(np.random.uniform(size=3))
    
    np.random.seed(0)
    print("individually:", np.stack([np.random.uniform() for _ in range(3)]))

    np.random.seed(0)
    print("all at once: ", np.random.uniform(size=3))
    

@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def random_numbers_jax():
    
    import numpy as np

    np.random.seed(0)

    def bar(): return np.random.uniform()
    def baz(): return np.random.uniform()
    def foo(): return bar() + 2 * baz()

    print(foo())
    
    from jax import random
    key = random.PRNGKey(42)
    print(key)
    print(random.normal(key))
    print(random.normal(key))
    
    print("old key", key)
    new_key, subkey = random.split(key)
    del key  # The old key is discarded -- we must never use it again.
    normal_sample = random.normal(subkey)
    print(r"    \---SPLIT --> new key   ", new_key)
    print(r"             \--> new subkey", subkey, "--> normal", normal_sample)
    del subkey  # The subkey is also discarded after use.

    # Note: you don't actually need to `del` keys -- that's just for emphasis.
    # Not reusing the same values is enough.

    key = new_key  # If we wanted to do this again, we would use new_key as the key.
    key, subkey = random.split(key)
    key, *forty_two_subkeys = random.split(key, num=43)

    key = random.PRNGKey(42)
    subkeys = random.split(key, 3)
    sequence = np.stack([random.normal(subkey) for subkey in subkeys])
    print("individually:", sequence)

    key = random.PRNGKey(42)
    print("all at once: ", random.normal(key, shape=(3,)))

In [16]:
# Test our function locally 
random_numbers_numpy()
random_numbers_jax()

('MT19937', array([         0,          1, 1812433255, 1900727105, 1208447044,
       2481403966, 4042607538,  337614300, 3232553940, 1018809052,
       3202401494, 1775180719, 3192392114,  594215549,  184016991,
        829906058,  610491522, 3879932251, 3139825610,  297902587,
       4075895579, 2943625357, 3530655617, 1423771745, 2135928312,
       2891506774, 1066338622,  135451537,  933040465, 2759011858,
       2273819758, 3545703099, 2516396728, 127 ...
('MT19937', array([         0,          1, 1812433255, 1900727105, 1208447044,
       2481403966, 4042607538,  337614300, 3232553940, 1018809052,
       3202401494, 1775180719, 3192392114,  594215549,  184016991,
        829906058,  610491522, 3879932251, 3139825610,  297902587,
       4075895579, 2943625357, 3530655617, 1423771745, 2135928312,
       2891506774, 1066338622,  135451537,  933040465, 2759011858,
       2273819758, 3545703099, 2516396728, 127 ...
('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 110

In [17]:
# Submit the function for code execution
data_scientist_client.api.services.code.request_code_execution(random_numbers_numpy)
data_scientist_client.api.services.code.request_code_execution(random_numbers_jax)

```python
class Request:
  id: str = d429f1a38b1f41a5995bded876831a24
  requesting_user_verify_key: str = a44a9d2519d3b980bae2b375d7d5b3e122e1d23d4f50511f501ec958485ab92c
  approving_user_verify_key: str = None
  request_time: str = 2023-05-29 06:24:53
  approval_time: str = None
  status: str = RequestStatus.PENDING
  node_uid: str = 7bca415d13ed4ec881f0d0aede098dbb
  request_hash: str = "d39a9684633206aac95c0d6d1d60d959b83acc7385d83464fa6db868d1c1e5ee"
  changes: str = [syft.service.request.request.UserCodeStatusChange]

```

## Part 2 - Data Owner Reviewing and Approving Requests

In [18]:
data_owner_client = node.login(email="info@openmined.org", password="changethis")

In [19]:
# Get messages from domain
messages = data_owner_client.api.services.messages.get_all()
messages

Unnamed: 0,type,id,subject,status,created_at,linked_obj
0,syft.service.message.messages.Message,57c7046c56474b47b311d9e4d309d3ed,Approval Request,MessageStatus.UNDELIVERED,2023-05-29 06:24:53,<<class 'syft.service.request.request.Request'...
1,syft.service.message.messages.Message,51941d442bfc4d1397ae55d3504114bd,Approval Request,MessageStatus.UNDELIVERED,2023-05-29 06:24:53,<<class 'syft.service.request.request.Request'...


In [20]:
from helpers import review_request, run_submitted_function, accept_request

for message in messages:
    review_request(message)
    real_result = run_submitted_function(message)
    accept_request(message, real_result)

random_numbers_jax
@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def random_numbers_jax():
    
    import numpy as np

    np.random.seed(0)

    def bar(): return np.random.uniform()
    def baz(): return np.random.uniform()
    def foo(): return bar() + 2 * baz()

    print(foo())
    
    from jax import random
    key = random.PRNGKey(42)
    print(key)
    print(random.normal(key))
    print(random.normal(key))
    
    print("old key", key)
    new_key, subkey = random.split(key)
    del key  # The old key is discarded -- we must never use it again.
    normal_sample = random.normal(subkey)
    print(r"    \---SPLIT --> new key   ", new_key)
    print(r"             \--> new subkey", subkey, "--> normal", normal_sample)
    del subkey  # The subkey is also discarded after use.

    # Note: you don't actually need to `del` keys -- that's just for emphasis.
    # Not reusing the same values is enough.

    key = ne

exec_result=syft.service.code.user_code.UserCodeExecutionResult
action_object=Pointer:
syft.service.code.user_code.UserCodeExecutionResult


message='Request d429f1a38b1f41a5995bded876831a24 changes applied'
random_numbers_numpy
@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def random_numbers_numpy():
    import numpy as np
    np.random.seed(0)
    
    def print_truncated_random_state():
        """To avoid spamming the outputs, print only part of the state."""
        full_random_state = np.random.get_state()
        print(str(full_random_state)[:460], '...')
    print_truncated_random_state()

    np.random.seed(0)
    print_truncated_random_state()

    _ = np.random.uniform()
    print_truncated_random_state()
    
    np.random.seed(0)
    print(np.random.uniform(size=3))
    
    np.random.seed(0)
    print("individually:", np.stack([np.random.uniform() for _ in range(3)]))

    np.random.seed(0)
    print("all at once: ", np.random.uniform(size=3))

syft.service.code.user_code.UserCodeExecutionResult
message='Request f7b3caa880924c6ebb603356d0370510

exec_result=syft.service.code.user_code.UserCodeExecutionResult
action_object=Pointer:
syft.service.code.user_code.UserCodeExecutionResult


<Figure size 640x480 with 0 Axes>

## Part 3 - Downloading the Results

### Tutorial complete 👏

In [21]:
result = data_scientist_client.api.services.code.random_numbers_numpy()
assert not isinstance(result, sy.SyftError)

result = data_scientist_client.api.services.code.random_numbers_jax()
assert not isinstance(result, sy.SyftError)

In [22]:
if node.node_type.value == "python":
    node.land()