Skip to content

scan over discrete latent variables causes tracer leak #1998

@fehiepsi

Description

@fehiepsi
Member

Bug Description

This is a part of the issues reported in #1981. Running the following test will raise an error/xfail.

Steps to Reproduce

JAX_CHECK_TRACER_LEAKS=1 pytest -vs test/contrib/test_infer_discrete.py::test_scan_hmm_smoke

Expected Behavior

The test should pass.

Activity

fehiepsi

fehiepsi commented on Mar 7, 2025

@fehiepsi
MemberAuthor

The reason seems to be caused by this line

with funsor.adjoint.AdjointTape() as tape:

where the stateful adjoint tape is not compatible with jax scan.

Switching back to lazy interpretations seems to fix the leakage but it makes some tests failing.

-    with funsor.adjoint.AdjointTape() as tape:
+    with funsor.interpretations.lazy:
         with block(), enum(first_available_dim=first_available_dim):
             log_prob, model_tr, log_measures = _enum_log_density(
                 model, args, kwargs, {}, sum_op, prod_op
             )
 
     with approx:
-        approx_factors = tape.adjoint(sum_op, prod_op, log_prob)
+        approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob)
pscicluna

pscicluna commented on Jun 10, 2025

@pscicluna

I wondered what the status on this is - I'm trying to do SVI with discrete latents, and I think the errors I'm getting (tracer leaks that make it impossible to run) come from this bug. I guess I have 3 questions:
Am I right to think that the fixes in #2002 deal with this problem, or is it actually #1999?
Is the fork in a state where it's reasonable to install from and use it?
And is there anything I could do to help you get the PR over the line?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Participants

    @pscicluna@fehiepsi

    Issue actions

      scan over discrete latent variables causes tracer leak · Issue #1998 · pyro-ppl/numpyro