-
-
Notifications
You must be signed in to change notification settings - Fork 155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MultinomialRV
JAX implementation
#1360
Conversation
Thanks! We also need to figure out if the licenses are compatible and how to do proper attribution if you took inspiration from someone else's implementation. It looks like Numpyro is licensed under Apache 2.0 |
MultinomialRV
MultinomialRV
MultinomialRV
MultinomialRV
JAX implementation
Based on the NumPyro Apache License 2.0 section 4, we may reproduce and distribute copies of the Work or Derivative Works (the JAX implementation of MultinomialRV) provided we:
Since inspiration was drawn from a few functions and not an entire file, I suggest we include, in addition to (1), (2), (3), in the documentation for this RV, something along the lines
|
aesara/link/jax/dispatch/random.py
Outdated
def _categorical(key, p, shape): | ||
shape = shape or p.shape[:-1] | ||
s = jax.numpy.cumsum(p, axis=-1) | ||
r = jax.random.uniform(key, shape=shape + (1,)) | ||
|
||
return jax.numpy.sum(s < r, axis=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am surprised because JAX does have an implementation for the categorical distribution here that uses their implementation for the Gumbel distribution. Is this justified in the codebase, or is it just because it was implemented at a time where jax.random.categorical
was not available (which you should be able to determine with git blame
)?
aesara/link/jax/dispatch/random.py
Outdated
samples_2d = jax.vmap(_scatter_add_one, (0, 0, 0))( | ||
jax.numpy.zeros((outcomes_2d.shape[0], p.shape[-1]), dtype=outcomes.dtype), | ||
jax.numpy.expand_dims(outcomes_2d, axis=-1), | ||
jax.numpy.ones(outcomes_2d.shape, dtype=outcomes.dtype) | ||
) | ||
|
||
sample = jax.numpy.reshape(samples_2d, size + p.shape[-1:]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couldn't we use jax.nn.one_hot
on the output of the categorical and then reduce the resulting tensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, your proposal seems to be a more elegant solution; thanks. Just committed.
b83b1ac
to
3b006eb
Compare
I rebased your branch on |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I performed the git pull --merge
, it kept a copy of the previous code (without the new key splitting). This was likely a mistake on my end. Therefore, I deleted it.
aesara/link/jax/dispatch/random.py
Outdated
samples_2d = jax.vmap(_scatter_add_one, (0, 0, 0))( | ||
jax.numpy.zeros((outcomes_2d.shape[0], p.shape[-1]), dtype=outcomes.dtype), | ||
jax.numpy.expand_dims(outcomes_2d, axis=-1), | ||
jax.numpy.ones(outcomes_2d.shape, dtype=outcomes.dtype) | ||
) | ||
|
||
sample = jax.numpy.reshape(samples_2d, size + p.shape[-1:]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, your proposal seems to be a more elegant solution; thanks. Just committed.
Yes you need to |
aesara/link/jax/dispatch/random.py
Outdated
def _categorical(key, p, shape): | ||
shape = shape or p.shape[:-1] | ||
s = jax.numpy.cumsum(p, axis=-1) | ||
r = jax.random.uniform(key, shape=shape + (1,)) | ||
|
||
return jax.numpy.sum(s < r, axis=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is that different from jax.random.categorical
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the jax.random.categorical
uses their implementation of the Gumbel distribution. However, after using git blame
, it seems the NumPyro team used this implementation before the implementation of the jax.random.categorical
.
If we would like to update this code to use the jax.random.categorical
, it would be the following:
def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
n, p = parameters
n_max = jax.numpy.max(n)
size = size or p.shape[:-1]
logits = jax.scipy.special.logit(p)
indices = jax.random.categorical(jax_key, logits, shape=(n_max,) + size)
one_hot = jax.nn.one_hot(indices, p.shape[0])
sample = jax.numpy.sum(one_hot, axis=0, dtype=dtype, keepdims=False)
rng["jax_state"] = rng_key
return (rng, sample)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That looks much simpler, great!
Since the multinomial distribution is slightly more complex than the other ones when it comes to shapes we should make sure that the output shape of the samples that are generated in the JAX backend is identical to those of the samples generated with the other backends.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You just need to check the Gumbel implementation, they form a N by N array.
.gitignore
Outdated
@@ -55,3 +55,4 @@ aesara-venv/ | |||
testing-report.html | |||
coverage.xml | |||
.coverage.* | |||
jax_multinomial_test.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like a job for a "local" Git ignore (see here).
@GStechschulte I think we should follow @AdrienCorenflos's suggestion here. I'll take another look at it this week. |
8932215
to
662c586
Compare
@GStechschulte I rebased your branch on |
662c586
to
d93fb8d
Compare
d93fb8d
to
148217a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've squared, rebased, and added a few fixes. The JAX steps are not properly accounting for the shapes of the distribution parameters, so that needs to be finished. I refactored the tests so that they cover one of the most basic cases and another that should confirm that the sizes/shapes are handled correctly (when they are).
@rlouf, I had to add a case for Constant
s in assert_size_argument_jax_compatible
. You'll need to confirm that this is valid more generally.
FYI: the tests aren't being run with shape inference, since the jax_mode
used by the tests doesn't include "ShapeOpt"
(or whatever its tag is). Standard JAX mode should, since it's included with the "fast_run"
tag, but, if we want to test for non-trivial shape scenarios (e.g. the shape value isn't explicitly constant, but can be "inferred" as a constant value), we'll need to add it.
That's valid. I was so focused on the complex case that I forgot the simplest one. |
This draft PR is a work in progress and contains a JAX implementation of
MultinomialRV
for issue #1326. The implementation builds off theMultinomial Distribution
implementation in NumPyro. Likewise, the output is similar to that of thenumpy
implementation. Below, you will find a brief outline of the functions used to construct theMultinomialRV
.def _categorical(key, p, shape)
def _scatter_add_ones(operand, indices, update)
jax.lax.scatter_add()
functionoperand
is a zero filled array.indices
is theoutcomes
array with an added dimension and specifies the indices to which the update should be applied to.update
is an array filled with ones and can be thought of as acnt += 1
for eachoperand
array is updated+1
using theupdate
array according to the outcomes in theindices
array.I still need to add a test for this. Thanks!