# JAX 101 - 01 Jax Basics
Link to the original JAX tutorial: https://jax.readthedocs.io/en/latest/jax-101/01-jax-basics.html

## Part 1 - Data Owner Setup

In [None]:
# Import the necessary libraries
import syft as sy
sy.requires(">=0.8,<0.9")

import jax
import jax.numpy as jnp
import numpy as np

In [None]:
# Launch the domain
node = sy.orchestra.launch(name="test-domain-1", port=8080, reset=True)
data_owner_client = node.login(email="info@openmined.org", password="changethis")

In [None]:
# Add a dummy data subject
dummy_data_subject = sy.DataSubject(name="Dummy", aliases=["dummy"])
dummy_member = sy.DataSubject(name="Dummy Member", aliases=["dummy:member"])
dummy_data_subject.add_member(dummy_member)
registry = data_owner_client.data_subject_registry
response = registry.add_data_subject(dummy_data_subject)
response

In [None]:
# Check the data subjects
data_subjects = data_owner_client.api.services.data_subject.get_all()
data_subjects

In [None]:
# Add the needed data for the tutorial
# Optional add description, citation, url, contributors
dataset = sy.Dataset(name="JAX - Chapter 1 data")
dataset.set_description("Dummy dataset for Chapter 1")
sos_asset = sy.Asset(name='sum-of-squares-data')
sos_asset.set_description("dummy array for the first set of functions")
sos_asset.set_obj(np.array([1.0, 2.0, 3.0, 4.0]))
sos_asset.set_mock(np.array([1.0, 1.0, 1.0, 1.0]), mock_is_real=False)
sos_asset.add_data_subject(dummy_data_subject)
sos_asset.set_shape((4,))
dataset.add_asset(sos_asset)

In [None]:
# Generate Training data 
import numpy as np
import matplotlib.pyplot as plt

xs = np.random.normal(size=(100,))
noise = np.random.normal(scale=0.1, size=(100,))
ys = xs * 3 - 1 + noise

plt.scatter(xs, ys)

In [None]:
# Add that data to the dataset
x_asset = sy.Asset(name='xs')
x_asset.set_description("xs")
x_asset.set_obj(xs)
x_asset.set_mock(xs, mock_is_real=True)
x_asset.add_data_subject(dummy_data_subject)
x_asset.set_shape((100,))
dataset.add_asset(x_asset)

y_asset = sy.Asset(name='ys')
y_asset.set_description("ys")
y_asset.set_obj(ys)
y_asset.set_mock(ys, mock_is_real=True)
y_asset.add_data_subject(dummy_data_subject)
y_asset.set_shape((100,))
dataset.add_asset(y_asset)

In [None]:
# Upload the dataset
data_owner_client.upload_dataset(dataset)

## Part 2 - Data Scientist

In [None]:
# Register a client to the domain
data_scientist_client = node.client
data_scientist_client.register(
    name="Jane Doe", email="jane@caltech.edu", password="abc123",
    institution="Caltech", website="https://www.caltech.edu/"
)
data_scientist_client.login(email="jane@caltech.edu", password="abc123")

In [None]:
# Inspect available data
results = data_scientist_client.api.services.dataset.get_all()
dataset = results[0]
mock = dataset.assets[0].mock
xs = dataset.assets[1]
ys = dataset.assets[2]

In [None]:
# Create a function for code execution
# ATTENTION: ALL LIBRARIES USED SHOULD BE DEFINED INSIDE THE FUNCTION CONTEXT!!!

@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def func():
    import jax.numpy as jnp
    x = jnp.arange(10)
    return x

@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def func_timeit():
    import jax.numpy as jnp
    long_vector = jnp.arange(int(1e7))
    %timeit jnp.dot(long_vector, long_vector).block_until_ready()
    
@sy.syft_function(input_policy=sy.ExactMatch(x=mock),
                  output_policy=sy.SingleExecutionExactOutput())
def func_sum_of_squares(x):
    import jax
    import jax.numpy as jnp
    
    def sum_of_squares(x):
        return jnp.sum(x**2)
    sum_of_squares_dx = jax.grad(sum_of_squares)
    return sum_of_squares(x), sum_of_squares_dx(x)

@sy.syft_function(input_policy=sy.ExactMatch(x=mock),
                  output_policy=sy.SingleExecutionExactOutput())
def func_sum_squared_error(x):
    import jax
    import jax.numpy as jnp
    
    def sum_squared_error(x, y):
        return jnp.sum((x-y)**2)
    sum_squared_error_dx = jax.grad(sum_squared_error)
    y = jnp.asarray([1.1, 2.1, 3.1, 4.1])
    return sum_squared_error_dx(x, y), jax.grad(sum_squared_error, argnums=(0, 1))(x, y)


@sy.syft_function(input_policy=sy.ExactMatch(x=mock),
                  output_policy=sy.SingleExecutionExactOutput())
def func_sum_squared_error_with_aux(x):
    import jax
    import jax.numpy as jnp
    
    def sum_squared_error(x, y):
        return jnp.sum((x-y)**2)
    
    def squared_error_with_aux(x, y):
        return sum_squared_error(x, y), x-y

    y = jnp.asarray([1.1, 2.1, 3.1, 4.1])
    return jax.grad(squared_error_with_aux, has_aux=True)(x, y)


@sy.syft_function(input_policy=sy.ExactMatch(x=mock),
                  output_policy=sy.SingleExecutionExactOutput())
def func_naive_modify(x):
    import jax
    import jax.numpy as jnp
    def in_place_modify(x):
        x[0] = 123
        return None

    try:
        in_place_modify(jnp.array(x))
    except Exception as e:
        print(e)


@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def func_jax_modify():
    import jax
    import jax.numpy as jnp
    def jax_in_place_modify(x):
        return x.at[0].set(123)

    y = jnp.array([1, 2, 3])
    return jax_in_place_modify(y)

@sy.syft_function(input_policy=sy.ExactMatch(xs=xs.mock, ys=ys.mock),
                  output_policy=sy.SingleExecutionExactOutput())
def training_loop(xs, ys):
    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    def model(theta, x):
        """Computes wx + b on a batch of input x."""
        w, b = theta
        return w * x + b
    
    def loss_fn(theta, x, y):
        prediction = model(theta, x)
        return jnp.mean((prediction-y)**2)
    
    def update(theta, x, y, lr=0.1):
        return theta - lr * jax.grad(loss_fn)(theta, x, y)

    theta = jnp.array([1., 1.])

    for _ in range(1000):
        theta = update(theta, xs, ys)

    plt.scatter(xs, ys)
    plt.plot(xs, model(theta, xs))

    w, b = theta
    print(f"w: {w:<.2f}, b: {b:<.2f}")
    return w, b


In [None]:
# Test our function locally 
print(func())
print(func_timeit())
print(func_sum_of_squares(x=mock))
print(training_loop(xs=xs.mock, ys=ys.mock))

In [None]:
# Submit the function for code execution
# data_scientist_client.api.services.code.request_code_execution(func)
# data_scientist_client.api.services.code.request_code_execution(func_timeit)
# data_scientist_client.api.services.code.request_code_execution(func_sum_of_squares)
# data_scientist_client.api.services.code.request_code_execution(func_sum_squared_error)
# data_scientist_client.api.services.code.request_code_execution(func_sum_squared_error_with_aux)
# data_scientist_client.api.services.code.request_code_execution(func_naive_modify)
# data_scientist_client.api.services.code.request_code_execution(func_jax_modify)
data_scientist_client.api.services.code.request_code_execution(training_loop)

In [None]:
data_scientist_client.api.services.code.training_loop()

## Part 3 - Data Owner Reviewing and Approving Requests

In [None]:
data_owner_client = node.login(email="info@openmined.org", password="changethis")

In [None]:
# Get messages from domain
messages = data_owner_client.api.services.messages.get_all()
messages

In [None]:
# Fetch the dataset
results = data_owner_client.api.services.dataset.get_all()
dataset = results[0]
mock = dataset.assets[0].data.syft_action_data
xs = dataset.assets[1].data.syft_action_data
ys = dataset.assets[2].data.syft_action_data

In [None]:
messages

In [None]:
# Review Request
request = messages[0].link
func = request.changes[0].link
func_name = func.service_func_name
print(func_name)
print(func.raw_code)

In [None]:
# Run the submitted function
user_func = func.unsafe_function
kwargs = [func_name]
real_result = user_func(xs=xs, ys=ys)
real_result

In [None]:
# Approve the request with the previous result
request.approve()
result = request.accept_by_depositing_result(real_result)
result

## Part 4 - Data Scientist

In [None]:
data_scientist_client = node.client
data_scientist_client.login(email="jane@caltech.edu", password="abc123")

In [None]:
result = data_scientist_client.api.services.code.training_loop()

In [None]:
result

### Tutorial Complete 👏

In [None]:
assert round(result[0]) == 3

In [None]:
node.land()