## JAX arrays (jax.Array)

We typically don’t call the jax.Array constructor directly, but rather create arrays via JAX API functions. For example, jax.numpy provides familiar NumPy-style array construction functionality such as jax.numpy.zeros(), jax.numpy.linspace(), jax.numpy.arange(), etc.

In [1]:
import jax
import jax.numpy as jnp

x = jnp.arange(5)
isinstance(x, jax.Array)

True

### Array devices and sharding

In [2]:
x.devices()

{CpuDevice(id=0)}

In [3]:
x.sharding

SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host)

## Transformations

Along with functions to operate on arrays, JAX includes a number of transformations which operate on JAX functions. These include

* `jax.jit()`: Just-in-time (JIT) compilation; see Just-in-time compilation

* `jax.vmap()`: Vectorizing transform; see Automatic vectorization

* `jax.grad()`: Gradient transform; see Automatic differentiation

as well as several others. Transformations accept a function as an argument, and return a new transformed function. For example, here’s how you might JIT-compile a simple SELU function:

In [4]:
def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

selu_jit = jax.jit(selu)
print(selu_jit(1.0))

1.05


In [5]:
@jax.jit
def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

Transformations like jit(), `vmap()`, `grad()`, and others are key to using JAX effectively, and we’ll cover them in detail in later sections.

## Tracing
The magic behind transformations is the notion of a Tracer. Tracers are abstract stand-ins for array objects, and are passed to JAX functions in order to extract the sequence of operations that the function encodes.

You can see this by printing any array value within transformed JAX code; for example:

In [6]:
@jax.jit
def f(x):
  print(x)
  return x + 1

x = jnp.arange(5)
result = f(x)

Traced<ShapedArray(int32[5])>with<DynamicJaxprTrace(level=1/0)>


The value printed is not the array `x`, but a `Tracer` instance that represents essential attributes of `x`, such as its `shape` and `dtype`. By executing the function with traced values, JAX can determine the sequence of operations encoded by the function before those operations are actually executed: transformations like `jit()`, `vmap()`, and `grad()` can then map this sequence of input operations to a transformed sequence of operations.

### Jaxprs
JAX has its own intermediate representation for sequences of operations, known as a jaxpr. A jaxpr (short for JAX exPRession) is a simple representation of a functional program, comprising a sequence of primitive operations.

In [7]:
def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

In [8]:
x = jnp.arange(5.0)
jax.make_jaxpr(selu)(x)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[5][39m. [34m[22m[1mlet
    [39m[22m[22mb[35m:bool[5][39m = gt a 0.0
    c[35m:f32[5][39m = exp a
    d[35m:f32[5][39m = mul 1.6699999570846558 c
    e[35m:f32[5][39m = sub d 1.6699999570846558
    f[35m:f32[5][39m = pjit[
      name=_where
      jaxpr={ [34m[22m[1mlambda [39m[22m[22m; g[35m:bool[5][39m h[35m:f32[5][39m i[35m:f32[5][39m. [34m[22m[1mlet
          [39m[22m[22mj[35m:f32[5][39m = select_n g i h
        [34m[22m[1min [39m[22m[22m(j,) }
    ] b a e
    k[35m:f32[5][39m = mul 1.0499999523162842 f
  [34m[22m[1min [39m[22m[22m(k,) }

## Pytrees

JAX functions and transformations fundamentally operate on arrays, but in practice it is convenient to write code that works with collection of arrays: for example, a neural network might organize its parameters in a dictionary of arrays with meaningful keys. Rather than handle such structures on a case-by-case basis, JAX relies on the pytree abstraction to treat such collections in a uniform manner.

In [9]:
# (nested) list of parameters
params = [1, 2, (jnp.arange(3), jnp.ones(2))]

print(jax.tree.structure(params))
print(jax.tree.leaves(params))

PyTreeDef([*, *, (*, *)])
[1, 2, Array([0, 1, 2], dtype=int32), Array([1., 1.], dtype=float32)]


In [10]:
# Dictionary of parameters
params = {'n': 5, 'W': jnp.ones((2, 2)), 'b': jnp.zeros(2)}

print(jax.tree.structure(params))
print(jax.tree.leaves(params))

PyTreeDef({'W': *, 'b': *, 'n': *})
[Array([[1., 1.],
       [1., 1.]], dtype=float32), Array([0., 0.], dtype=float32), 5]


In [11]:
# Named tuple of parameters
from typing import NamedTuple

class Params(NamedTuple):
  a: int
  b: float

params = Params(1, 5.0)
print(jax.tree.structure(params))
print(jax.tree.leaves(params))

PyTreeDef(CustomNode(namedtuple[Params], [*, *]))
[1, 5.0]


JAX has a number of general-purpose utilities for working with PyTrees; for example the functions `jax.tree.map()` can be used to map a function to every leaf in a tree, and `jax.tree.reduce()` can be used to apply a reduction across the leaves in a tree.

# What is a pytree?
A pytree is a container-like structure built out of container-like Python objects — “leaf” pytrees and/or more pytrees. A pytree can include lists, tuples, and dicts. A leaf is anything that’s not a pytree, such as an array, but a single leaf is also a pytree.

In the context of machine learning (ML), a pytree can contain:

* Model parameters

* Dataset entries

* Reinforcement learning agent observations

When working with datasets, you can often come across pytrees (such as lists of lists of dicts).

Below is an example of a simple pytree. In JAX, you can use `jax.tree.leaves()`, to extract the flattened leaves from the trees

In [12]:
import jax
import jax.numpy as jnp

example_trees = [
    [1, 'a', object()],
    (1, (2, 3), ()),
    [1, {'k1': 2, 'k2': (3, 4)}, 5],
    {'a': 2, 'b': (2, 3)},
    jnp.array([1, 2, 3]),
]

# Print how many leaves the pytrees have.
for pytree in example_trees:
  # This `jax.tree.leaves()` method extracts the flattened leaves from the pytrees.
  leaves = jax.tree.leaves(pytree)
  print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")

[1, 'a', <object object at 0x00000162C5ADCA50>] has 3 leaves: [1, 'a', <object object at 0x00000162C5ADCA50>]
(1, (2, 3), ())                               has 3 leaves: [1, 2, 3]
[1, {'k1': 2, 'k2': (3, 4)}, 5]               has 5 leaves: [1, 2, 3, 4, 5]
{'a': 2, 'b': (2, 3)}                         has 3 leaves: [2, 2, 3]
Array([1, 2, 3], dtype=int32)                 has 1 leaves: [Array([1, 2, 3], dtype=int32)]


In [13]:
list_of_lists = [
    [1, 2, 3],
    [1, 2],
    [1, 2, 3, 4]
]

jax.tree.map(lambda x: x*2, list_of_lists)

[[2, 4, 6], [2, 4], [2, 4, 6, 8]]

In [14]:
another_list_of_lists = list_of_lists
jax.tree.map(lambda x, y: x+y, list_of_lists, another_list_of_lists)

[[2, 4, 6], [2, 4], [2, 4, 6, 8]]

### Example of jax.tree.map with ML model parameters

In [15]:
import numpy as np

def init_mlp_params(layer_widths):
  params = []
  for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
    params.append(
        dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in),
             biases=np.ones(shape=(n_out,))
            )
    )
  return params

params = init_mlp_params([1, 128, 128, 1])

In [16]:
jax.tree.map(lambda x: x.shape, params)

[{'biases': (128,), 'weights': (1, 128)},
 {'biases': (128,), 'weights': (128, 128)},
 {'biases': (1,), 'weights': (128, 1)}]

In [17]:
# Define the forward pass.
def forward(params, x):
  *hidden, last = params
  for layer in hidden:
    x = jax.nn.relu(x @ layer['weights'] + layer['biases'])
  return x @ last['weights'] + last['biases']

# Define the loss function.
def loss_fn(params, x, y):
  return jnp.mean((forward(params, x) - y) ** 2)

# Set the learning rate.
LEARNING_RATE = 0.0001

# Using the stochastic gradient descent, define the parameter update function.
# Apply `@jax.jit` for JIT compilation (speed).
@jax.jit
def update(params, x, y):
  # Calculate the gradients with `jax.grad`.
  grads = jax.grad(loss_fn)(params, x, y)
  # Note that `grads` is a pytree with the same structure as `params`.
  # `jax.grad` is one of many JAX functions that has
  # built-in support for pytrees.
  # This is useful - you can apply the SGD update using JAX pytree utilities.
  return jax.tree.map(
      lambda p, g: p - LEARNING_RATE * g, params, grads
  )

## Custom pytree nodes
This section explains how in JAX you can extend the set of Python types that will be considered internal nodes in pytrees (pytree nodes) by using `jax.tree_util.register_pytree_node()` with `jax.tree.map()`.

Why would you need this? In the previous examples, pytrees were shown as lists, tuples, and dicts, with everything else as pytree leaves. This is because if you define your own container class, it will be considered to be a pytree leaf unless you register it with JAX. This is also the case even if your container class has trees inside it. 

In [18]:
class Special(object):
  def __init__(self, x, y):
    self.x = x
    self.y = y

jax.tree.leaves([
    Special(0, 1),
    Special(2, 4),
])

[<__main__.Special at 0x162c5d13aa0>, <__main__.Special at 0x162c5ad99a0>]

In [19]:
from jax.tree_util import register_pytree_node

class RegisteredSpecial(Special):
  def __repr__(self):
    return "RegisteredSpecial(x={}, y={})".format(self.x, self.y)

def special_flatten(v):
  """Specifies a flattening recipe.

  Params:
    v: The value of the registered type to flatten.
  Returns:
    A pair of an iterable with the children to be flattened recursively,
    and some opaque auxiliary data to pass back to the unflattening recipe.
    The auxiliary data is stored in the treedef for use during unflattening.
    The auxiliary data could be used, for example, for dictionary keys.
  """
  children = (v.x, v.y)
  aux_data = None
  return (children, aux_data)

def special_unflatten(aux_data, children):
  """Specifies an unflattening recipe.

  Params:
    aux_data: The opaque data that was specified during flattening of the
      current tree definition.
    children: The unflattened children

  Returns:
    A reconstructed object of the registered type, using the specified
    children and auxiliary data.
  """
  return RegisteredSpecial(*children)

# Global registration
register_pytree_node(
    RegisteredSpecial,
    special_flatten,    # Instruct JAX what are the children nodes.
    special_unflatten   # Instruct JAX how to pack back into a `RegisteredSpecial`.
)

In [20]:
jax.tree.map(lambda x: x + 1,
  [
   RegisteredSpecial(0, 1),
   RegisteredSpecial(2, 4),
  ])

[RegisteredSpecial(x=1, y=2), RegisteredSpecial(x=3, y=5)]

In [21]:
from typing import NamedTuple, Any

class MyOtherContainer(NamedTuple):
  name: str
  a: Any
  b: Any
  c: Any

# NamedTuple subclasses are handled as pytree nodes, so
# this will work out-of-the-box.
jax.tree.leaves([
    MyOtherContainer('Alice', 1, 2, 3),
    MyOtherContainer('Bob', 4, 5, 6)
])

['Alice', 1, 2, 3, 'Bob', 4, 5, 6]

### Explicit key paths
In a pytree each leaf has a key path. A key path for a leaf is a list of keys, where the length of the list is equal to the depth of the leaf in the pytree . Each key is a hashable object that represents an index into the corresponding pytree node type. The type of the key depends on the pytree node type; for example, the type of keys for dicts is different from the type of keys for tuples.

For built-in pytree node types, the set of keys for any pytree node instance is unique. For a pytree comprising nodes with this property, the key path for each leaf is unique.

JAX has the following jax.tree_util.* methods for working with key paths:

* `jax.tree_util.tree_flatten_with_path()`: Works similarly to jax.tree.flatten(), but returns key paths.

* `jax.tree_util.tree_map_with_path()`: Works similarly to jax.tree.map(), but the function also takes key paths as arguments.

* `jax.tree_util.keystr()`: Given a general key path, returns a reader-friendly string expression.

For example, one use case is to print debugging information related to a certain leaf value:

In [22]:
import collections

ATuple = collections.namedtuple("ATuple", ('name'))

tree = [1, {'k1': 2, 'k2': (3, 4)}, ATuple('foo')]
flattened, _ = jax.tree_util.tree_flatten_with_path(tree)

for key_path, value in flattened:
  print(f'Value of tree{jax.tree_util.keystr(key_path)}: {value}')

Value of tree[0]: 1
Value of tree[1]['k1']: 2
Value of tree[1]['k2'][0]: 3
Value of tree[1]['k2'][1]: 4
Value of tree[2].name: foo


In [23]:
for key_path, _ in flattened:
  print(f'Key path of tree{jax.tree_util.keystr(key_path)}: {repr(key_path)}')

Key path of tree[0]: (SequenceKey(idx=0),)
Key path of tree[1]['k1']: (SequenceKey(idx=1), DictKey(key='k1'))
Key path of tree[1]['k2'][0]: (SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=0))
Key path of tree[1]['k2'][1]: (SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=1))
Key path of tree[2].name: (SequenceKey(idx=2), GetAttrKey(name='name'))


## Common pytree gotchas
This section covers some of the most common problems (“gotchas”) encountered when using JAX pytrees.

### Mistaking pytree nodes for leaves
A common gotcha to look out for is accidentally introducing tree nodes instead of leaves:

In [24]:
a_tree = [jnp.zeros((2, 3)), jnp.zeros((3, 4))]

# Try to make another pytree with ones instead of zeros.
shapes = jax.tree.map(lambda x: x.shape, a_tree)
jax.tree.map(jnp.ones, shapes)

[(Array([1., 1.], dtype=float32), Array([1., 1., 1.], dtype=float32)),
 (Array([1., 1., 1.], dtype=float32), Array([1., 1., 1., 1.], dtype=float32))]

### Handling of None by `jax.tree_util`
`jax.tree_util` functions treat `None` as the absence of a pytree node, not as a leaf:

In [25]:
jax.tree.leaves([None, None, None])

[]

In [26]:
jax.tree.leaves([None, None, None], is_leaf=lambda x: x is None)

[None, None, None]

Transposing pytrees with `jax.tree.map` and `jax.tree.transpose`
To transpose a pytree (turn a list of trees into a tree of lists), JAX has two functions: {func} `jax.tree.map` (more basic) and `jax.tree.transpose()` (more flexible, complex and verbose).

Option 1: Use `jax.tree.map()`. 

In [27]:
def tree_transpose(list_of_trees):
  """
  Converts a list of trees of identical structure into a single tree of lists.
  """
  return jax.tree.map(lambda *xs: list(xs), *list_of_trees)

# Convert a dataset from row-major to column-major.
episode_steps = [dict(t=1, obs=3), dict(t=2, obs=4)]
tree_transpose(episode_steps)

{'obs': [3, 4], 't': [1, 2]}

Option 2: For more complex transposes, use `jax.tree.transpose()`, which is more verbose, but allows you specify the structure of the inner and outer pytree for more flexibility

In [28]:
jax.tree.transpose(
  outer_treedef = jax.tree.structure([0 for e in episode_steps]),
  inner_treedef = jax.tree.structure(episode_steps[0]),
  pytree_to_transpose = episode_steps
)

{'obs': [3, 4], 't': [1, 2]}