# JAX 101 - 08 Stateful Computations
Link to the original JAX tutorial: https://jax.readthedocs.io/en/latest/jax-101/07-state.html

## Part 0 - Data Owner Setup

In [None]:
# Import the necessary libraries
import syft as sy
sy.requires(">=0.8,<0.9")

import jax
import jax.numpy as jnp
import numpy as np

In [None]:
# Launch the domain
node = sy.orchestra.launch(name="test-domain-1", reset=True)
data_owner_client = node.login(email="info@openmined.org", password="changethis")

## Part 1 - Data Scientist

In [None]:
# Register a client to the domain
data_scientist_client = node.client
data_scientist_client.register(name="Jane Doe", email="jane@caltech.edu", password="abc123", institution="Caltech", website="https://www.caltech.edu/")
data_scientist_client.login(email="jane@caltech.edu", password="abc123")

In [None]:
# Create a function for code execution
# ATTENTION: ALL LIBRARIES USED SHOULD BE DEFINED INSIDE THE FUNCTION CONTEXT!!!

@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def counter_1():
    import jax

    class Counter:
        """A simple counter."""

        def __init__(self):
            self.n = 0

        def count(self) -> int:
            """Increments the counter and returns the new value."""
            self.n += 1
            return self.n

        def reset(self):
            """Resets the counter to zero."""
            self.n = 0


    counter = Counter()

    for _ in range(3):
        print(counter.count())

    counter.reset()
    fast_count = jax.jit(counter.count)

    for _ in range(3):
        print(fast_count())
        

@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def counter_2():    
    import jax
    from typing import Tuple

    CounterState = int

    class CounterV2:

        def count(self, n: CounterState) -> Tuple[int, CounterState]:
            # You could just return n+1, but here we separate its role as 
            # the output and as the counter state for didactic purposes.
            return n+1, n+1

        def reset(self) -> CounterState:
            return 0

    counter = CounterV2()
    state = counter.reset()

    for _ in range(3):
        value, state = counter.count(state)
        print(value)
        
    state = counter.reset()
    fast_count = jax.jit(counter.count)

    for _ in range(3):
        value, state = fast_count(state)
        print(value)
        
        
@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def linear_regression():
    import syft as sy
    import jax
    import jax.numpy as jnp
    from typing import NamedTuple

    class Params(NamedTuple):
        weight: jnp.ndarray
        bias: jnp.ndarray


    def init(rng) -> Params:
        """Returns the initial model params."""
        weights_key, bias_key = jax.random.split(rng)
        weight = jax.random.normal(weights_key, ())
        bias = jax.random.normal(bias_key, ())
        return Params(weight, bias)


    def loss(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
        """Computes the least squares error of the model's predictions on x against y."""
        pred = params.weight * x + params.bias
        return jnp.mean((pred - y) ** 2)


    LEARNING_RATE = 0.005

    @jax.jit
    def update(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> Params:
        """Performs one SGD update step on params using the given data."""
        grad = jax.grad(loss)(params, x, y)

        # If we were using Adam or another stateful optimizer,
        # we would also do something like
        # ```
        # updates, new_optimizer_state = optimizer(grad, optimizer_state)
        # ```
        # and then use `updates` instead of `grad` to actually update the params.
        # (And we'd include `new_optimizer_state` in the output, naturally.)

        new_params = jax.tree_map(
            lambda param, g: param - g * LEARNING_RATE, params, grad)

        return new_params
    
    import matplotlib.pyplot as plt

    rng = jax.random.PRNGKey(42)

    # Generate true data from y = w*x + b + noise
    true_w, true_b = 2, -1
    x_rng, noise_rng = jax.random.split(rng)
    xs = jax.random.normal(x_rng, (128, 1))
    noise = jax.random.normal(noise_rng, (128, 1)) * 0.5
    ys = xs * true_w + true_b + noise

    # Fit regression
    params = init(rng)
    for _ in range(1000):
        params = update(params, xs, ys)

    plt.scatter(xs, ys)
    plt.plot(xs, params.weight * xs + params.bias, c='red', label='Model Prediction')
    plt.legend()
    return np.asanyarray(params.weight), np.asanyarray(params.bias)

In [None]:
# Test our function locally 
counter_1()
counter_2()
linear_regression()

In [None]:
# Submit the function for code execution
data_scientist_client.api.services.code.request_code_execution(counter_1)
data_scientist_client.api.services.code.request_code_execution(counter_2)
data_scientist_client.api.services.code.request_code_execution(linear_regression)

## Part 2 - Data Owner Reviewing and Approving Requests

In [None]:
data_owner_client = node.login(email="info@openmined.org", password="changethis")

In [None]:
# Get messages from domain
messages = data_owner_client.api.services.messages.get_all()
messages

In [None]:
from helpers import review_request, run_submitted_function, accept_request

for message in messages:
    review_request(message)
    real_result = run_submitted_function(message)
    accept_request(message, real_result)

## Part 3 - Downloading the Results

### Tutorial complete 👏

In [None]:
result = data_scientist_client.api.services.code.counter_1()
assert not isinstance(result, sy.SyftError)

result = data_scientist_client.api.services.code.counter_2()
assert not isinstance(result, sy.SyftError)

result = data_scientist_client.api.services.linear_regression()
assert not isinstance(result, sy.SyftError)

In [None]:
node.land()