-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementing a custom bijector is a hassle: solve by adding macro? #137
Comments
IMO a macro seems to complicated and leads to un-julian syntax. I also don't think it is necessarily easier for users to figure out how to write the macro than just implementing the three functions currently. I'm not sure if
(f::Bijector)(x) = unthunk(forward(f, x).rv)
logabsdetjac(f::Bijector, x) = unthunk(forward(f, x).logabsdetjac) (BTW I'm not sure about these names, maybe just make it a tuple or use something else than
function Bijectors.forward(f::MyCoolBijector, x)
....
return (rv = ..., logabsdetjac = ...)
end possibly using
|
But the "un-Julian syntax" is mostly due to the fact that we're dropping the
That's true but the goal here isn't to make it easier to understand, but easier to go from "I want this bijector" to "I have this bijector". Using a macro we could make it so that there is a minimal amount of work on the user, in addition to getting the most efficient implementation for all the necessary functions. E.g.
It's part of the API 👍 And there are cases where it's definitively worth it, e.g.
I'm potentially for this. But it's worth noting that Bijectors.jl still works for Julia <1.3, it's just that we don't test properly + certain AD-backends doesn't work. This introduction would completely break Bijectors.jl for Julia <1.3.
You're thinking along the lines of function _forward(b, x)
rv = @thunk ...
logabsdetjac = @thunk ...
return (rv = rv, logabsdetjac = logabsdetjac)
end
forward(b, x) = unthunk(_forward(b, x))
(b::Bijector)(x) = unthunk(_forward(b, x).rv)
logabsdetjac(b, x) = unthunk(_forward(b, x).logabsdetjac) right? I'd argue that this is both a) more complicated to understand for the user, b) way worse performance as closures have comparatively significant overhead.
I'm in favour of the suggestion, but it seems like a slightly different issue, no? |
Agree, but also separate issue. We discussed renaming |
Currently there a couple of annoyances when implementing a new
Bijector
:(b::Bijector)(x::T)
cannot be implemented for abstract types on Julia <1.3. This means that we have to implement batch-computations on a case-by-case basis, which is both annoying and sometimes difficult to do in a AD-friendly + type-stable way (we have a bunch ofmapvcat
andeachcolmaphcat
methods to do this, which is an unnecessary complication for a newcomer).transform(b, x)
method as the "evaluation" method, as this would allow us to have more generic implementations for batching, etc. But we decided not to do that, as it also felt clunky.forward(b::Bijector, x)
is supposed to allow the user to share computation between the evaluation, i.e.(b::MyBijector)(x)
, andlogabsdetjac(b, x)
. Buuut it's annoying to have to first implement(b::MyBijector)(x)
andlogabsdetjac(b, x)
, which are mandatory, and then have to go through these methods to figure out what is shared and then copy-paste certain parts to atransform
method, etc.Since we're in Julia, my first idea is of course to throw a macro at problem! I'm thinking introduce
transfrom
but make it -super-easy for the user to define everything in one go. I.e. something along the lines of:which is then transformed into something along the lines of
Then the only thing that is left for the user to implement is the inverse evaluation.
Also, I do have a somewhat "dirty" implementation ready (from which the above output was generated +
MacroTools.prettify
): https://gist.github.com/torfjelde/8675bba686afdf693476ae1c70f516d3.This would then allow us to easily transition to
transform
, thus ensuring compatibility with Julia <1.3 but still using more generic methods, i.e.transform(b::Bijector{0}, x::AbstractVector) = b.(x)
. It would make it super-easy to share computation inforward
. Finally, we could start thinking about adding in complementary inplace methods, e.g.transform!(b::Bijector, x, out)
,logabsdetjac!(b::Bijector, x, out)
, etc, as a next step.The only question is: are we overcomplicating things here? Is there an easier way of achieving what we want?
The text was updated successfully, but these errors were encountered: