In [1]:
import jax
import jax.numpy as jnp

#### Transformations

In [5]:
def selu(x, alpha=1.67, lambda_=1.05):
    return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

selu_jit = jax.jit(selu)
print(selu_jit(1.0))

1.05


In [9]:
@jax.jit
def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

selu(1.0)

Array(1.05, dtype=float32, weak_type=True)

#### Tracing

In [11]:
@jax.jit
def f(x):
    print(x)
    return x + 1

x = jnp.arange(5)
result = f(x)

Traced<ShapedArray(int32[5])>with<DynamicJaxprTrace(level=1/0)>


#### Jaxprs

In [12]:
x = jnp.arange(5.0)
jax.make_jaxpr(selu)(x)

{ lambda ; a:f32[5]. let
    b:f32[5] = pjit[
      name=selu
      jaxpr={ lambda ; c:f32[5]. let
          d:bool[5] = gt c 0.0
          e:f32[5] = exp c
          f:f32[5] = mul 1.6699999570846558 e
          g:f32[5] = sub f 1.6699999570846558
          h:f32[5] = pjit[
            name=_where
            jaxpr={ lambda ; i:bool[5] j:f32[5] k:f32[5]. let
                l:f32[5] = select_n i k j
              in (l,) }
          ] d c g
          m:f32[5] = mul 1.0499999523162842 h
        in (m,) }
    ] a
  in (b,) }

#### Pytrees

In [14]:
params = [1, 2, (jnp.arange(3), jnp.ones(2))]

print(jax.tree.structure(params))
print(jax.tree.leaves(params))

PyTreeDef([*, *, (*, *)])
[1, 2, Array([0, 1, 2], dtype=int32), Array([1., 1.], dtype=float32)]


In [15]:
params = {'n': 5, 'W': jnp.ones((2, 2)), 'b': jnp.zeros(2)}

print(jax.tree.structure(params))
print(jax.tree.leaves(params))

PyTreeDef({'W': *, 'b': *, 'n': *})
[Array([[1., 1.],
       [1., 1.]], dtype=float32), Array([0., 0.], dtype=float32), 5]


In [16]:
from typing import NamedTuple

class Params(NamedTuple):
  a: int
  b: float

params = Params(1, 5.0)
print(jax.tree.structure(params))
print(jax.tree.leaves(params))

PyTreeDef(CustomNode(namedtuple[Params], [*, *]))
[1, 5.0]
