In [1]:
import aesara.tensor as at
import jax.numpy as jnp
import numpy as np
import jax

In [2]:
jax_key = jax.random.PRNGKey(1701)
jax_key

DeviceArray([   0, 1701], dtype=uint32)

In [3]:
sample = 0 + jax.random.normal(key=jax_key, dtype=jnp.float16) * 1
sample_exp = jnp.exp(sample)
#jax_key['jax_state'] = random.split(jax_key, num=1)[0]

In [4]:
def sample_fn(rng, size, dtype, *parameters):
    #rng_key = rng["jax_state"]
    loc, scale = parameters
    sample = loc + jax.random.normal(rng, size, dtype) * scale
    sample_exp = jax.numpy.exp(sample)
    #rng["jax_state"] = jax.random.split(rng, num=1)[0]
    return (rng, sample_exp)

In [5]:
sample_fn(jax_key, jnp.array([5]), jnp.float32, 0, 1)

(DeviceArray([   0, 1701], dtype=uint32),
 DeviceArray([0.45558405, 3.867971  , 3.0604928 , 0.36627936, 1.8835636 ],            dtype=float32))

In [6]:
def sample_fn(rng, size, dtype, *parameters):
    rng_key = rng["jax_state"]
    (
        df,
        loc,
        scale,
    ) = parameters
    sample = loc + jax.random.t(rng_key, df, size, dtype) * scale
    rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
    return (rng, sample)

In [7]:
def _scatter_add_one(operand, indices, updates):
    return lax.scatter_add(operand, indices, updates,
                           lax.ScatterDimensionNumbers(update_window_dims=(),
                                                       inserted_window_dims=(0,),
                                                       scatter_dims_to_operand_dims=(0,)))


def _categorical(key, p, shape):
    # this implementation is fast when event shape is small, and slow otherwise
    # Ref: https://stackoverflow.com/a/34190035
    shape = shape or p.shape[:-1]
    s = np.cumsum(p, axis=-1)
    print(f"sum of probs. {s}")
    r = random.uniform(key, shape=shape + (1,))
    # FIXME: replace this computation by using binary search as suggested in the above
    # reference. A while_loop + vmap for a reshaped 2D array would be enough.
    return np.sum(s < r, axis=-1)


def categorical(key, p, shape=()):
    return _categorical(key, p, shape)

def promote_shapes(*args, shape=()):
    # adapted from lax.lax_numpy
    if len(args) < 2 and not shape:
        return args
    else:
        shapes = [np.shape(arg) for arg in args]
        num_dims = len(lax.broadcast_shapes(shape, *shapes))
        return [lax.reshape(arg, (1,) * (num_dims - len(s)) + s)
                if len(s) < num_dims else arg for arg, s in zip(args, shapes)]

In [8]:
def _multinomial(key, p, n, n_max, shape=()):

    print(np.shape(n), np.shape(p)[:-1])
    if np.shape(n) != np.shape(p)[:-1]:
        print('in first if')
        broadcast_shape = lax.broadcast_shapes(np.shape(n), np.shape(p)[:-1])
        n = np.broadcast_to(n, broadcast_shape)
        p = np.broadcast_to(p, broadcast_shape + np.shape(p)[-1:])

    shape = shape or p.shape[:-1]
    print(f"shape: {shape}")
    indices = categorical(key, p, (n_max,) + shape)
    print(f"indices: {indices}")
    # mask out values when counts is heterogeneous
    if np.ndim(n) > 0:
        print("in second if")
        mask = promote_shapes(np.arange(n_max) < np.expand_dims(n, -1), shape=shape + (n_max,))[0]
        mask = np.moveaxis(mask, -1, 0).astype(indices.dtype)
        excess = np.concatenate([np.expand_dims(n_max - n, -1), np.zeros(np.shape(n) + (p.shape[-1] - 1,))], -1)
    else:
        mask = 1
        excess = 0
    # NB: we transpose to move batch shape to the front
    indices_2D = (np.reshape(indices * mask, (n_max, -1,))).T
    samples_2D = vmap(_scatter_add_one, (0, 0, 0))(np.zeros((indices_2D.shape[0], p.shape[-1]),
                                                            dtype=indices.dtype),
                                                   np.expand_dims(indices_2D, axis=-1),
                                                   np.ones(indices_2D.shape, dtype=indices.dtype))
    print(f"samples 2d: {samples_2D}")
    print(f"shape: {shape}, p.shape: {p.shape[-1:]}")
    return np.reshape(samples_2D, shape + p.shape[-1:]) - excess

In [9]:
def multinomial(key, p, n, shape=()):
    n_max = int(np.max(n))
    return _multinomial(key, p, n, n_max, shape)

In [10]:
multinomial(jax_key, jnp.array([1/6]*6), 10, (2, 2))

() ()
shape: (2, 2)
sum of probs. [0.16666667 0.33333334 0.5        0.6666667  0.8333334  1.        ]


NameError: name 'random' is not defined

In [26]:
def _scatter_add_one(operand, indices, updates):
    return jax.lax.scatter_add(operand, indices, updates,
                           jax.lax.ScatterDimensionNumbers(update_window_dims=(),
                                                       inserted_window_dims=(0,),
                                                       scatter_dims_to_operand_dims=(0,)))

def _categorical(key, p, shape):
    shape = shape or p.shape[:-1]
    s = jax.numpy.cumsum(p, axis=-1)
    print(f"cumsum probs: {s}")
    r = jax.random.uniform(key, shape=shape + (1,))
    #print(f"uniform dist r: {r}")
    return jax.numpy.sum(s < r, axis=-1)

def multinomial_sample_fn(rng, size, dtype, *parameters):
    """add sampling functionality"""

    #rng_key = rng["jax_state"]
    n, p = parameters
    n_max = jax.numpy.max(n)
    size = size or p.shape[:-1]
    print(f"size: {size}")
    indices = _categorical(rng, p, (n_max,) + size)
    print(f"indices 1d: {indices}")
    
    indices_2d = (jax.numpy.reshape(indices, (n_max, -1,))).T
    print(f"indices 2d: {indices_2d}")
    samples_2d = jax.vmap(_scatter_add_one, (0, 0, 0))(
        jax.numpy.zeros((indices_2d.shape[0], p.shape[-1]),
        dtype=indices.dtype),
        jax.numpy.expand_dims(indices_2d, axis=-1),
        jax.numpy.ones(indices_2d.shape, dtype=indices.dtype))
    
    # jax_unique = jax.numpy.unique(
    #     indices_2d, return_counts=True)
    # jax_unique = jax.numpy.bincount(
    #     indices_2d)
    # print(f"jax unique: {jax_unique[1]}")
    
    print(f"samples_2d: {samples_2d}")
    print(f"size: {size}, p.shape: {p.shape[-1:]}")
    return jax.numpy.reshape(samples_2d, size + p.shape[-1:])
    

In [28]:
multinomial_sample_fn(
    jax_key,
    None,
    jnp.int32, 
    10, jnp.array([1/6]*6)
)

size: ()
cumsum probs: [0.16666667 0.33333334 0.5        0.6666667  0.8333334  1.        ]
indices 1d: [4 2 2 5 2 1 5 5 0 3]
indices 2d: [[4 2 2 5 2 1 5 5 0 3]]
samples_2d: [[1 1 3 1 1 3]]
size: (), p.shape: (6,)


DeviceArray([1, 1, 3, 1, 1, 3], dtype=int32)