New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement a general mixture log-probability rewrite #19
Implement a general mixture log-probability rewrite #19
Conversation
b4f90a4
to
9733f01
Compare
408b18b
to
dd95570
Compare
dd95570
to
6240cc0
Compare
13ccc2c
to
e265ba9
Compare
The multivariate case is blocked by aesara-devs/aesara#507 and aesara-devs/aesara#509. The latter issue is only a blocker if we want to prevent the evaluation of "unselected" mixture components. We can always produce an array of stacked |
e265ba9
to
6cd2419
Compare
The two tests in this PR pass locally under the changes in aesara-devs/aesara#516. |
6cd2419
to
b81d998
Compare
Codecov Report
@@ Coverage Diff @@
## main #19 +/- ##
==========================================
+ Coverage 90.52% 92.80% +2.27%
==========================================
Files 5 6 +1
Lines 802 917 +115
Branches 103 119 +16
==========================================
+ Hits 726 851 +125
+ Misses 44 30 -14
- Partials 32 36 +4
Continue to review full report at Codecov.
|
2d53b58
to
33bc7d5
Compare
rv.owner and isinstance(rv.owner.op, RandomVariable) for rv in mixture_rvs | ||
): | ||
# Currently, all mixture components must be `RandomVariable` outputs | ||
# TODO: Allow constants and make them Dirac-deltas |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could have a rewrite phase before this one where we convert constants to Dirac deltas (not needed for this PR)
aeppl/mixture.py
Outdated
# problem) | ||
join_axis = 0 # op.join_axis | ||
|
||
logp_val = at.full(tuple(at.shape(value)), -np.inf, dtype=value.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't be safer to fill it with nan
s?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Safer in what way? It probably shouldn't matter what that initial value is, since the indexing is supposed to cover all values.
Is this working with categorical distributions as well? If so, we could add a test |
Yes, definitely. |
a12a939
to
94b342f
Compare
94b342f
to
a74b7ec
Compare
This PR implements log-probabilities for general mixtures (e.g. heterogeneous
RandomVariable
s).This implementation does not produce a marginalized log-probability.
It covers cases like the following:
Instead of producing a log-probability graph that represents the computation
np.stack([logprob(X, obs), logprob(Y, obs)])[idx]
, which requires complete evaluation oflogprob(X, obs)
andlogprob(Y, obs)
even thoughidx
will only select some of thoselogprob
values from each mixture component, this implementation will compute something like the following:Currently, it does all the necessary setup and parsing needed to construct the desired log-probability graphs; however, the actual construction steps aren't necessarily finished. I still need to check/confirm a few things and—obviously—write tests.