# JAX 101 - 03 Vectorization
Link to the original JAX tutorial: https://jax.readthedocs.io/en/latest/jax-101/03-vectorization.html

## Part 0 - Data Owner Setup

In [1]:
# Import the necessary libraries
import syft as sy
sy.requires(">=0.8,<0.9")

import jax
import jax.numpy as jnp
import numpy as np



✅ The installed version of syft==0.8.1b3 matches the requirement >=0.8 and the requirement <0.9


In [2]:
# Launch the domain
node = sy.orchestra.launch(name="test-domain-1", reset=True, dev_mode=True)
data_owner_client = node.login(email="info@openmined.org", password="changethis")

SQLite Store Path:
!open file:///var/folders/sz/hkfsnn612hq56r7cs5rd540r0000gn/T/7bca415d13ed4ec881f0d0aede098dbb.sqlite



## Part 1 - Data Scientist

In [3]:
# Register a client to the domain
node = sy.orchestra.launch(name="test-domain-1")
data_scientist_client = node.client
data_scientist_client.register(name="Jane Doe", email="jane@caltech.edu", password="abc123", institution="Caltech", website="https://www.caltech.edu/")
data_scientist_client.login(email="jane@caltech.edu", password="abc123")

SQLite Store Path:
!open file:///var/folders/sz/hkfsnn612hq56r7cs5rd540r0000gn/T/7bca415d13ed4ec881f0d0aede098dbb.sqlite



<SyftClient - test-domain-1 <7bca415d13ed4ec881f0d0aede098dbb>: PythonConnection>

In [4]:
# Create a function for code execution
# ATTENTION: ALL LIBRARIES USED SHOULD BE DEFINED INSIDE THE FUNCTION CONTEXT!!!

@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def manual_vectorization():
    import jax
    import jax.numpy as jnp

    x = jnp.arange(5)
    w = jnp.array([2., 3., 4.])

    def convolve(x, w):
        output = []
        for i in range(1, len(x)-1):
            output.append(jnp.dot(x[i-1:i+2], w))
        return jnp.array(output)

    print("Convolve:", convolve(x, w))

    xs = jnp.stack([x, x])
    ws = jnp.stack([w, w])
    
    def manually_batched_convolve(xs, ws):
        output = []
        for i in range(xs.shape[0]):
            output.append(convolve(xs[i], ws[i]))
        return jnp.stack(output)

    print("Naive solution for batched convolve", manually_batched_convolve(xs, ws))


    def manually_vectorized_convolve(xs, ws):
        output = []
        for i in range(1, xs.shape[-1] -1):
            output.append(jnp.sum(xs[:, i-1:i+2] * ws, axis=1))
        return jnp.stack(output, axis=1)

    print("Vectorized solution for batched convolve", manually_vectorized_convolve(xs, ws))
    
@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def automatic_vectorization():
    import jax
    import jax.numpy as jnp
    
    x = jnp.arange(5)
    w = jnp.array([2., 3., 4.])

    def convolve(x, w):
        output = []
        for i in range(1, len(x)-1):
            output.append(jnp.dot(x[i-1:i+2], w))
        return jnp.array(output)

    print("Convolve:", convolve(x, w))

    xs = jnp.stack([x, x])
    ws = jnp.stack([w, w])
    
    auto_batch_convolve = jax.vmap(convolve)

    print("Automatic vectorization", auto_batch_convolve(xs, ws))
    
    auto_batch_convolve_v2 = jax.vmap(convolve, in_axes=1, out_axes=1)

    xst = jnp.transpose(xs)
    wst = jnp.transpose(ws)

    print("Automatic vectorization with axes", auto_batch_convolve_v2(xst, wst))
    
    batch_convolve_v3 = jax.vmap(convolve, in_axes=[0, None])

    print("Automatic vectorization with one ax", batch_convolve_v3(xs, w))
    
    jitted_batch_convolve = jax.jit(auto_batch_convolve)

    print("Jitted Automatic vectorization", jitted_batch_convolve(xs, ws))

In [5]:
# Test our function locally 
manual_vectorization()
automatic_vectorization()

Convolve: [11. 20. 29.]
Naive solution for batched convolve [[11. 20. 29.]
 [11. 20. 29.]]
Vectorized solution for batched convolve [[11. 20. 29.]
 [11. 20. 29.]]
Convolve: [11. 20. 29.]
Automatic vectorization [[11. 20. 29.]
 [11. 20. 29.]]
Automatic vectorization with axes [[11. 11.]
 [20. 20.]
 [29. 29.]]
Automatic vectorization with one ax [[11. 20. 29.]
 [11. 20. 29.]]
Jitted Automatic vectorization [[11. 20. 29.]
 [11. 20. 29.]]


In [6]:
# Submit the function for code execution
data_scientist_client.api.services.code.request_code_execution(manual_vectorization)
data_scientist_client.api.services.code.request_code_execution(automatic_vectorization)

```python
class Request:
  id: str = 89d5600734374ef587248ee2c22c7262
  requesting_user_verify_key: str = 8ab25caf213965e212ecef5d4a870e456933e271f2080ead7210905dadf185c7
  approving_user_verify_key: str = None
  request_time: str = 2023-05-29 06:21:52
  approval_time: str = None
  status: str = RequestStatus.PENDING
  node_uid: str = 7bca415d13ed4ec881f0d0aede098dbb
  request_hash: str = "db2d9bbe74afbc26be66c3b317a1e6fe8acab604e4bf5ce65f85bd3458fbb4cd"
  changes: str = [syft.service.request.request.UserCodeStatusChange]

```

## Part 2 - Data Owner Reviewing and Approving Requests

In [7]:
data_owner_client = node.login(email="info@openmined.org", password="changethis")

In [8]:
# Get messages from domain
messages = data_owner_client.api.services.messages.get_all()
messages

Unnamed: 0,type,id,subject,status,created_at,linked_obj
0,syft.service.message.messages.Message,1cb3436730a0458ab56101145d0d7607,Approval Request,MessageStatus.UNDELIVERED,2023-05-29 06:21:52,<<class 'syft.service.request.request.Request'...
1,syft.service.message.messages.Message,4a483cd316ca4e8fb5e4e65ddf4609ae,Approval Request,MessageStatus.UNDELIVERED,2023-05-29 06:21:52,<<class 'syft.service.request.request.Request'...


In [9]:
from helpers import review_request, run_submitted_function, accept_request

for message in messages:
    review_request(message)
    real_result = run_submitted_function(message)
    accept_request(message, real_result)

manual_vectorization
@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def manual_vectorization():
    import jax
    import jax.numpy as jnp

    x = jnp.arange(5)
    w = jnp.array([2., 3., 4.])

    def convolve(x, w):
        output = []
        for i in range(1, len(x)-1):
            output.append(jnp.dot(x[i-1:i+2], w))
        return jnp.array(output)

    print("Convolve:", convolve(x, w))

    xs = jnp.stack([x, x])
    ws = jnp.stack([w, w])
    
    def manually_batched_convolve(xs, ws):
        output = []
        for i in range(xs.shape[0]):
            output.append(convolve(xs[i], ws[i]))
        return jnp.stack(output)

    print("Naive solution for batched convolve", manually_batched_convolve(xs, ws))


    def manually_vectorized_convolve(xs, ws):
        output = []
        for i in range(1, xs.shape[-1] -1):
            output.append(jnp.sum(xs[:, i-1:i+2] * ws, axis=1))
        return jnp.stack(output

exec_result=syft.service.code.user_code.UserCodeExecutionResult
action_object=Pointer:
syft.service.code.user_code.UserCodeExecutionResult


automatic_vectorization
@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def automatic_vectorization():
    import jax
    import jax.numpy as jnp
    
    x = jnp.arange(5)
    w = jnp.array([2., 3., 4.])

    def convolve(x, w):
        output = []
        for i in range(1, len(x)-1):
            output.append(jnp.dot(x[i-1:i+2], w))
        return jnp.array(output)

    print("Convolve:", convolve(x, w))

    xs = jnp.stack([x, x])
    ws = jnp.stack([w, w])
    
    auto_batch_convolve = jax.vmap(convolve)

    print("Automatic vectorization", auto_batch_convolve(xs, ws))
    
    auto_batch_convolve_v2 = jax.vmap(convolve, in_axes=1, out_axes=1)

    xst = jnp.transpose(xs)
    wst = jnp.transpose(ws)

    print("Automatic vectorization with axes", auto_batch_convolve_v2(xst, wst))
    
    batch_convolve_v3 = jax.vmap(convolve, in_axes=[0, None])

    print("Automatic vectorization with one ax", batch_convolve_v3(xs,

exec_result=syft.service.code.user_code.UserCodeExecutionResult
action_object=Pointer:
syft.service.code.user_code.UserCodeExecutionResult


message='Request 89d5600734374ef587248ee2c22c7262 changes applied'


<Figure size 640x480 with 0 Axes>

## Part 3 - Downloading the Results

### Tutorial complete 👏

In [10]:
result = data_scientist_client.api.services.code.manual_vectorization()
assert not isinstance(result, sy.SyftError)

result = data_scientist_client.api.services.code.automatic_vectorization()
assert not isinstance(result, sy.SyftError)

In [11]:
if node.node_type.value == "python":
    node.land()