Skip to content

Adds support for passing initial value to transform #2038

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jul 9, 2025

Conversation

tingiskhan
Copy link
Contributor

Changes made

Added support for passing a parameter called initial_value to RecursiveLinearTransform to facilitate forecasting when using the transform for latent variables.

Links to related PRs

#2037

Tests

I added one test case to the "general" part, and parameterized test_batched_recursive_linear_transform for testing arbitrary initial values.

Dependencies

No new dependencies introduced.

@tillahoffmann
Copy link
Collaborator

Thank you for the contribution, @tingiskhan! Do you think this could also be achieved using a composition of an AffineTransform and RecursiveLinearTransform?

@tingiskhan
Copy link
Contributor Author

Thank you for the contribution, @tingiskhan! Do you think this could also be achieved using a composition of an AffineTransform and RecursiveLinearTransform?

That’s a very good point, I’ll take a look at it next week (if you don’t beat me to it)!

@tingiskhan
Copy link
Contributor Author

tingiskhan commented Jun 20, 2025

@tillahoffmann So, I've given it some thought and see three ways forward (there might be more):

  1. Keep as is (i.e. pass initial value to RecursiveLinearTransform).
  2. Use a composition of AffineTransform and original RecursiveLinearTransform (as you suggested). The "issue" I see with this method is that we'll need to pass the length of the "series" on __init__ in order to construct the "loc" variable of AffineTransform correctly: loc * jnp.eye((length, 1)) (where loc corresponds to initial value). This would break the design of transforms a bit I think, since the transform will require the shape of sample on init.
  3. Construct a derived class from AffineTransform that only takes the initial value as input (corresponding to loc and pass 1 to scale) and then in __call__/_inverse handle the building of the offset given the sample (same as in 2.).

Of these three I think either 1. or 3. is the most appropriate. I'm not sure how re-usable the new component of 3. would be, but I'm open to either suggestions.

I'm not sure this PR is even needed to be honest, but the current way of passing initial values other than 0 is not so clear (might just require an entry to the docs?).

@fehiepsi
Copy link
Member

If initial value is 0, we have y_1=x_1, so we can just prepend the initial value to x and y?

@tingiskhan
Copy link
Contributor Author

@fehiepsi Yes, I suppose that would work as well! What solution would be preferred (if at all)?

@fehiepsi
Copy link
Member

fehiepsi commented Jul 4, 2025

Hi @tingiskhan, I think supporting initial values makes sense.

@tingiskhan
Copy link
Contributor Author

@fehiepsi, are these changes what you had in mind?

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM pending some nits. Thanks @tingiskhan!

@tingiskhan
Copy link
Contributor Author

@fehiepsi, I’m guessing the errors are unrelated to my changes?

@fehiepsi
Copy link
Member

fehiepsi commented Jul 5, 2025

Could you help update the tolerant of the failing tests?

@tingiskhan
Copy link
Contributor Author

@fehiepsi, the tests were failing because of the jax.live_arrays check since I was creating jax arrays in __init__ of RecursiveLinearTransform, and sampling an initial value in one of the parametrized tests. I had to rethink the logic of how and when we broadcast the initial value, see latest commit.

@tingiskhan
Copy link
Contributor Author

tingiskhan commented Jul 8, 2025

@fehiepsi, equinox released a new version yesterday. Pinned it (sort of) and adressed the exception by passing the mode parameter

@fehiepsi fehiepsi merged commit 6fd6045 into pyro-ppl:master Jul 9, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants