Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MeasurableDimshuffle logp fails when value broadcastable information is lost #191

Open
ricardoV94 opened this issue Oct 8, 2022 · 0 comments

Comments

@ricardoV94
Copy link
Contributor

ricardoV94 commented Oct 8, 2022

import aesara.tensor as at
from aeppl import joint_logprob

# Using a multivariate to force our Dimshuffle rewrite
x = at.random.dirichlet([1, 1, 1])[None, ...]
y = at.random.dirichlet([1, 1, 1], size=(9,))
z = at.concatenate([x, y], axis=0)

z_vv = z.clone()
joint_logprob({z: z_vv})  # ValueError: Cannot drop a non-broadcastable dimension: [False, False], [1]

The logp of the concatenate will split the value of z_vv in two, and the first one will not be inferred to be broadcastable at runtime. We could fix this, by adding a specify_shape of 1 for the dimensions we are dropping in the MeasurableDimshuffle logp here:

value = value.dimshuffle(undo_ds)

More generally is there a reason why we don't add always add a specify_shape when dropping dimensions in Aesara via Dimshuffle instead of raising that error?

@ricardoV94 ricardoV94 changed the title Dimshuffle rewrite fails when information about value broadcastable dimension is lost MeasurableDimshuffle logp fails when value broadcastable information is lost Oct 8, 2022
ricardoV94 added a commit to ricardoV94/pymc that referenced this issue Oct 8, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant