Skip to content

Commit

Permalink
Migrate flax from using old-style PRNG keys to new-style typed PRNG keys
Browse files Browse the repository at this point in the history
Functionally, this involves changing uses of jax.random.PRNGKey to jax.random.key. For details on this change and the motivation behind it, see the draft JEP at google/jax#17297, and please feel free to offer comments and feedback!

PiperOrigin-RevId: 565475405
  • Loading branch information
Jake VanderPlas authored and 8bitmp3 committed Oct 9, 2023
1 parent 93d31fd commit e51e71c
Show file tree
Hide file tree
Showing 113 changed files with 612 additions and 629 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ vNext
-
-
-
-
- Use new typed PRNG keys throughout flax: this essentially involved changing
uses of `jax.random.PRNGKey` to `jax.random.key`.
(See [JEP 9263](https://github.com/google/jax/pull/17297) for details).
If you notice dispatch performance regressions after this change, be sure
you update `jax` to version 0.4.16 or newer.
-
-
-
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class MLP(nn.Module):

model = MLP([12, 8, 4])
batch = jnp.ones((32, 10))
variables = model.init(jax.random.PRNGKey(0), batch)
variables = model.init(jax.random.key(0), batch)
output = model.apply(variables, batch)
```

Expand All @@ -142,7 +142,7 @@ class CNN(nn.Module):

model = CNN()
batch = jnp.ones((32, 64, 64, 10)) # (N, H, W, C) format
variables = model.init(jax.random.PRNGKey(0), batch)
variables = model.init(jax.random.key(0), batch)
output = model.apply(variables, batch)
```

Expand Down Expand Up @@ -174,7 +174,7 @@ model = AutoEncoder(encoder_widths=[20, 10, 5],
decoder_widths=[5, 10, 20],
input_shape=(12,))
batch = jnp.ones((16, 12))
variables = model.init(jax.random.PRNGKey(0), batch)
variables = model.init(jax.random.key(0), batch)
encoded = model.apply(variables, batch, method=model.encode)
decoded = model.apply(variables, encoded, method=model.decode)
```
Expand Down
8 changes: 4 additions & 4 deletions docs/developer_notes/lift.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class ManualVmapMLP(nn.Module):
return apply_fn({'params': mlp_params}, xs)

xs = jnp.ones((3, 4))
variables = ManualVmapMLP().init(random.PRNGKey(0), xs)
variables = ManualVmapMLP().init(random.key(0), xs)
print(jax.tree_util.tree_map(jnp.shape, variables['params']))
"""==>
{
Expand Down Expand Up @@ -270,7 +270,7 @@ def lift_transpose(fn, target='params', variables=True, rngs=True):
rng_filters=(rngs,))

x = jnp.ones((3, 2))
y, params = init(lift_transpose(core_nn.dense))(random.PRNGKey(0), x, 4)
y, params = init(lift_transpose(core_nn.dense))(random.key(0), x, 4)
```

NOTE that most users should not need to interact with `pack` directly.
Expand Down Expand Up @@ -310,7 +310,7 @@ class LinenVmapMLP(nn.Module):
VmapMLP = nn.vmap(MLP, variable_axes={'params': 0}, split_rngs={'params': True}, in_axes=0)
return VmapMLP(name='mlp')(xs)

variables = LinenVmapMLP().init(random.PRNGKey(0), xs)
variables = LinenVmapMLP().init(random.key(0), xs)
print(jax.tree_util.tree_map(jnp.shape, variables['params']))
"""==>
{
Expand Down Expand Up @@ -346,7 +346,7 @@ class LinenStatefulVmapMLP(nn.Module):
def __call__(self, xs, *, train):
VmapMLP = nn.vmap(StatefulMLP, variable_axes={'params': 0, 'batch_stats': 0}, split_rngs={'params': True}, in_axes=0)
return VmapMLP(name='mlp')(xs, train=train)
variables = LinenStatefulVmapMLP().init(random.PRNGKey(0), xs)
variables = LinenStatefulVmapMLP().init(random.key(0), xs)
```

All we had to add to `nn.vmap` is `'batch_stats': 0`, indicating that the batch stats are vectorized rather than shared along the first axis.
Expand Down
12 changes: 6 additions & 6 deletions docs/developer_notes/module_lifecycle.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ Now we want to construct and use the ``MLP`` Module:

mlp = MLP(hidden_size=5, out_size=3)
x = jax.numpy.ones((1, 2))
variables = mlp.init(random.PRNGKey(0), x)
variables = mlp.init(random.key(0), x)
y = mlp.apply(variables, x)


Expand All @@ -70,8 +70,8 @@ Let's take a closer look at initialization. Surprisingly, there actually is no s

.. testcode::

# equivalent to: variables = mlp.init(random.PRNGKey(0), x)
_, variables = mlp.apply({}, x, rngs={"params": random.PRNGKey(0)}, mutable=True)
# equivalent to: variables = mlp.init(random.key(0), x)
_, variables = mlp.apply({}, x, rngs={"params": random.key(0)}, mutable=True)


Thus, ``init`` is nothing more than a wrapper around ``apply`` where:
Expand Down Expand Up @@ -155,7 +155,7 @@ Another benefit of defining submodules and/or variables inline is that you can a

mdl = CompactScaledMLP(hidden_size=4, out_size=5)
x = jax.numpy.ones((3, 2))
vars = mdl.init(random.PRNGKey(0), x)
vars = mdl.init(random.key(0), x)
assert vars["params"]["scale"].shape == (2,)

Many of the standard Linen Modules like ``nn.Dense`` use shape inference already to avoid the need to specify input shapes (like the number of input features to a Dense layer).
Expand Down Expand Up @@ -207,7 +207,7 @@ The latter is done as follows:
return mdl(z, "decode")

mdl = CorrectModule()
vars = nn.init(init_fn, mdl)(random.PRNGKey(0))
vars = nn.init(init_fn, mdl)(random.key(0))
assert vars["params"]["Dense_0"]["kernel"].shape == (2, 8)
assert vars["params"]["Dense_1"]["kernel"].shape == (8, 4)

Expand Down Expand Up @@ -348,7 +348,7 @@ Function closure is the most common way to accidentally hide a JAX array or Line

x = jax.numpy.ones((3, 2))
mdl = Foo()
vars = mdl.init(random.PRNGKey(0), x)
vars = mdl.init(random.key(0), x)
assert vars['params']['Dense_0']['kernel'].shape == (3, 2, 2)


Expand Down
2 changes: 1 addition & 1 deletion docs/flip/1009-optimizer-api.md
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def get_learning_rate(step):


model = Model()
rng = jax.random.PRNGKey(0)
rng = jax.random.key(0)
ds = tfds.load('mnist')['train'].take(160).map(pp).batch(16)
batch = next(iter(ds))
variables = model.init(rng, jnp.array(batch['image'][:1]))
Expand Down
4 changes: 2 additions & 2 deletions docs/flip/2396-rnn.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __call__(self, x):
nn.LSTMCell, variable_broadcast="params", split_rngs={"params": False}
)
carry = LSTM.initialize_carry(
jax.random.PRNGKey(0), batch_dims=x.shape[:1], size=self.hidden_size
jax.random.key(0), batch_dims=x.shape[:1], size=self.hidden_size
)
carry, x = LSTM()(carry, x)
return x
Expand Down Expand Up @@ -91,7 +91,7 @@ Where:
* `initial_carry`: the initial carry, if not provided it will be initialized
using the cell's :meth:`RNNCellBase.initialize_carry` method.
* `init_key`: a PRNG key used to initialize the carry, if not provided
``jax.random.PRNGKey(0)`` will be used. Most cells will ignore this
``jax.random.key(0)`` will be used. Most cells will ignore this
argument.
* `seq_lengths`: an optional integer array of shape ``(*batch)`` indicating
the length of each sequence, elements whose index in the time dimension
Expand Down
4 changes: 2 additions & 2 deletions docs/getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@
"import jax.numpy as jnp # JAX NumPy\n",
"\n",
"cnn = CNN()\n",
"print(cnn.tabulate(jax.random.PRNGKey(0), jnp.ones((1, 28, 28, 1))))"
"print(cnn.tabulate(jax.random.key(0), jnp.ones((1, 28, 28, 1))))"
]
},
{
Expand Down Expand Up @@ -521,7 +521,7 @@
},
"outputs": [],
"source": [
"init_rng = jax.random.PRNGKey(0)"
"init_rng = jax.random.key(0)"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ import jax
import jax.numpy as jnp # JAX NumPy
cnn = CNN()
print(cnn.tabulate(jax.random.PRNGKey(0), jnp.ones((1, 28, 28, 1))))
print(cnn.tabulate(jax.random.key(0), jnp.ones((1, 28, 28, 1))))
```

+++ {"id": "4b5ac16e"}
Expand Down Expand Up @@ -332,7 +332,7 @@ executionInfo:
timestamp: 1673483485436
id: e4f6f4d3
---
init_rng = jax.random.PRNGKey(0)
init_rng = jax.random.key(0)
```

+++ {"id": "80fbb60b"}
Expand Down
4 changes: 2 additions & 2 deletions docs/guides/batch_norm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,15 @@ The ``batch_stats`` collection must be extracted from the ``variables`` for late

mlp = MLP()
x = jnp.ones((1, 3))
variables = mlp.init(jax.random.PRNGKey(0), x)
variables = mlp.init(jax.random.key(0), x)
params = variables['params']


jax.tree_util.tree_map(jnp.shape, variables)
---
mlp = MLP()
x = jnp.ones((1, 3))
variables = mlp.init(jax.random.PRNGKey(0), x, train=False) #!
variables = mlp.init(jax.random.key(0), x, train=False) #!
params = variables['params']
batch_stats = variables['batch_stats'] #!

Expand Down
10 changes: 5 additions & 5 deletions docs/guides/convert_pytorch_to_flax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ and the Flax kernel has shape [inC, outC]. Transposing the kernel will do the tr
# [outC, inC] -> [inC, outC]
kernel = jnp.transpose(kernel, (1, 0))

key = random.PRNGKey(0)
key = random.key(0)
x = random.normal(key, (1, 3))

variables = {'params': {'kernel': kernel, 'bias': bias}}
Expand Down Expand Up @@ -62,7 +62,7 @@ and the Flax kernel has shape [kH, kW, inC, outC]. Transposing the kernel will d
# [outC, inC, kH, kW] -> [kH, kW, inC, outC]
kernel = jnp.transpose(kernel, (2, 3, 1, 0))

key = random.PRNGKey(0)
key = random.key(0)
x = random.normal(key, (1, 6, 6, 3))

variables = {'params': {'kernel': kernel, 'bias': bias}}
Expand Down Expand Up @@ -154,7 +154,7 @@ Other than the transpose operation before reshaping, we can convert the weights
variables = {'params': {'conv': {'kernel': conv_kernel, 'bias': conv_bias},
'fc': {'kernel': fc_kernel, 'bias': fc_bias}}}

key = random.PRNGKey(0)
key = random.key(0)
x = random.normal(key, (1, 6, 6, 3))

j_out = j_model.apply(variables, x)
Expand Down Expand Up @@ -192,7 +192,7 @@ while Flax multiplies the estimated statistic with ``momentum`` and the new obse
variables = {'params': {'scale': scale, 'bias': bias},
'batch_stats': {'mean': mean, 'var': var}}

key = random.PRNGKey(0)
key = random.key(0)
x = random.normal(key, (1, 6, 6, 3))

j_bn = nn.BatchNorm(momentum=0.9, use_running_average=True)
Expand Down Expand Up @@ -241,7 +241,7 @@ operation. ``nn.pool()`` is the core function behind |nn.avg_pool()|_ and |nn.ma
return y


key = random.PRNGKey(0)
key = random.key(0)
x = random.normal(key, (1, 6, 6, 3))

j_out = avg_pool(x, window_shape=(2, 2), strides=(1, 1), padding=((1, 1), (1, 1)))
Expand Down
6 changes: 3 additions & 3 deletions docs/guides/dropout.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ desirable properties for neural networks. To learn more, refer to the
`Pseudorandom numbers in JAX tutorial <https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html>`__.

**Note:** Recall that JAX has an explicit way of giving you PRNG keys:
you can fork the main PRNG state (such as ``key = jax.random.PRNGKey(seed=0)``)
you can fork the main PRNG state (such as ``key = jax.random.key(seed=0)``)
into multiple new PRNG keys with ``key, subkey = jax.random.split(key)``. You
can refresh your memory in
`🔪 JAX - The Sharp Bits 🔪 Randomness and PRNG keys <https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#jax-prng>`__.
Expand All @@ -41,10 +41,10 @@ into three keys, including one for Flax Linen ``Dropout``.
:title_right: With Dropout
:sync:

root_key = jax.random.PRNGKey(seed=0)
root_key = jax.random.key(seed=0)
main_key, params_key = jax.random.split(key=root_key)
---
root_key = jax.random.PRNGKey(seed=0)
root_key = jax.random.key(seed=0)
main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3) #!

**Note:** In Flax, you provide *PRNG streams* with *names*, so that you can use them later
Expand Down
4 changes: 2 additions & 2 deletions docs/guides/ensembling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ directly.

train_ds, test_ds = get_datasets()
#!
rng = jax.random.PRNGKey(0)
rng = jax.random.key(0)

rng, init_rng = jax.random.split(rng)
state = create_train_state(init_rng, learning_rate, momentum) #!
Expand All @@ -246,7 +246,7 @@ directly.
---
train_ds, test_ds = get_datasets()
test_ds = jax_utils.replicate(test_ds) #!
rng = jax.random.PRNGKey(0)
rng = jax.random.key(0)

rng, init_rng = jax.random.split(rng)
state = create_train_state(jax.random.split(init_rng, jax.device_count()), #!
Expand Down
18 changes: 9 additions & 9 deletions docs/guides/extracting_intermediates.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ Note that, by default ``sow`` appends values every time it is called:
return output, features

batch = jnp.ones((1,28,28,1))
variables = init(jax.random.PRNGKey(0), batch)
variables = init(jax.random.key(0), batch)
preds, feats = predict(variables, batch)

assert len(feats) == 2 # Tuple with two values since module was called twice.
Expand Down Expand Up @@ -180,7 +180,7 @@ avoid using ``nn.compact`` altogether.
return RefactoredCNN().apply({"params": params}, x,
method=lambda module, x: module.features(x))

params = init(jax.random.PRNGKey(0), batch)
params = init(jax.random.key(0), batch)

features(params, batch)

Expand Down Expand Up @@ -209,7 +209,7 @@ In the following code example we check if any intermediate activations are non-f
fin = jax.tree_util.tree_map(lambda xs: jnp.all(jnp.isfinite(xs)), intermediates)
return y, fin

variables = init(jax.random.PRNGKey(0), batch)
variables = init(jax.random.key(0), batch)
y, is_finite = predict(variables, batch)
all_finite = all(jax.tree_util.tree_leaves(is_finite))
assert all_finite, "non-finite intermediate detected!"
Expand Down Expand Up @@ -250,8 +250,8 @@ non-layer intermediates, but the filter function won't be applied to it.
def predict(params, x):
return Model().apply({"params": params}, x, capture_intermediates=True)

batch = jax.random.uniform(jax.random.PRNGKey(1), (1,3))
params = init(jax.random.PRNGKey(0), batch)
batch = jax.random.uniform(jax.random.key(1), (1,3))
params = init(jax.random.key(0), batch)
preds, feats = predict(params, batch)
feats # intermediate c in Model was not stored because it's not a Flax layer
---
Expand All @@ -276,8 +276,8 @@ non-layer intermediates, but the filter function won't be applied to it.
filter_fn = lambda mdl, method_name: isinstance(mdl.name, str) and (mdl.name in {'Dense_0', 'Dense_2'}) #!
return Model().apply({"params": params}, x, capture_intermediates=filter_fn) #!

batch = jax.random.uniform(jax.random.PRNGKey(1), (1,3))
params = init(jax.random.PRNGKey(0), batch)
batch = jax.random.uniform(jax.random.key(1), (1,3))
params = init(jax.random.key(0), batch)
preds, feats = predict(params, batch)
feats # intermediate c in Model is stored and isn't filtered out by the filter function #!

Expand Down Expand Up @@ -337,7 +337,7 @@ your model more explicitly.
return Sequential(SeqCNN().layers[0:7]).apply({"params": params}, x)

batch = jnp.ones((1,28,28,1))
params = init(jax.random.PRNGKey(0), batch)
params = init(jax.random.key(0), batch)
features(params, batch)

Extracting gradients of intermediate values
Expand Down Expand Up @@ -367,7 +367,7 @@ the model:
y = jnp.empty((1, 2)) # random data

model = Model()
variables = model.init(jax.random.PRNGKey(1), x)
variables = model.init(jax.random.key(1), x)
params, perturbations = variables['params'], variables['perturbations']

Finally compute the gradients of the loss with respect to the perturbations,
Expand Down
Loading

0 comments on commit e51e71c

Please sign in to comment.