Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Widening the scope of the package and dropping support for batching #214

Merged
merged 108 commits into from
Feb 1, 2023

Conversation

torfjelde
Copy link
Member

@torfjelde torfjelde commented Feb 11, 2022

This PR is an attempt at a couple of things:

  1. Widening the scope, i.e. not being so restrictive regarding whether or not something is bijective.
  2. Removal of expected dimensionality of inputs from the type of the Bijector.
    • This had the benefit of disambiguating whether or not an input represents a single input or a collection of inputs, but it's super annoying to deal with both for the user and implementer + restricts us to only working with arrays, and transformations are expected to preserve dimensionality.
    • After this PR there is no longer "official" support for batching. Some bijectors work for batches because there's no need for disambiguation, but we don't claim that this works for all.
  3. We now also adopt the ChangesOfVariables.jl interface.

TODOs:

  • Should we remove (basically) unused and partially unsupported methods: logpdf_with_jac, logpdf_forward, forward(d::Distribution, ...)?
    • E.g. forward no longer works for multiple samples because we no longer have support for batched inputs.
    • Just removing these seems like the most straight-forward and useful approach atm, so we might as well.

@torfjelde
Copy link
Member Author

This is a fairly big one @devmotion , but I would greatly appreciate it if you had a super-quick look. The main part is just removing the dimensionality completely from the definition of the bijectors, in addition to a couple of small things:

  1. Adding mutating methods, e.g. transform!, logabsdetjac!, and with_logabsdet_jacobian! (should this maybe go to ChangesOfVariables.jl?).
  2. Added docs.
  3. Using elementwise(exp) to indicate that a method should be applied elementwise (equivalent to Base.Fix1(broadcast, exp)).

I'm also going to make a separate PR to remove the stuff related to ADbackend.

This will be a huge breaking release, as I think it's time we just rip the band-aid off.

@torfjelde
Copy link
Member Author

@yebai you might also want to take a look at this

docs/src/transforms.md Outdated Show resolved Hide resolved
docs/src/transforms.md Show resolved Hide resolved
InvertibleBatchNorm
Coupling,
InvertibleBatchNorm,
elementwise
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm... I wonder if we could reuse some existing functionality in the ecosystem here. And/or if there is a shorter name. Regarding the first point, e.g., Transducers.Map seems similar?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy to take suggestions, but I'm also okay with elementwise, so IMO this shouldn't hold this PR back.

Not too big of a fan to depend use Transducers.Map though; seems like unnecessary complexity just to make Base.Fix1(broadcast, f).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I don't want to depend on Transducers either. Unfortunately, it seems we can't just define a curried version

Base.map(t::Transform) = ... 

or

Base.broadcast(t::Transform) = ...

since we would like to use elementwise also for functions such as exp?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we would like to use elementwise also for functions such as exp?

Exactly 😕

src/Bijectors.jl Outdated
Comment on lines 268 to 271
Base.@deprecate NamedBijector(bs) NamedTransform(bs)

@noinline function Base.inv(b::AbstractBijector)
Base.depwarn("`Base.inv(b::AbstractBijector)` is deprecated, use `inverse(b)` instead.", :inv)
inverse(b)
end
Base.@deprecate Exp() elementwise(exp) false
Base.@deprecate Log() elementwise(log) false
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are deprecations needed if it is a breaking release? Or would it be sufficient to add them to some changelog/announcement/NEWS.md?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed!

src/bijectors/corr.jl Outdated Show resolved Hide resolved
src/interface.jl Outdated

