Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement a general mixture log-probability rewrite #19

Merged
merged 4 commits into from Jul 17, 2021

Conversation

brandonwillard
Copy link
Member

@brandonwillard brandonwillard commented Jun 22, 2021

This PR implements log-probabilities for general mixtures (e.g. heterogeneous RandomVariables).

This implementation does not produce a marginalized log-probability.

It covers cases like the following:

import aesara.tensor as at

from aeppl import joint_logprob


X_rv = at.random.normal(0, 1, name="X")
Y_rv = at.random.gamma(0.5, 0.5, name="Y")

p = at.scalar("p")
I_rv = at.random.bernoulli(p, name="I")

M_rv = at.stack([X_rv, Y_rv])[I_rv]

m = M_rv.type()
M_logp = joint_logprob(M_rv, {M_rv: m})

Instead of producing a log-probability graph that represents the computation np.stack([logprob(X, obs), logprob(Y, obs)])[idx], which requires complete evaluation of logprob(X, obs) and logprob(Y, obs) even though idx will only select some of those logprob values from each mixture component, this implementation will compute something like the following:

res[idx == 0] = logprob(X[idx == 0], obs[idx == 0])
res[idx == 1] = logprob(Y[idx == 1], obs[idx == 1])

Currently, it does all the necessary setup and parsing needed to construct the desired log-probability graphs; however, the actual construction steps aren't necessarily finished. I still need to check/confirm a few things and—obviously—write tests.

@brandonwillard brandonwillard added enhancement New feature or request important This label is used to indicate priority over things not given this label graph rewriting Involves the implementation of rewrites to Aesara graphs labels Jun 22, 2021
@brandonwillard brandonwillard self-assigned this Jun 22, 2021
@brandonwillard brandonwillard force-pushed the general-mixtures branch 2 times, most recently from b4f90a4 to 9733f01 Compare June 22, 2021 07:30
aeppl/opt.py Outdated Show resolved Hide resolved
@brandonwillard brandonwillard force-pushed the general-mixtures branch 2 times, most recently from 408b18b to dd95570 Compare June 26, 2021 05:22
aeppl/opt.py Outdated Show resolved Hide resolved
@brandonwillard
Copy link
Member Author

brandonwillard commented Jul 7, 2021

The multivariate case is blocked by aesara-devs/aesara#507 and aesara-devs/aesara#509.

The latter issue is only a blocker if we want to prevent the evaluation of "unselected" mixture components. We can always produce an array of stacked logprob values evaluated conditional on each mixture component and then index that array with the mixing distribution; however, this is a prohibitively slow approach for large enough mixture models (e.g. almost anything involving time-varying mixture components used with non-trivially-sized datasets).

@brandonwillard
Copy link
Member Author

The two tests in this PR pass locally under the changes in aesara-devs/aesara#516.

@codecov
Copy link

codecov bot commented Jul 11, 2021

Codecov Report

Merging #19 (a74b7ec) into main (b07159b) will increase coverage by 2.27%.
The diff coverage is 93.75%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main      #19      +/-   ##
==========================================
+ Coverage   90.52%   92.80%   +2.27%     
==========================================
  Files           5        6       +1     
  Lines         802      917     +115     
  Branches      103      119      +16     
==========================================
+ Hits          726      851     +125     
+ Misses         44       30      -14     
- Partials       32       36       +4     
Impacted Files Coverage Δ
aeppl/joint_logprob.py 98.00% <83.33%> (+6.33%) ⬆️
aeppl/mixture.py 93.63% <93.63%> (ø)
aeppl/opt.py 82.60% <100.00%> (+22.00%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update b07159b...a74b7ec. Read the comment docs.

@brandonwillard brandonwillard force-pushed the general-mixtures branch 2 times, most recently from 2d53b58 to 33bc7d5 Compare July 11, 2021 21:37
@brandonwillard brandonwillard marked this pull request as ready for review July 11, 2021 21:54
aeppl/__init__.py Show resolved Hide resolved
rv.owner and isinstance(rv.owner.op, RandomVariable) for rv in mixture_rvs
):
# Currently, all mixture components must be `RandomVariable` outputs
# TODO: Allow constants and make them Dirac-deltas
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could have a rewrite phase before this one where we convert constants to Dirac deltas (not needed for this PR)

aeppl/mixture.py Outdated
# problem)
join_axis = 0 # op.join_axis

logp_val = at.full(tuple(at.shape(value)), -np.inf, dtype=value.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't be safer to fill it with nans?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Safer in what way? It probably shouldn't matter what that initial value is, since the indexing is supposed to cover all values.

@ricardoV94
Copy link
Contributor

Is this working with categorical distributions as well? If so, we could add a test

@brandonwillard
Copy link
Member Author

Is this working with categorical distributions as well? If so, we could add a test

Yes, definitely.

@brandonwillard brandonwillard force-pushed the general-mixtures branch 3 times, most recently from a12a939 to 94b342f Compare July 16, 2021 00:59
@brandonwillard brandonwillard merged commit 9436fc0 into aesara-devs:main Jul 17, 2021
@brandonwillard brandonwillard deleted the general-mixtures branch July 17, 2021 02:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request graph rewriting Involves the implementation of rewrites to Aesara graphs important This label is used to indicate priority over things not given this label
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants