Fix _DistributionTerm.expand return-type annotation#667
Conversation
`expand` was `@defop`-annotated to return `jax.Array`, which routed its result through `_ArrayTerm` and broke chained distribution methods like `Normal(mu_term, 1.0).expand([J]).to_event(1)`. Annotate the return type as `dist.Distribution` so the result is no longer dispatched as an array. Note: `to_event` was already correctly annotated.
|
Not ready yet. Reproduction in #666 still fails, now as ( |
Without this, `defdata` dispatch on `dist.Distribution` falls through to `collections.abc.Callable` (since `Distribution` defines `__call__`) and produces `_CallableTerm`, breaking chained distribution methods like `Normal(mu_term, 1.0).expand([J]).to_event(1)`. The fallback subclasses `_DistributionTerm` and carries forward the receiver's distribution family, so further chained methods remain available on the resulting term. Add a regression test pinning the chain-construction behavior. Note: this resolves the AttributeError reported in #666. Running the full repro under MCMC still fails downstream because other distribution property defops (`support`, `batch_shape`, ...) raise `NotHandled` on non-eager receivers and route through the same defdata pattern — a broader integration matter, not the chain-construction bug.
The original issue framed the bug as blocking "the standard NumPyro vectorised hierarchical-model idiom". Add a test that exercises the idiomatic form (`numpyro.plate` around a sample) end-to-end under MCMC and asserts sample shapes. Companion to the existing narrow chain test.
The plate-based MCMC test was using `Normal(mu, 1.0)` where `mu` is a sampled real array, so the receiver was fully eager and `.expand` / `.to_event` were never called. The test passed on master too — it didn't pin anything #666-specific. The narrow chain test does the actual job and is verified to fail on master.
|
Fixed |
Per review feedback, no new term classes. Two minimal changes: 1. ``_DistributionTerm._is_eager`` (on the base) now recurses into Term arguments that are themselves ``_DistributionTerm``s, consulting their ``_is_eager``. Previously any ``_DistributionTerm``-valued arg was treated as non-eager (since ``is_eager_array`` returned False for Distribution Terms), which prevented downstream methods from resolving even when the underlying base was eager. 2. ``_DistributionMethodTerm._pos_base_dist`` builds a real NumPyro distribution by recursively materialising the receiver and applying the deferred op (``to_event`` / ``expand``) directly. With this and the ``_is_eager`` fix, every inherited ``@defop`` method (``batch_shape``, ``event_shape``, ``support``, ``log_prob``, ``sample``, ...) resolves via the standard machinery — no parallel property/sample overrides needed. Adds two regression tests: - equational shape/support laws for ``.expand`` and ``.to_event`` on a free-variable receiver bound by an effectful handler - end-to-end MCMC over the literal #666 idiom (``Normal(mu_term, 1.0).expand([3]).to_event(1)`` with ``mu_term`` bound via ``handler``) Both fail on master with the expected error mode and pass on this branch.
|
Ready now! |
…ases Two cleanups: 1. `_DistributionMethodTerm._pos_base_dist` previously branched on ``self._op is _DistributionTerm.to_event`` / ``... is .expand`` to pick the materialisation. Replace the if-chain with a single ``getattr(base, self._op.__name__)(*self._args[1:], **self._kwargs)``. Same semantics for the two existing ops, and now correct for any future ``dist.Distribution``-returning method op without a per-op switch. 2. Remove ``xfail="to_event not implemented"`` from the ``Beta(...).to_event(k)`` and ``Dirichlet(...).to_event(k)`` parametrised cases; they pass now under this PR. The ``Independent(TransformedDistribution(...), k)`` case stays xfail but the reason is updated to ``"TransformedDistribution not implemented"`` to reflect the actual blocker (TransformedDistribution itself is unsupported in effectful).
The basic ``D.to_event(k)`` chain works after this PR, but ``test_dist_expand`` then composes ``.expand_by`` on the resulting ``_DistributionMethodTerm`` whose receiver carries indexed (named-dim) arrays. Materialising via the receiver's ``_pos_base_dist`` and then ``base.expand(...)`` mishandles those indices (``Cannot broadcast distribution of shape (5, 3) to shape (3,)``) — that's a deeper indexed-dim integration matter, out of scope for the #666 annotation fix. Restore the xfail with a more accurate reason.
| assert isinstance(chained, numpyro.distributions.Distribution) | ||
|
|
||
|
|
||
| def test_expand_to_event_shape_laws(): |
There was a problem hiding this comment.
This test doesn't expand a symbolic distribution, because mu is always concrete. Did it fail before?
|
Working on merging #613 and find a cleaner solution. |
|
@jfeser I ran into this problem of using import typing
from effectful.internals.unification import unify
T = typing.TypeVar("T")
class Foo: pass
class Bar(Foo): pass
assert unify(T, Bar) == {T: Bar} # FAILS: gets {T: Foo} So whenever you unify an inherited type with T, it resolves to the parent class. Should we default to principal type or widening in unifying/type inferring? |
|
That's definitely a bug. |
|
I am not sure about the premise of this PR. The latest proposed typing rule is not consistent with the signature of def expand(self, batch_shape: Sequence[int]) -> Distribution: ...I'm missing context on what you're trying to accomplish but I think it's unwise in general to attempt to monkey-patch typing behavior of upstream libraries within |
|
Sorry, that's my mistake. In that case, we might register |
|
Oh got it. Will revert back to The context was that I was attempting to factor out the interfacing between effectful symbols and numpyro here. So then you could do something like this: import jax.numpy as jnp
import jax.random as jr
from effectful.handlers.numpyro import Normal, HalfNormal
from effectful_mcmc import sample, MCMC, NUTS
def linear_regression(x, y):
alpha = sample(Normal(0.0, 10.0))
beta = sample(Normal(0.0, 10.0))
sigma = sample(HalfNormal(2.0))
sample(Normal(alpha + beta * x, sigma), obs=y)
return alpha, beta, sigma # return per-site effects
mcmc = MCMC(NUTS(linear_regression), num_warmup=500, num_samples=1000)
mcmc.run(jr.PRNGKey(0), x_data, y_data)
# Posteriors are keyed by the Operation behind each sample site;
# `term.op` recovers it. No string names to keep in sync.
alpha, beta, sigma = mcmc.model_return_value
samples = mcmc.get_samples()
print(float(samples[alpha.op].mean()),
float(samples[beta.op].mean()),
float(samples[sigma.op].mean()))Which is a bit more like normal Python. I was thinking something like this could be factored out of RoboTL. |
106ab34 to
9d79358
Compare
Closes #666.
Per @eb8680's feedback — sticking to NumPyro's
Distribution.expandsignature (-> Distribution) rather than monkey-patching it to-> Self. Reverted thetyping.Selfannotation attempt and the merge of #613. Back to registering adefdatacase fordist.Distribution(per @jfeser's suggestion in the same thread).Three small changes in
effectful/handlers/numpyro.py:_DistributionTerm.expandreturn annotation:jax.Array→dist.Distribution. Was routing.expand([J])through_ArrayTerm, breaking.to_event(1)chaining._DistributionTerm._is_eagerrecurses into_DistributionTerm-valued args (rather than treating them as non-eager viais_eager_array). Otherwise a fully-eagerD.to_event(k)would be tagged non-eager and every downstream property/sample method raised.defdata.register(dist.Distribution)mapping to_DistributionMethodTerm. Without it,dist.Distributionfalls through to_CallableTerm(sinceDistributiondefines__call__). The class overrides_pos_base_distto materialise bygetattr(base, self._op.__name__)(*args[1:], **kwargs), so every inherited@defop(batch_shape,event_shape,support,log_prob,sample, …) resolves through the standard machinery against the real NumPyro distribution.Three regression tests (each verified to fail on master):
AttributeErrorfrom_DistributionTerm.expandand.to_eventreturn_ArrayTerminstead ofDistribution, breaking standard NumPyro vectorization idiom #666).expand/.to_eventNormal(mu_term, 1.0).expand([3]).to_event(1)withmu_termbound viaeffectful.ops.semantics.handlerBeta(...).to_event(k)/Dirichlet(...).to_event(k)parametrised cases stay xfail — the basic chain works, but composing.expand_byon top with indexed (named-dim) arrays hits a deeper indexed-dim integration matter, out of scope for this PR.Lint clean (ruff / format / mypy).