In [1]:
import jax
from jax import numpy as jnp
from functools import partial

In [2]:
def update_1(x, key):
  """
  Flips the value on random site for every sample.
  x is 2d array of shape (N_samples, N_sites)
  """
  idxs = jax.random.randint(key, (x.shape[0],), 0, x.shape[1])
  flip = jax.vmap(lambda s, idx: s.at[idx].set(-s.at[idx].get()))
  xu = flip(x, idxs)
  return xu

def update_2(x, key):
  """
  Flips all values on all sites.
  """
  return -x

@partial(jax.jit, static_argnums=(1, 2))
def weighted_update(x, u1, u2, key):
  """
  Applies update u1 and update u2 with a probability of 50% each to every
  sample in x by random. This is done by calculating both updates
  and selecting one afterwards
  x is 2d array of shape (N_samples, N_sites)
  """
  u1key, u2key, select_key = jax.random.split(key, 3)
  xus = jnp.stack([u1(x, u1key), u2(x,u2key)])

  choices = jax.random.choice(select_key, 2, shape=(x.shape[0],), p=jnp.array([0.5, 0.5]))
  batch_select = jax.vmap(lambda s, i: s[i], in_axes=(1, 0), out_axes=0)

  return batch_select(xus, choices)

@partial(jax.jit, static_argnums=(1, 2))
def weighted_update_branched(x, u1, u2, key):
  """
  Applies update u1 and update u2 with a probability of 50% each to every
  sample in x by random. This is done by first selecting the update and then
  calculating it.
  x is 2d array of shape (N_samples, N_sites)
  """
  ukey, select_key = jax.random.split(key, 2)
  choices = jax.random.choice(select_key, 2, shape=(x.shape[0],), p=jnp.array([0.5, 0.5]))

  cond_update = jax.vmap(lambda s, choice: jax.lax.cond(choice, u1, u2, s.reshape(1, -1), ukey))
  xu = cond_update(x, choices)

  return xu


In [4]:
rkey = jax.random.PRNGKey(0)
N_samples = 4*4096
x = jnp.ones(shape=(N_samples, 2000))
# run each update function once to trigger jax compilation
wup1 = weighted_update(x, update_1, update_2, rkey)
wup2 = weighted_update_branched(x, update_1, update_2, rkey)

In [5]:
%%timeit
up1 = weighted_update(x, update_1, update_2, rkey)

2.8 ms ± 1.91 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [6]:
%%timeit
up2 = weighted_update_branched(x, update_1, update_2, rkey)

2.82 ms ± 797 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)


see https://github.com/google/jax/pull/16335/commits/005d4ca78eec595527972de5ed80575185be05e0 and https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.cond.html
