-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Basic rewrite of the package #25
Conversation
thanks @theogf, it looks like a great PR in progress. Just a clarification question, does |
That's a good question, after checking the ADVI paper, ADVI is DSVI with bijectors (which is answering the question I had @torfjelde). However the previous implementation of ADVI was not reflecting what was done in the paper. Reference: Automatic Differentiation Variational Inference https://arxiv.org/abs/1603.00788 |
Here is a larger discussion on having So I implemented The approach I took now for So basically I would argue to only leave the |
Pull Request Test Coverage Report for Build 589667488
💛 - Coveralls |
Since (most) tests are passing I am going to make it ready for review while I work on the tests and remove some unnecessary bits. One part I am quite unhappy about is the way gradients are computed. Every algorithm require a different approach and I could not find a global approach. This means a huge load of code copy/paste to make it work with the different backends... |
I solved the gradient issue with going back to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good stuff! There's quite a bit that needs addressing though.
I've left some comments on what for sure needs changing. I still need to look through again and think about how we can get to a complete solution here.
Some more general comments:
- I would really, really like to not reimplement a bunch of distributions in this package. Most of these very particular distributions could probably go into DistributionsAD.jl if we need them.
- There's a big focus on mutating operations here. It seems a bit unnecessary given that it's a matter of decreasing memory usage by a constant factor of 2, no? Or am I missing something here? Also, you've mentioned Optimisers.jl; isn't that moving towards non-mutation states for optimisers?
But yeah, after you've responded to my comments, I'll have to go through again 👍
src/algorithms/advi.jl
Outdated
update_mean!(q, vec(mean(Δ, dims = 2)), opt) | ||
update_cov!(alg, q, Δ, state, opt) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Too general IMO. ADVI can be used with non-normal distributions as the underlying, in which case there is no such thing as cov
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well actually not, and that's where I am trying to be maybe more precise. The ADVI approach I implemented is following the ADVI paper which namely states :
So the underlying distribution is always a Gaussian.
I will add a reference to the paper in the docs
If we want to train something different it will have to be a different algorithm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fully aware that the original paper only had this in mind, but it's unecessarily restricted and is mainly just a consequence of the restricted tools they had available at the time.
It's called "Automatic Differential Variational Inference" which really refers to the fact that we use AD + reparameterization trick.
IMO it just comes down to: what do it cost us to allow any reparameterization with any valid base distribution, and just let Gaussian be the default? That just seems overall superior, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well in theory yes that sounds amazing. The problem comes from how do you formulate the reparametrization trick for non-Gaussian. It's not always possible right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Normalizing flows; affine transformation is just an instance:)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO, we don't even need all these implementations of differnt Gaussians. We just make it different affine transformations, e.g. introduce a Affine <: Bijector
and define different evaluations depending on whether the scaling-matrix is Cholesky
, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The above is the reasoning behind the current impl btw. Though I 100% agree we should make it more convenient to instantiate the different approaches.
src/algorithms/advi.jl
Outdated
nsamples(alg::ADVI) = alg.samples_per_step | ||
niters(alg::ADVI) = alg.max_iters | ||
|
||
function compats(::ADVI) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this specifying which distriubitons it's compatible with?
IMO this shouldn't be here; it will be difficult to keep track off + too restrictive (e.g. ADVI works for any distribution for which we can use the reparam trick, e.g. any TransformedDistribution
) + it will be incorrect as soon as someone decides to extend functionality in a different package/user code. Instead we should let methods fail according to missing impls.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See the previous point.
src/interface.jl
Outdated
## Verify that the algorithm can work with the corresponding variational distribution | ||
function check_compatibility(alg, q) | ||
if !compat(alg, q) | ||
throw(ArgumentError("Algorithm $(alg) cannot work with distributions of type $(typeof(q)), compatible distributions are: $(compats(alg))")) | ||
end | ||
end | ||
|
||
function compat(alg::VariationalInference, q) | ||
return q isa compats(alg) | ||
end | ||
|
||
function compats(::Any) | ||
return () | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
## Verify that the algorithm can work with the corresponding variational distribution | |
function check_compatibility(alg, q) | |
if !compat(alg, q) | |
throw(ArgumentError("Algorithm $(alg) cannot work with distributions of type $(typeof(q)), compatible distributions are: $(compats(alg))")) | |
end | |
end | |
function compat(alg::VariationalInference, q) | |
return q isa compats(alg) | |
end | |
function compats(::Any) | |
return () | |
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think my solution is awesome :D
end | ||
include("gradients.jl") | ||
include("interface.jl") | ||
# include("optimisers.jl") # Relying on Tracker... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we not add Tracker to test/Project.toml
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is currently outdated, I need to have another look
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah the problem is that it only works with Tracker and I had deleted it for a while
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Am confused. Isn't using ForwardDiff? o.O
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See here :
AdvancedVI.jl/src/optimisers.jl
Line 42 in 808acbe
)::Array{typeof(Tracker.data(Δ)), 1} |
src/distributions/distributions.jl
Outdated
abstract type AbstractPosteriorMvNormal{T} <: | ||
Distributions.ContinuousMultivariateDistribution end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this needed vs. AbstractMvNormal
or whatever it's called?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was not aware of AbstractMvNormal
until recently. This is still practical to have AbstractPosteriorMvNormal
to define stuff like mean(d::AbstractPosteriorMvNormal) = d.\mu
src/distributions/cholmvnormal.jl
Outdated
function to_vec(q::CholMvNormal) | ||
vcat(q.μ, vec(q.Γ)) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we not make use of something similar to Flux.params
here since we already depend on Functors.jl
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well you need to tell that to ForwardDiff.jl, ReverseDiff.jl and co :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are no ties between Flux.params
and the AD-framework used anymore:)
It's using Functors.jl to define what's trainable and what isn't. So Flux.params
should just work for the above distributions. Then you just vec
and vcat
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah sorry I misunderstood you, I thought you were complaining about the vcat
.
Is there a timeframe for this PR? |
Not really, I think the problem is that it tries to put too many things at the same time. @torfjelde What do you think about having the variational distributions defined as bijected distributions in another PR? |
What's missing? |
Closed in favour of #45 |
Hey, following #24 I am making a first attempt to make large changes (I don't believe this could be done by small changes and I don't think anyone is using this package atm anyway).
Here are the current main changes :
The problem of having to deal with a vector of variational parameters is gone since by using the reparametrization trick, we take the gradients of the samples and not of the parameters.I implement both approach one with a vector of parameters and one withoutoptimize!
is now a loop overstep!
which has to be defined for each algorithmgrad!
, there is nowgradlogπ!
and eventuallygradentropy
. The first is computed via sampling, the second mostly has closed-form solutions but a generic fall back with samples can be made.I entirely removed theI left it and set the default behavior to useVariationalObjective
object which right now objectively does not bring anything (I guess if we start to consider more divergences we could start to add such objects)ELBO
Requires
based package (does anyone still uses Tracker?)AbstractMCMC
,step!
now takes astate
as an argument which is initialized viainit
for each algorithm. This allows to use the right preallocations for each method!