In [18]:
import jax
import jax.numpy as jnp

# Costly function to be JIT-compiled
def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)


## Patterns to use it with covalent


1. Compile the function and run it inside each electron

*Useful if the jit is compute intensive and is in order of minutes of runtime. This ships the function to a cluster, compiles it for that architecture and then executes it there*

In [2]:

import covalent as ct


@ct.electron
def compile_run(func, *args, **kwargs):
    compiled = jax.jit(func)
    return compiled(*args, **kwargs)



2. Compile the function once and run it inside each electron

*Useful (fastest) if the jit compilation hardware is the same as when shipping the function (example local execution)*

In [None]:
@ct.electron
def compile_once(func, *args, **kwargs):
    compiled = jax.jit(func)
    return compiled

@ct.electron
def run_once(func, *args, **kwargs):
    return func(*args, **kwargs)

Running 10 parallel calls of the jitted function in both the pattern

In [14]:
PARALLEL=10
X=100_00_0

Pattern 2 : Compile once and use it multiple times 

- Needs same hardware for compilation and execution
- Fastest for local execution

![](assets/compile_once.jpg)


In [16]:
@ct.lattice
def parallel_compile_once(n,x):
  a=[]
  func=compile_once(selu)
  for i in range(n):
    result=run_once(func, x)
    a.append(result)
  return a
runid=ct.dispatch(parallel_compile_once)(n=PARALLEL,x=X)

result=ct.get_result(runid,wait=True)
print(f"Compile once and execute : {result.end_time-result.start_time}")

Compile once and execute : 0:00:05.284116


Pattern 2: Compile and run it everytime 

- Compiles for the cluster architecture and run it for long running functions
- Slow for local execution as it is compiled everytime

![compile again](assets/compile_again.jpg)

In [17]:
@ct.lattice
def parallel_compile_and_run(n,x):
  a=[]
  for i in range(n):
    result=compile_run(selu, x)
    a.append(result)
  return a
runid=ct.dispatch(parallel_compile_and_run)(n=PARALLEL,x=X)

result=ct.get_result(runid,wait=True)
print(f"Compile and execute each time: {result.end_time-result.start_time}")

Compile and execute each time: 0:00:11.689765
