In [2]:
import aesara.tensor as at
import jax.numpy as jnp
import numpy as np
import jax

In [3]:
jax_key = jax.random.PRNGKey(1701)
jax_key

DeviceArray([   0, 1701], dtype=uint32)

In [4]:
sample = 0 + jax.random.normal(key=jax_key, dtype=jnp.float16) * 1
sample_exp = jnp.exp(sample)
#jax_key['jax_state'] = random.split(jax_key, num=1)[0]

In [5]:
def sample_fn(rng, size, dtype, *parameters):
    #rng_key = rng["jax_state"]
    loc, scale = parameters
    sample = loc + jax.random.normal(rng, size, dtype) * scale
    sample_exp = jax.numpy.exp(sample)
    #rng["jax_state"] = jax.random.split(rng, num=1)[0]
    return (rng, sample_exp)

In [6]:
sample_fn(jax_key, jnp.array([5]), jnp.float32, 0, 1)

(DeviceArray([   0, 1701], dtype=uint32),
 DeviceArray([0.45558405, 3.867971  , 3.0604928 , 0.36627936, 1.8835636 ],            dtype=float32))

In [136]:
def _scatter_add_one(operand, indices, updates):

    outcome_cnts = jax.lax.scatter_add(
        operand, indices, updates, 
        jax.lax.ScatterDimensionNumbers(
            update_window_dims=(),
            inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,)
            )
        )

    return outcome_cnts

def _categorical(key, p, shape):
    shape = shape or p.shape[:-1]
    s = jax.numpy.cumsum(p, axis=-1)
    r = jax.random.uniform(key, shape=shape + (1,))
    return jax.numpy.sum(s < r, axis=-1)

def multinomial_sample_fn(rng, size, dtype=jax.numpy.int32, *parameters):
    """add sampling functionality"""

    #rng_key = rng["jax_state"]
    n, p = parameters
    n_max = jax.numpy.max(n)
    size = size or p.shape[:-1]
    indices = _categorical(rng, p, (n_max,) + size)
    
    # indices_2d = (jax.numpy.reshape(indices, (n_max, -1,))).T
    # samples_2d = jax.vmap(_scatter_add_one, (0, 0, 0))(
    #     jax.numpy.zeros((indices_2d.shape[0], p.shape[-1]), dtype=indices.dtype),
    #     jax.numpy.expand_dims(indices_2d, axis=-1),
    #     jax.numpy.ones(indices_2d.shape, dtype=indices.dtype)
    #     )

    one_hot = jax.nn.one_hot(indices, p.shape[0])
    samples = jax.numpy.sum(one_hot, axis=0, dtype=dtype, keepdims=False)
    
    # print(f"samples _scatter    : \n {samples_2d}")
    print(f"samples one_hot sum : \n {samples}")
    
    return samples

In [138]:
samples = multinomial_sample_fn(
    jax_key,
    (2, 2),
    jnp.int32, 
    10, jnp.array([1/6]*6)
)


samples one_hot sum : 
 [[[2 1 1 2 2 2]
  [4 3 0 0 1 2]]

 [[2 1 2 2 1 2]
  [5 1 2 0 0 2]]]


In [9]:
shape = jax.numpy.array([10.], dtype=jax.numpy.int32)
p = jnp.array([1/6]*6)
n_max = jax.numpy.max(10)
size = None or p.shape[:-1]

s = jax.numpy.cumsum(jnp.array([1/6]*6))
r = jax.random.uniform(jax_key, shape=(n_max,) + size + (1,)) ## 10x1

# 10x6 bool matrix indicating if label r > s (cum. probs.)
bool_matrix = s < r
jax.numpy.sum(bool_matrix, axis=-1)

DeviceArray([4, 2, 2, 5, 2, 1, 5, 5, 0, 3], dtype=int32)

In [11]:
# operand: array to be updated 
operand = jax.numpy.zeros((indices_2d.shape[0], jnp.array([1/6]*6).shape[-1]), dtype=indices_2d.dtype)
# indices: add dim to to last axis and indices to which update should be applied
indices = jax.numpy.expand_dims(indices_2d, axis=-1)
# updates: cnt += 1 to operand for each outcome occurence
updates = jax.numpy.ones(indices_2d.shape, dtype=indices.dtype)

In [12]:
operand, updates

(DeviceArray([[0, 0, 0, 0, 0, 0]], dtype=int32),
 DeviceArray([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=int32))

In [13]:
if indices.dtype == jax.numpy.int32:
    print('yay')

yay


In [14]:
p = np.array([1/6]*6, dtype=np.float64)
n = np.array([20], dtype=np.int64)

In [15]:
func = lambda *args: args

In [16]:
func(p, n)

(array([0.16666667, 0.16666667, 0.16666667, 0.16666667, 0.16666667,
        0.16666667]),
 array([20]))

In [17]:
some = np.array([1.0, 2.0], dtype=np.float64)
#anoth = np.array(1.0, dtype=np.float64)
func = lambda *args: (0, args[0])

In [18]:
func(some)

(0, array([1., 2.]))