Transform `x` using `b`, treating `x` as a single input.
"""
transform(f::Function, x) = f(x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if it matters, but Julia won't specialize on f here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah good catch!

src/interface.jl Outdated Show resolved Hide resolved
"""
isclosedform(b::Bijector)::bool
isclosedform(b⁻¹::Inverse{<:Bijector})::bool
logabsdetjac(b, x) = last(with_logabsdet_jacobian(b, x))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed? Seems like something that - if desired - should maybe go to ChangesOfVariables.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally find this useful, and the last time we spoke about this (JuliaMath/ChangesOfVariables.jl#3), there didn't seem to be a desire to add it 😕

src/interface.jl Outdated
Comment on lines 157 to 161
# Useful for checking if compositions, etc. are invertible or not.
Base.:+(::NotInvertible, ::Invertible) = NotInvertible()
Base.:+(::Invertible, ::NotInvertible) = NotInvertible()
Base.:+(::NotInvertible, ::NotInvertible) = NotInvertible()
Base.:+(::Invertible, ::Invertible) = Invertible()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would recommend removing these definitions. Seems like a misuse of +.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

"""
with_logabsdet_jacobian(b::Bijector, x) = (b(x), logabsdetjac(b, x))
inverse(t::Transform) = Inverse(t)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just remove this definition, and it should be sufficient to operate with InverseFunctions.NoInverse instead of Invertible/NoInvertible?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

@torfjelde
Copy link
Member Author

I think I've replied/addressed your comments now @devmotion

Some I'm of the opinion that we leave until later PRs, given how long this PR has been in the pipeline + a lot of improvements we want to do (e.g. adding a VecCorr bijector) are dependent on this PR making it's way through.

@devmotion
Copy link
Member

Yeah, let's just make another breaking release if it becomes necessary. I think I checked yesterday and Bijectors only has ~9 direct dependents, so it's not too bad if we iterate and release multiple breaking versions if needed (of course, it would be better to have the optimal design right away but that's completely unrealistic 😄).

@torfjelde
Copy link
Member Author

torfjelde commented Feb 1, 2023

Bueno! You "happy" with the current version of the PR then?

@torfjelde torfjelde merged commit 8b924d0 into master Feb 1, 2023
@delete-merged-branch delete-merged-branch bot deleted the tor/write-without-batch branch February 1, 2023 22:47
@yebai
Copy link
Member

yebai commented Feb 2, 2023

Thanks, @torfjelde @devmotion -- it looks good to me. I agree that we can keep improving the design in new PRs.

bors bot pushed a commit to TuringLang/DynamicPPL.jl that referenced this pull request Feb 3, 2023
This PR makes DPPL compatible with the changes to come in TuringLang/Bijectors.jl#214.

Tests are passing locally.

Closes #455 Closes #456
yebai added a commit to TuringLang/DynamicPPL.jl that referenced this pull request Mar 2, 2023
…eep existing compat) (#469)

* Fixed a typo in tutorial (#451)

* CompatHelper: bump compat for Turing to 0.24 for package turing, (keep existing compat) (#450)

This pull request changes the compat entry for the `Turing` package from `0.21` to `0.21, 0.24` for package turing.
This keeps the compat entries for earlier versions.



Note: I have not tested your package with this new compat entry.
It is your responsibility to make sure that your package tests pass before you merge this pull request.

Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>

* Some minor utility improvements (#452)

This PR does the following:
- Moves the `varname_leaves` from `TestUtils` to main module.
  - It can be very useful in Turing.jl for constructing `Chains` and the like, so I think it's a good idea to make it part of the main module rather than keeping it "hidden" there.
- Makes the default `varinfo` in the constructor of `LogDensityFunction` be `model.context` rather than a new `DynamicPPL.DefaultContext`.
  - The `context` pass to `evaluate!!` will override the leaf-context in `model.context`, and so the current default constructor always uses `DefaultContext` as the leaf-context, even if the `Model` has been `contextualize`d with some other leaf-context, e.g. `PriorContext`. This PR fixes this issue.

* Always run CI  (#453)

I find the current `bors` workflow a bit tedious. Most of the time, I summon `bors` to see the CI results (see e.g. #438). Given that most `CI` tests are quick (< 10mins), we can always run them by default. 

The most time-consuming `IntegrationTests` is still run by `bors` to avoid excessive CI runs.

* Compat with new Bijectors.jl (#454)

This PR makes DPPL compatible with the changes to come in TuringLang/Bijectors.jl#214.

Tests are passing locally.

Closes #455 Closes #456

* Another Bijectors.jl compat bound bump (#457)

* CompatHelper: bump compat for MCMCChains to 6 for package test, (keep existing compat) (#467)

This pull request changes the compat entry for the `MCMCChains` package from `4.0.4, 5` to `4.0.4, 5, 6` for package test.
This keeps the compat entries for earlier versions.



Note: I have not tested your package with this new compat entry.
It is your responsibility to make sure that your package tests pass before you merge this pull request.

Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>

* CompatHelper: bump compat for AbstractPPL to 0.6 for package test, (keep existing compat)

---------

Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>
Co-authored-by: github-actions[bot] <compathelper_noreply@julialang.org>
Co-authored-by: Tor Erlend Fjelde <tor.erlend95@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants