In [1]:
import numpy as np
import jax
import jax.numpy as jnp

In [2]:
from jax import jit
x = np.random.rand(1000,1000)
y= jnp.array(x)

def a(x):

  for _ in range(10):
      x = 0.5*x + 0.1* jnp.sin(x)

  return x

b = jit(a)

%timeit -n 5 -r 5 a(y).block_until_ready()
# 5 loops, best of 5: 10.8 ms per loop

%timeit -n 5 -r 5 b(y).block_until_ready()
# 5 loops, best of 5: 341 µs per loop



The slowest run took 6.46 times longer than the fastest. This could mean that an intermediate result is being cached.
154 ms ± 158 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)
28.4 ms ± 18.4 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [3]:
from jax import grad

def c(x):
  return 3*x**2 + 2*x + 5

def c_prime(x):
  return 6*x +2

grad(c)(1.0)
# DeviceArray(8., dtype=float32)

c_prime(1.0)

8.0

In [5]:
from jax import vmap
import jax.numpy as jnp

def d(x):

  return jnp.square(x)

d(jnp.arange(10))

#DeviceArray([ 0,  1,  4,  9, 16, 25, 36, 49, 64, 81], dtype=int32)

vmap(d)(jnp.arange(10))

DeviceArray([ 0,  1,  4,  9, 16, 25, 36, 49, 64, 81], dtype=int32)

In [6]:
from jax import grad, jit

def cf(x):
  if x < 5:
    return 5. * x ** 2
  else:
    return -6 * x

print(grad(cf)(4.))  # ok!
print(grad(cf)(7.))  # ok!

40.0
-6.0


In [7]:
@jit
def a(x):
  for i in range(3):
    x = 2 * x
  return x

print(a(3))

24


In [8]:
@jit
def b(x):
  y = 0.
  for i in range(x.shape[0]):
    y = y + x[i]
  return y

print(b(jnp.array([1., 2., 3.])))

6.0


In [9]:
from jax import grad, jit

@jit
def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x
try:
  f(2)
except Exception as e:
  print("Exception {}".format(e))

Exception Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
While tracing the function f at /tmp/ipykernel_3735/319401360.py:3 for jit, this concrete value was not available in Python because it depends on the value of the argument 'x'.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError


In [10]:
def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

f = jit(f, static_argnums=(0,))

print(f(2.))

12.0


In [11]:
def f(x, n):
  y = 0.
  for i in range(n):
    y = y + x[i]
  return y

f = jit(f, static_argnums=(1,))

f(jnp.array([2., 3., 4.]), 2)

DeviceArray(5., dtype=float32)

In [12]:
from jax import random

key = random.PRNGKey(5)

print(key)

[0 5]


In [None]:
import numpy as np
from jax import pmap

def e(x):
  return jnp.sin(x) + x**2

e(np.arange(4))
#DeviceArray([0.       , 1.841471 , 4.9092975, 9.14112  ], dtype=float32)

pmap(e)(np.arange(4))

In [None]:
#won't run unless you have various devices running parallel to each other
from functools import partial
from jax.lax import pmean

@partial(pmap, axis_name="i")
def normalize(x):

  return x/ pmean(x,'i')

normalize(np.arange(8.))