In [1]:
import aesara.tensor as at
import jax.numpy as jnp
import numpy as np
import jax

In [2]:
jax_key = jax.random.PRNGKey(1701)
jax_key

DeviceArray([   0, 1701], dtype=uint32)

In [3]:
sample = 0 + jax.random.normal(key=jax_key, dtype=jnp.float16) * 1
sample_exp = jnp.exp(sample)
#jax_key['jax_state'] = random.split(jax_key, num=1)[0]

In [4]:
def sample_fn(rng, size, dtype, *parameters):
    #rng_key = rng["jax_state"]
    loc, scale = parameters
    sample = loc + jax.random.normal(rng, size, dtype) * scale
    sample_exp = jax.numpy.exp(sample)
    #rng["jax_state"] = jax.random.split(rng, num=1)[0]
    return (rng, sample_exp)

In [5]:
sample_fn(jax_key, jnp.array([5]), jnp.float32, 0, 1)

(DeviceArray([   0, 1701], dtype=uint32),
 DeviceArray([0.45558405, 3.867971  , 3.0604928 , 0.36627936, 1.8835636 ],            dtype=float32))

In [94]:
def _scatter_add_one(operand, indices, updates):

    outcome_cnts = jax.lax.scatter_add(
        operand, indices, updates, 
        jax.lax.ScatterDimensionNumbers(
            update_window_dims=(),
            inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,)
            )
        )

    return outcome_cnts

def _categorical(key, p, shape):
    shape = shape or p.shape[:-1]
    s = jax.numpy.cumsum(p, axis=-1)
    r = jax.random.uniform(key, shape=shape + (1,))
    return jax.numpy.sum(s < r, axis=-1)

def multinomial_sample_fn(rng, size, dtype, *parameters):
    """add sampling functionality"""

    #rng_key = rng["jax_state"]
    n, p = parameters
    n_max = jax.numpy.max(n)
    size = size or p.shape[:-1]
    print(f"size: {size}")
    indices = _categorical(rng, p, (n_max,) + size)
    
    indices_2d = (jax.numpy.reshape(indices, (n_max, -1,))).T
    print(f"indices 2d: {indices_2d}")
    samples_2d = jax.vmap(_scatter_add_one, (0, 0, 0))(
        jax.numpy.zeros((indices_2d.shape[0], p.shape[-1]), dtype=indices.dtype),
        jax.numpy.expand_dims(indices_2d, axis=-1),
        jax.numpy.ones(indices_2d.shape, dtype=indices.dtype)
        )
    
    print(f"samples_2d: {samples_2d}")
    print(f"size: {size}, p.shape: {p.shape[-1:]}")
    return indices_2d, jax.numpy.reshape(samples_2d, size + p.shape[-1:])
    

In [96]:
indices_2d, samples = multinomial_sample_fn(
    jax_key,
    None,
    jnp.int32, 
    10, jnp.array([1/6]*6)
)

size: ()
indices 2d: [[4 2 2 5 2 1 5 5 0 3]]
samples_2d: [[1 1 3 1 1 3]]
size: (), p.shape: (6,)


In [97]:
shape = jax.numpy.array([10.], dtype=jax.numpy.int32)
p = jnp.array([1/6]*6)
n_max = jax.numpy.max(10)
size = None or p.shape[:-1]

s = jax.numpy.cumsum(jnp.array([1/6]*6))
r = jax.random.uniform(jax_key, shape=(n_max,) + size + (1,)) ## 10x1

# 10x6 bool matrix indicating if label r > s (cum. probs.)
bool_matrix = s < r
jax.numpy.sum(bool_matrix, axis=-1)

DeviceArray([4, 2, 2, 5, 2, 1, 5, 5, 0, 3], dtype=int32)

In [120]:
# operand: array to be updated 
operand = jax.numpy.zeros((indices_2d.shape[0], jnp.array([1/6]*6).shape[-1]), dtype=indices_2d.dtype)
# indices: add dim to to last axis and indices to which update should be applied
indices = jax.numpy.expand_dims(indices_2d, axis=-1)
# updates: cnt += 1 to operand for each outcome occurence
updates = jax.numpy.ones(indices_2d.shape, dtype=indices.dtype)

In [121]:
operand, updates

(DeviceArray([[0, 0, 0, 0, 0, 0]], dtype=int32),
 DeviceArray([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=int32))

In [56]:
indices

DeviceArray([[[4],
              [2],
              [2],
              [5],
              [2],
              [1],
              [5],
              [5],
              [0],
              [3]]], dtype=int32)