Skip to content

Fix _DistributionTerm.expand return-type annotation#667

Open
datvo06 wants to merge 8 commits into
masterfrom
fix/distribution-term-expand-return-type
Open

Fix _DistributionTerm.expand return-type annotation#667
datvo06 wants to merge 8 commits into
masterfrom
fix/distribution-term-expand-return-type

Conversation

@datvo06
Copy link
Copy Markdown
Contributor

@datvo06 datvo06 commented May 26, 2026

Closes #666.

Per @eb8680's feedback — sticking to NumPyro's Distribution.expand signature (-> Distribution) rather than monkey-patching it to -> Self. Reverted the typing.Self annotation attempt and the merge of #613. Back to registering a defdata case for dist.Distribution (per @jfeser's suggestion in the same thread).

Three small changes in effectful/handlers/numpyro.py:

  1. _DistributionTerm.expand return annotation: jax.Arraydist.Distribution. Was routing .expand([J]) through _ArrayTerm, breaking .to_event(1) chaining.

  2. _DistributionTerm._is_eager recurses into _DistributionTerm-valued args (rather than treating them as non-eager via is_eager_array). Otherwise a fully-eager D.to_event(k) would be tagged non-eager and every downstream property/sample method raised.

  3. defdata.register(dist.Distribution) mapping to _DistributionMethodTerm. Without it, dist.Distribution falls through to _CallableTerm (since Distribution defines __call__). The class overrides _pos_base_dist to materialise by getattr(base, self._op.__name__)(*args[1:], **kwargs), so every inherited @defop (batch_shape, event_shape, support, log_prob, sample, …) resolves through the standard machinery against the real NumPyro distribution.

Three regression tests (each verified to fail on master):

Beta(...).to_event(k) / Dirichlet(...).to_event(k) parametrised cases stay xfail — the basic chain works, but composing .expand_by on top with indexed (named-dim) arrays hits a deeper indexed-dim integration matter, out of scope for this PR.

Lint clean (ruff / format / mypy).

`expand` was `@defop`-annotated to return `jax.Array`, which routed its
result through `_ArrayTerm` and broke chained distribution methods like
`Normal(mu_term, 1.0).expand([J]).to_event(1)`.

Annotate the return type as `dist.Distribution` so the result is no
longer dispatched as an array. Note: `to_event` was already correctly
annotated.
@datvo06
Copy link
Copy Markdown
Contributor Author

datvo06 commented May 26, 2026

Not ready yet. Reproduction in #666 still fails, now as (_CallableTerm instead of _ArrayTerm previously). Also havent added the reproduction test.

datvo06 added 3 commits May 26, 2026 16:30
Without this, `defdata` dispatch on `dist.Distribution` falls through to
`collections.abc.Callable` (since `Distribution` defines `__call__`) and
produces `_CallableTerm`, breaking chained distribution methods like
`Normal(mu_term, 1.0).expand([J]).to_event(1)`.

The fallback subclasses `_DistributionTerm` and carries forward the
receiver's distribution family, so further chained methods remain
available on the resulting term. Add a regression test pinning the
chain-construction behavior.

Note: this resolves the AttributeError reported in #666. Running the
full repro under MCMC still fails downstream because other distribution
property defops (`support`, `batch_shape`, ...) raise `NotHandled` on
non-eager receivers and route through the same defdata pattern — a
broader integration matter, not the chain-construction bug.
The original issue framed the bug as blocking "the standard NumPyro
vectorised hierarchical-model idiom". Add a test that exercises the
idiomatic form (`numpyro.plate` around a sample) end-to-end under MCMC
and asserts sample shapes. Companion to the existing narrow chain test.
The plate-based MCMC test was using `Normal(mu, 1.0)` where `mu` is a
sampled real array, so the receiver was fully eager and `.expand` /
`.to_event` were never called. The test passed on master too — it
didn't pin anything #666-specific. The narrow chain test does the
actual job and is verified to fail on master.
@datvo06
Copy link
Copy Markdown
Contributor Author

datvo06 commented May 26, 2026

Fixed expand().to_event(), but for full MCMC examples in #666, I'd need to update it, since NumPyro'd eventually need mu() actual value.

Per review feedback, no new term classes. Two minimal changes:

1. ``_DistributionTerm._is_eager`` (on the base) now recurses into Term
   arguments that are themselves ``_DistributionTerm``s, consulting their
   ``_is_eager``. Previously any ``_DistributionTerm``-valued arg was
   treated as non-eager (since ``is_eager_array`` returned False for
   Distribution Terms), which prevented downstream methods from
   resolving even when the underlying base was eager.

2. ``_DistributionMethodTerm._pos_base_dist`` builds a real NumPyro
   distribution by recursively materialising the receiver and applying
   the deferred op (``to_event`` / ``expand``) directly. With this and
   the ``_is_eager`` fix, every inherited ``@defop`` method
   (``batch_shape``, ``event_shape``, ``support``, ``log_prob``,
   ``sample``, ...) resolves via the standard machinery — no parallel
   property/sample overrides needed.

Adds two regression tests:
- equational shape/support laws for ``.expand`` and ``.to_event`` on a
  free-variable receiver bound by an effectful handler
- end-to-end MCMC over the literal #666 idiom
  (``Normal(mu_term, 1.0).expand([3]).to_event(1)`` with ``mu_term``
  bound via ``handler``)

Both fail on master with the expected error mode and pass on this branch.
@datvo06
Copy link
Copy Markdown
Contributor Author

datvo06 commented May 26, 2026

Ready now!

@datvo06 datvo06 marked this pull request as ready for review May 26, 2026 21:30
@datvo06 datvo06 requested a review from jfeser May 26, 2026 21:30
datvo06 added 3 commits May 26, 2026 17:36
…ases

Two cleanups:

1. `_DistributionMethodTerm._pos_base_dist` previously branched on
   ``self._op is _DistributionTerm.to_event`` / ``... is .expand`` to
   pick the materialisation. Replace the if-chain with a single
   ``getattr(base, self._op.__name__)(*self._args[1:], **self._kwargs)``.
   Same semantics for the two existing ops, and now correct for any
   future ``dist.Distribution``-returning method op without a per-op
   switch.

2. Remove ``xfail="to_event not implemented"`` from the
   ``Beta(...).to_event(k)`` and ``Dirichlet(...).to_event(k)``
   parametrised cases; they pass now under this PR. The
   ``Independent(TransformedDistribution(...), k)`` case stays xfail but
   the reason is updated to ``"TransformedDistribution not implemented"``
   to reflect the actual blocker (TransformedDistribution itself is
   unsupported in effectful).
The basic ``D.to_event(k)`` chain works after this PR, but
``test_dist_expand`` then composes ``.expand_by`` on the resulting
``_DistributionMethodTerm`` whose receiver carries indexed (named-dim)
arrays. Materialising via the receiver's ``_pos_base_dist`` and then
``base.expand(...)`` mishandles those indices (``Cannot broadcast
distribution of shape (5, 3) to shape (3,)``) — that's a deeper
indexed-dim integration matter, out of scope for the #666 annotation
fix. Restore the xfail with a more accurate reason.
Comment thread effectful/handlers/numpyro.py
assert isinstance(chained, numpyro.distributions.Distribution)


def test_expand_to_event_shape_laws():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test doesn't expand a symbolic distribution, because mu is always concrete. Did it fail before?

@datvo06 datvo06 marked this pull request as draft May 27, 2026 14:32
@datvo06
Copy link
Copy Markdown
Contributor Author

datvo06 commented May 27, 2026

Working on merging #613 and find a cleaner solution.

@datvo06
Copy link
Copy Markdown
Contributor Author

datvo06 commented May 27, 2026

@jfeser I ran into this problem of using unify on an inherited type:

  import typing                                                                                                                                                                                                                                                       
  from effectful.internals.unification import unify                                                                                                                                                                                                                   
  
  T = typing.TypeVar("T")                                                                                                                                                                                                                                             
                                                                                                                                                                                                                                                                    
  class Foo: pass
  class Bar(Foo): pass

  assert unify(T, Bar) == {T: Bar}      # FAILS: gets {T: Foo}          

So whenever you unify an inherited type with T, it resolves to the parent class. Should we default to principal type or widening in unifying/type inferring?
This is blocking the Self-based routing, because then Normal(mu()).expand(s) would result in dist.Distribution instead of NormalTerm via Self -> dist.Normal anddispatch falls back to _CallableTerm.
Which raises another error.

@datvo06 datvo06 mentioned this pull request May 27, 2026
@jfeser
Copy link
Copy Markdown
Contributor

jfeser commented May 27, 2026

That's definitely a bug.

@eb8680
Copy link
Copy Markdown
Contributor

eb8680 commented May 27, 2026

I am not sure about the premise of this PR. The latest proposed typing rule is not consistent with the signature of numpyro.distributions.Distribution.expand, which only promises to return a Distribution, not something of the same type as self:

def expand(self, batch_shape: Sequence[int]) -> Distribution: ...

I'm missing context on what you're trying to accomplish but I think it's unwise in general to attempt to monkey-patch typing behavior of upstream libraries within effectful like this and I'd strongly recommend sticking to the NumPyro signature for expand.

@jfeser
Copy link
Copy Markdown
Contributor

jfeser commented May 27, 2026

Sorry, that's my mistake. In that case, we might register _DistributionTerm as a defdata case for dist.Distribution.

@datvo06
Copy link
Copy Markdown
Contributor Author

datvo06 commented May 27, 2026

Oh got it. Will revert back to _DistributionTerm defdata

The context was that I was attempting to factor out the interfacing between effectful symbols and numpyro here. So then you could do something like this:

import jax.numpy as jnp
import jax.random as jr

from effectful.handlers.numpyro import Normal, HalfNormal
from effectful_mcmc import sample, MCMC, NUTS

def linear_regression(x, y):
    alpha = sample(Normal(0.0, 10.0))
    beta  = sample(Normal(0.0, 10.0))
    sigma = sample(HalfNormal(2.0))
    sample(Normal(alpha + beta * x, sigma), obs=y)
    return alpha, beta, sigma                      # return per-site effects

mcmc = MCMC(NUTS(linear_regression), num_warmup=500, num_samples=1000)
mcmc.run(jr.PRNGKey(0), x_data, y_data)

# Posteriors are keyed by the Operation behind each sample site;
# `term.op` recovers it. No string names to keep in sync.
alpha, beta, sigma = mcmc.model_return_value
samples = mcmc.get_samples()
print(float(samples[alpha.op].mean()),
      float(samples[beta.op].mean()),
      float(samples[sigma.op].mean()))

Which is a bit more like normal Python. I was thinking something like this could be factored out of RoboTL.

@datvo06 datvo06 force-pushed the fix/distribution-term-expand-return-type branch from 106ab34 to 9d79358 Compare May 30, 2026 19:40
@datvo06 datvo06 marked this pull request as ready for review June 1, 2026 18:06
@datvo06 datvo06 requested a review from jfeser June 1, 2026 18:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

_DistributionTerm.expand and .to_event return _ArrayTerm instead of Distribution, breaking standard NumPyro vectorization idiom

3 participants