# Pseudorandom numbers
In this section we focus on `jax.random` and pseudo random number generation (PRNG); that is, the process of algorithmically generating sequences of numbers whose properties approximate the properties of sequences of random numbers sampled from an appropriate distribution.

PRNG-generated sequences are not truly random because they are actually determined by their initial value, which is typically referred to as the `seed`, and each step of random sampling is a deterministic function of some `state` that is carried over from a sample to the next.

Pseudo random number generation is an essential component of any machine learning or scientific computing framework. Generally, JAX strives to be compatible with NumPy, but pseudo random number generation is a notable exception.

To better understand the difference between the approaches taken by JAX and NumPy when it comes to random number generation we will discuss both approaches in this section.

### Random numbers in NumPy
Pseudo random number generation is natively supported in NumPy by the `numpy.random` module. In NumPy, pseudo random number generation is based on a global `state`, which can be set to a deterministic initial condition using `numpy.random.seed()`.

In [1]:
import numpy as np
np.random.seed(0)

In [2]:
def print_truncated_random_state():
  """To avoid spamming the outputs, print only part of the state."""
  full_random_state = np.random.get_state()
  print(str(full_random_state)[:460], '...')

print_truncated_random_state()

('MT19937', array([         0,          1, 1812433255, 1900727105, 1208447044,
       2481403966, 4042607538,  337614300, 3232553940, 1018809052,
       3202401494, 1775180719, 3192392114,  594215549,  184016991,
        829906058,  610491522, 3879932251, 3139825610,  297902587,
       4075895579, 2943625357, 3530655617, 1423771745, 2135928312,
       2891506774, 1066338622,  135451537,  933040465, 2759011858,
       2273819758, 3545703099, 2516396728, 127 ...


In [3]:
np.random.seed(0)
print_truncated_random_state()

('MT19937', array([         0,          1, 1812433255, 1900727105, 1208447044,
       2481403966, 4042607538,  337614300, 3232553940, 1018809052,
       3202401494, 1775180719, 3192392114,  594215549,  184016991,
        829906058,  610491522, 3879932251, 3139825610,  297902587,
       4075895579, 2943625357, 3530655617, 1423771745, 2135928312,
       2891506774, 1066338622,  135451537,  933040465, 2759011858,
       2273819758, 3545703099, 2516396728, 127 ...


In [4]:
_ = np.random.uniform()
print_truncated_random_state()

('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,
       3904844661,  676747479, 2085143622, 1056793272, 3812477442,
       2168787041,  275552121, 2696932952, 3432054210, 1657102335,
       3518946594,  962584079, 1051271004, 3806145045, 1414436097,
       2032348584, 1661738718, 1116708477, 2562755208, 3176189976,
        696824676, 2399811678, 3992505346,  569184356, 2626558620,
        136797809, 4273176064,  296167901, 343 ...


In [5]:
np.random.seed(0)
print(np.random.uniform(size=3))

[0.5488135  0.71518937 0.60276338]


In [6]:
np.random.seed(0)
print("individually:", np.stack([np.random.uniform() for _ in range(3)]))

np.random.seed(0)
print("all at once: ", np.random.uniform(size=3))

individually: [0.5488135  0.71518937 0.60276338]
all at once:  [0.5488135  0.71518937 0.60276338]


## Random numbers in JAX
JAX’s random number generation differs from NumPy’s in important ways, because NumPy’s PRNG design makes it hard to simultaneously guarantee a number of desirable properties. Specifically, in JAX we want PRNG generation to be:

1. reproducible,

2. parallelizable,

3. vectorisable.

We will discuss why in the following. First, we will focus on the implications of a PRNG design based on a global state.

In [7]:
import numpy as np

np.random.seed(0)

def bar(): return np.random.uniform()
def baz(): return np.random.uniform()

def foo(): return bar() + 2 * baz()

print(foo())

1.9791922366721637


The function `foo` sums two scalars sampled from a uniform distribution.

The output of this code can only satisfy requirement #1 if we assume a predictable order of execution for `bar()` and `baz()`. This is not a problem in NumPy, which always evaluates code in the order defined by the Python interpreter. In JAX, however, this is more problematic: for efficient execution, we want the JIT compiler to be free to reorder, elide, and fuse various operations in the function we define. Further, when executing in multi-device environments, execution efficiency would be hampered by the need for each process to synchronize a global state.

### Explicit random state
To avoid this issue, JAX avoids implicit global random state, and instead tracks state explicitly via a random `key`

In [8]:
from jax import random

key = random.key(42)
print(key)

Array((), dtype=key<fry>) overlaying:
[ 0 42]


In [9]:
print(random.normal(key))
print(random.normal(key))

-0.18471177
-0.18471177


In [10]:
for i in range(3):
  new_key, subkey = random.split(key)
  del key  # The old key is consumed by split() -- we must never use it again.

  val = random.normal(subkey)
  del subkey  # The subkey is consumed by normal().

  print(f"draw {i}: {val}")
  key = new_key  # new_key is safe to use in the next iteration.

draw 0: 1.369469404220581
draw 1: -0.19947023689746857
draw 2: -2.298278331756592


In [11]:
key, subkey = random.split(key)
key, *forty_two_subkeys = random.split(key, num=43)

### Lack of sequential equivalence
Another difference between NumPy’s and JAX’s random modules relates to the sequential equivalence guarantee mentioned above.

As in NumPy, JAX’s random module also allows sampling of vectors of numbers. However, JAX does not provide a sequential equivalence guarantee, because doing so would interfere with the vectorization on SIMD hardware (requirement #3 above).

In the example below, sampling 3 values out of a normal distribution individually using three subkeys gives a different result to using giving a single key and specifying `shape=(3,)`

In [12]:
key = random.key(42)
subkeys = random.split(key, 3)
sequence = np.stack([random.normal(subkey) for subkey in subkeys])
print("individually:", sequence)

key = random.key(42)
print("all at once: ", random.normal(key, shape=(3,)))

individually: [-0.04838832  0.10796154 -1.2226542 ]
all at once:  [ 0.18693547 -1.2806505  -1.5593132 ]


In [13]:
import jax
print("vectorized:", jax.vmap(random.normal)(subkeys))

vectorized: [-0.04838832  0.10796154 -1.2226542 ]


# JAX PRNG Design

We want a PRNG design that

1. is **expressive** in that it is convenient to use and it doesn’t constrain the user’s ability to write numerical programs with exactly the behavior that they want,

2. enables **reproducible** program execution in a backend-independent way,

3. has semantics that are **invariant to `@jit` compilation boundaries and device backends**,

4. enables **vectorization for generating array values** using SIMD hardware,

5. is **parallelizable** in that it doesn’t add sequencing constraints between random function calls that otherwise would have no data dependence,

6. scales to multi-replica, multi-core, and distributed computation,

7. **fits with JAX and XLA semantics** and design philosophies (which are ultimately motivated by other practical concerns).

## Three programming models and toy example programs

In [14]:
def foo(): return bar() + baz()
def bar(): return rand(RNG, (3, 4))
def baz(): return rand(RNG, (3, 4))
def main():
  global RNG
  RNG = RandomState(0)
  return foo()

To achieve reproducibility here we would need to control the order of evaluation for bar() and baz() even though there is no explicit data dependence from one to the other. This kind of sequencing requirement stemming from reproducibility (#2) violates parallelizability (#5) and doesn’t fit with JAX or XLA’s functional semantics (#6) in which subexpressions can be evaluated in any order. Even if we didn’t require reproducibility and thus allowed any evaluation order, parallelization across calls (#5) would still be made difficult by the need to update shared state. Moreover, because the same PRNG state would need to be accessed and maintained in both Python and any compiled code, this model would likely lead to engineering challenges to achieve compilation invariance (#3) and scaling to multiple replicas (#6). Finally, the expressiveness is limited (#1) because there is no way for foo() to call bar() or baz() without affecting its own (implicit) PRNG state.

Whether the model supports vectorization (#4) depends on some additional details. In Numpy, PRNG vectorization is limited by a sequential-equivalent guarantee:

In [15]:
rng = np.random.RandomState(0)

rng.randn(2)

array([1.76405235, 0.40015721])

In [16]:
rng = np.random.RandomState(0)

np.stack([rng.randn() for _ in range(2)])

array([1.76405235, 0.40015721])

To allow for vectorization (#4) within primitive PRNG function calls that generate arrays (e.g. to rand() with a shape argument), we drop this sequential-equivalent guarantee. This vectorization can be supported by any of the three programming models discussed in this section, though it motivates the implementation in terms of a counter-based PRNG as described in the next section.

The stateful PRNG user programming model is not promising. Here’s an example of a functional model but lacking a key ingredient that we call splitting:

In [17]:
def foo(rng_1):
   y, rng_2 = baz(rng_1)
   z, rng_3 = bar(rng_2)
   return y + z, rng_3

def bar(x, rng):
  val, new_rng = rand(rng, (3, 4))
  return val, new_rng

def baz(x, rng):
  val, new_rng = rand(rng, (3, 4))
  return val, new_rng

def main():
  foo(RandomState(0))

This model explicitly threads the PRNG state through all functions (primitive or non-primitive) that generate random values: that is, every random function must both accept and return the state. Now there is an explicit data dependence between the call to baz() and the call to bar() in foo(), so the data flow (and hence sequencing) is made explicit and fits with JAX’s existing semantics (#7), unlike in the previous model. This explicit threading can also make the semantics invariant to compilation boundaries (#3).

Explicit threading is inconvenient for the programmer. But worse, it hasn’t actually improved the expressiveness (#1): there is still no way for foo() to call into bar() or baz() while maintaining its own PRNG state. Without knowledge of their callers or the subroutines they call, functions must defensively pass in and return the rng state everywhere. Moreover, it also doesn’t improve the prospects for parallelization (#5) or scaling to multiple replicas (#6) because everything is still sequential, even if the sequencing is made explicit in the functional programming sense.

In short, making the code functional by explicitly threading state isn’t enough to achieve our expressiveness (#1) and performance (#5, #6) goals.

The key problem in both the previous models is that there’s too much sequencing. To reduce the amount of sequential dependence we use functional splittable PRNGs. Splitting is a mechanism to ‘fork’ a new PRNG state into two PRNG states while maintaining the usual desirable PRNG properties.

In [18]:
def foo(rng_1):
   rng_2, rng_3 = split(rng_1, 2)
   return bar(rng_2) + baz(rng_3)

def bar(x, rng):
  return rand(rng, (3, 4))

def baz(x, rng):
  return rand(rng, (3, 4))

def main():
  foo(RandomState(0))

Some points to notice:

1. there is no sequential dependence between the calls to bar() and baz() and they can be evaluated in either order without affecting the value of the result, which solves the remaining performance goals (#5, #6),

2. functions do not need to return updated versions of PRNGs and it is straightforward to call a random subroutine without affecting existing PRNG states, improving the expressiveness (#1) from the other functional model.

The example doesn’t show it, but as a consequence of the choice (2) the only way to advance the PRNG state is to call split(). That is, we have two ways to achieve (1), and they differ in whether they burden the user program with explicit calls to split(), as in the above example, or instead burden the user program with explicit threading. We prefer the former, i.e. the version with explicit splitting, because we can easily implement the explicit-threading version in terms of it.

## Design
We can use the counter-based PRNG design, and in particular the Threefry hash function, as described in Parallel random numbers: as easy as 1, 2, 3. We use the counter to achieve efficient vectorization: for a given key we can generate an array of values in a vectorized fashion by mapping the hash function over a range of integers [k + 1, …, k + sample_size]. We use the key together with the hash function to implement splittable PRNGs: that is, splitting is a way to generate two new keys from an existing one.

`type Sample = Int256`
`type Key = Sample  -- important identification for splitting`
`type Count = Int32`

`hash :: Key -> Count -> Int256  -- output type equal to Key and Sample`

`split :: Key -> (Key, Key)`
`split key = (hash key 0, hash key 1)`

`draw_samples :: Key -> Int -> [Sample]`
`draw_samples key n = map (hash key) [1..n]`

Surprisingly, drawing a sample is very similar to splitting! The key is the difference in the type of the output (even though the types are identified): in one case the value is to be used in forming random samples of interest (e.g. turning random bits into a Float representing a random normal) while in the other case the value is to be used as a key for further hashing.

The asymmetry in the hash function arguments, of type Key and Count, is that the latter is trivial and computationally cheap to advance by an arbitrary amount, since we just need to increase the integer value, while the former is only advanced by hashing. That’s why we use the count argument for vectorization.

## More realistic example user programs
Here’s what a training loop on the host might look like when the step requires a PRNG (maybe for dropout or for VAE training)

In [19]:
def Dropout(rate, mode='train'):
  def init_fun(input_shape):
    return input_shape, ()
  def apply_fun(rng, params, inputs):
    if mode == 'train':
      keep = lax.random.bernoulli(rng, rate, inputs.shape)
      return np.where(keep, inputs / rate, 0)
    else:
      return inputs
  return init_fun, apply_fun

In [20]:
def serial(*layers):
  init_funs, apply_funs = zip(*layers)
  def init_fun(input_shape):
    ...
  def apply_fun(rng, params, inputs):
    rngs = split(rng, len(layers))
    for rng, param, apply_fun in zip(rngs, params, apply_funs):
      inputs = apply_fun(rng, param, inputs)
    return inputs
  return init_fun, apply_fun

def parallel(*layers):
  init_funs, apply_funs = zip(*layers)
  def init_fun(input_shape):
    ...
  def apply_fun(rng, params, inputs):
    rngs = split(rng, len(layers))
    return [f(r, p, x) for f, r, p, x in zip(apply_funs, rngs, params, inputs)]
  return init_fun, apply_fun

### Tradeoffs and alternatives
1. We’re not exploiting any device hardware PRNG

    * We don’t currently have enough control over the hardware PRNG’s state for all backends.

    * Even if we did, it would be backend-dependent and we might have to introduce sequential dependencies between random calls to ensure deterministic ordering and hence reproducibility.

    * We don’t know of any workloads for which the software PRNG should become a bottleneck.

    * We could consider providing an additional API that allows access to a hardware PRNG for users who want to give up other desiderata (like strict reproducibility).

2. We give up the sequential equivalent guarantee, in which creating a random array in one call produces the same values as creating the flattened array one random element at a time.

    * This property is likely incompatible with vectorization (a high priority).

    * We don’t know of any users or examples for which this property is important.

    * Users could write a layer on top of this API to provide this guarantee.

3. We can’t follow the `numpy.random` API exactly.