# JAX 101 - 03 Vectorization
Link to the original JAX tutorial: https://jax.readthedocs.io/en/latest/jax-101/03-vectorization.html

## Part 0 - Data Owner Setup

In [None]:
# Import the necessary libraries
import syft as sy
sy.requires(">=0.8,<0.9")

import jax
import jax.numpy as jnp
import numpy as np

In [None]:
# Launch the domain
node = sy.orchestra.launch(name="test-domain-1", reset=True)
domain_client = node.login(email="info@openmined.org", password="changethis")

## Part 1 - Data Scientist

In [None]:
# Register a client to the domain
node = sy.orchestra.launch(name="test-domain-1")
data_scientist_client = node.client
data_scientist_client.register(name="Jane Doe", email="jane@caltech.edu", password="abc123", institution="Caltech", website="https://www.caltech.edu/")
data_scientist_client.login(email="jane@caltech.edu", password="abc123")

In [None]:
# Create a function for code execution
# ATTENTION: ALL LIBRARIES USED SHOULD BE DEFINED INSIDE THE FUNCTION CONTEXT!!!

@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def manual_vectorization():
    import jax
    import jax.numpy as jnp

    x = jnp.arange(5)
    w = jnp.array([2., 3., 4.])

    def convolve(x, w):
        output = []
        for i in range(1, len(x)-1):
            output.append(jnp.dot(x[i-1:i+2], w))
        return jnp.array(output)

    print("Convolve:", convolve(x, w))

    xs = jnp.stack([x, x])
    ws = jnp.stack([w, w])
    
    def manually_batched_convolve(xs, ws):
        output = []
        for i in range(xs.shape[0]):
            output.append(convolve(xs[i], ws[i]))
        return jnp.stack(output)

    print("Naive solution for batched convolve", manually_batched_convolve(xs, ws))


    def manually_vectorized_convolve(xs, ws):
        output = []
        for i in range(1, xs.shape[-1] -1):
            output.append(jnp.sum(xs[:, i-1:i+2] * ws, axis=1))
        return jnp.stack(output, axis=1)

    print("Vectorized solution for batched convolve", manually_vectorized_convolve(xs, ws))
    
@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def automatic_vectorization():
    import jax
    import jax.numpy as jnp
    
    x = jnp.arange(5)
    w = jnp.array([2., 3., 4.])

    def convolve(x, w):
        output = []
        for i in range(1, len(x)-1):
            output.append(jnp.dot(x[i-1:i+2], w))
        return jnp.array(output)

    print("Convolve:", convolve(x, w))

    xs = jnp.stack([x, x])
    ws = jnp.stack([w, w])
    
    auto_batch_convolve = jax.vmap(convolve)

    print("Automatic vectorization", auto_batch_convolve(xs, ws))
    
    auto_batch_convolve_v2 = jax.vmap(convolve, in_axes=1, out_axes=1)

    xst = jnp.transpose(xs)
    wst = jnp.transpose(ws)

    print("Automatic vectorization with axes", auto_batch_convolve_v2(xst, wst))
    
    batch_convolve_v3 = jax.vmap(convolve, in_axes=[0, None])

    print("Automatic vectorization with one ax", batch_convolve_v3(xs, w))
    
    jitted_batch_convolve = jax.jit(auto_batch_convolve)

    print("Jitted Automatic vectorization", jitted_batch_convolve(xs, ws))

In [None]:
# Test our function locally 
manual_vectorization()
automatic_vectorization()

In [None]:
# Submit the function for code execution
data_scientist_client.api.services.code.request_code_execution(manual_vectorization)
data_scientist_client.api.services.code.request_code_execution(automatic_vectorization)

## Part 2 - Data Owner Reviewing and Approving Requests

In [None]:
# Get messages from domain
messages = domain_client.api.services.messages.get_all()
messages

In [None]:
from helpers import review_request, run_submitted_function, accept_request

for message in messages:
    review_request(message)
    real_result = run_submitted_function(message)
    accept_request(message, real_result)

## Part 3 - Downloading the Results

### Tutorial complete 👏

In [None]:
data_scientist_client.api.services.code.manual_vectorization()
data_scientist_client.api.services.code.automatic_vectorization()

In [None]:
if node.node_type.value == "python":
    node.land()