# Jax Level 2

In [30]:
import syft as sy
worker = sy.Worker(name="Command_center")
root_domain_client = worker.root_client
jax = root_domain_client.api.lib.jax

> Worker: Command_center - 724e8da3a2c34f688961909ac6b624b6 - NodeType.DOMAIN

Services:
ActionService
DataSubjectMemberService
DataSubjectService
DatasetService
MessageService
MetadataService
NetworkService
PolicyService
ProjectService
RequestService
UserCodeService
UserService


In [31]:
from syft.serde.lib_service_registry import action_execute_registry_libs
action_execute_registry_libs.jax.numpy

├───jax.numpy (ALL_EXECUTE)
│    ├───fft (ALL_EXECUTE)
│    │    ├───ifft (ALL_EXECUTE)
│    │    ├───ifft2 (ALL_EXECUTE)
│    │    ├───ifftn (ALL_EXECUTE)
│    │    ├───ifftshift (ALL_EXECUTE)
│    │    ├───ihfft (ALL_EXECUTE)
│    │    ├───irfft (ALL_EXECUTE)
│    │    ├───irfft2 (ALL_EXECUTE)
│    │    ├───irfftn (ALL_EXECUTE)
│    │    ├───fft (ALL_EXECUTE)
│    │    ├───fft2 (ALL_EXECUTE)
│    │    ├───fftfreq (ALL_EXECUTE)
│    │    ├───fftn (ALL_EXECUTE)
│    │    ├───fftshift (ALL_EXECUTE)
│    │    ├───hfft (ALL_EXECUTE)
│    │    ├───rfft (ALL_EXECUTE)
│    │    ├───rfft2 (ALL_EXECUTE)
│    │    ├───rfftfreq (ALL_EXECUTE)
│    │    └───rfftn (ALL_EXECUTE)
│    ├───linalg (ALL_EXECUTE)
│    │    ├───cholesky (ALL_EXECUTE)
│    │    ├───det (ALL_EXECUTE)
│    │    ├───eig (ALL_EXECUTE)
│    │    ├───eigh (ALL_EXECUTE)
│    │    ├───eigvals (ALL_EXECUTE)
│    │    ├───eigvalsh (ALL_EXECUTE)
│    │    ├───inv (ALL_EXECUTE)
│    │    ├───lstsq (ALL_EXECUTE)
│    │    ├───matrix_po

In [32]:
jax = root_domain_client.api.lib.jax
jnp = jax.numpy

## The basics

In [33]:
x = jnp.arange(10)
print(x)

[0 1 2 3 4 5 6 7 8 9]


In [34]:
x

DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int64)

In [35]:
type(x)

syft.service.action.jax.DeviceArrayObject

## Time difference

In [36]:
import jax as real_jax
long_vector = real_jax.numpy.arange(int(1e7))

%timeit real_jax.numpy.dot(long_vector, long_vector).block_until_ready()

4.65 ms ± 54.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [37]:
long_vector = jnp.arange(int(1e7))

%timeit jnp.dot(long_vector, long_vector).block_until_ready()

16.3 ms ± 404 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## JAX first transformation: grad

In [38]:
def sum_of_squares(x):
  return jnp.sum(x**2)

In [39]:
sum_of_squares_dx = jax.grad(sum_of_squares)

x = jnp.asarray([1.0, 2.0, 3.0, 4.0])

print(sum_of_squares(x))

print(sum_of_squares_dx(x))

KjException: capnp/message.c++:99: failed: expected segment != nullptr && segment->checkObject(segment->getStartPtr(), ONE * WORDS); Message did not contain a root pointer.
stack: 7fa22895b734 7fa228808935 7fa22881dc0b 7fa27002426e 55c437be6cd6 55c437ba0a60 55c437b58424 55c437be3741 55c437be3db4 55c437acca50 55c437c37afc 55c437be3741 55c437b58424 55c437be3741 55c437b590b4 55c437be3741 55c437b58424 55c437be3741 55c437b58424 55c437be3741 55c437b590b4 55c437be3741 55c437c79997 55c437c7e591 55c437b58424 55c437be3741 55c437b58424 55c437be3741 55c437b58424 55c437be3741 55c437b590b4

# Blocking Question:
## How do we send a function to the domain, without triggering an approval from the Data Owner?

Very common in jax, used by all transformation.

Building Block in Haiku:

In [40]:
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np

In [41]:
class MyLinear1(hk.Module):

  def __init__(self, output_size, name=None):
    super().__init__(name=name)
    self.output_size = output_size

  def __call__(self, x):
    j, k = x.shape[-1], self.output_size
    w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(j))
    w = hk.get_parameter("w", shape=[j, k], dtype=x.dtype, init=w_init)
    b = hk.get_parameter("b", shape=[k], dtype=x.dtype, init=jnp.ones)
    return jnp.dot(x, w) + b

In [42]:
def _forward_fn_linear1(x):
  module = MyLinear1(output_size=2)
  return module(x)

forward_linear1 = hk.transform(_forward_fn_linear1)

In [43]:
forward_linear1

Transformed(init=<function without_state.<locals>.init_fn at 0x7fa0c558d7e0>, apply=<function without_state.<locals>.apply_fn at 0x7fa0c558d000>)

In [44]:
forward_linear1.init

<function haiku._src.transform.without_state.<locals>.init_fn(*args, **kwargs)>

In [45]:
forward_linear1.apply

<function haiku._src.transform.without_state.<locals>.apply_fn(params, *args, **kwargs)>

## Error Propagation

In [None]:
import numpy as np

x = np.array([1, 2, 3])

def in_place_modify(x):
  x[0] = 123
  return None

in_place_modify(x)
x

In [None]:
in_place_modify(jnp.array(x)) 

In [None]:
def jax_in_place_modify(x):
  return x.at[0].set(123)

y = jnp.array([1, 2, 3])
jax_in_place_modify(y)

In [None]:
y

## Basic Model Training

In [None]:
import numpy as np
import matplotlib.pyplot as plt

xs = np.random.normal(size=(100,))
noise = np.random.normal(scale=0.1, size=(100,))
ys = xs * 3 - 1 + noise

plt.scatter(xs, ys)

In [None]:
def model(theta, x):
  """Computes wx + b on a batch of input x."""
  w, b = theta
  return w * x + b

def loss_fn(theta, x, y):
  prediction = model(theta, x)
  return jnp.mean((prediction-y)**2)

def update(theta, x, y, lr=0.1):
  return theta - lr * real_jax.grad(loss_fn)(theta.syft_action_data, x, y)

In [None]:
theta = jnp.array([1., 1.])

for _ in range(1000):
  theta = update(theta, xs, ys)

plt.scatter(xs, ys)
plt.plot(xs, model(theta, xs))

w, b = theta
print(f"w: {w:<.2f}, b: {b:<.2f}")