Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic rewrite of the package #25

Closed
wants to merge 29 commits into from
Closed

Basic rewrite of the package #25

wants to merge 29 commits into from

Conversation

theogf
Copy link
Member

@theogf theogf commented Feb 12, 2021

Hey, following #24 I am making a first attempt to make large changes (I don't believe this could be done by small changes and I don't think anyone is using this package atm anyway).

Here are the current main changes :

  • Introduction of a collection of variational distributions with specific parametrizations. This make it possible to change the variational parameters in-place and avoid passing a vector.
  • The problem of having to deal with a vector of variational parameters is gone since by using the reparametrization trick, we take the gradients of the samples and not of the parameters. I implement both approach one with a vector of parameters and one without
  • samples (and parametrized samples) are stored in so there is no additional allocation
  • optimize! is now a loop over step! which has to be defined for each algorithm
  • Instead of grad!, there is now gradlogπ! and eventually gradentropy. The first is computed via sampling, the second mostly has closed-form solutions but a generic fall back with samples can be made.
  • I entirely removed the VariationalObjective object which right now objectively does not bring anything (I guess if we start to consider more divergences we could start to add such objects) I left it and set the default behavior to use ELBO
  • I removed Tracker to put it as a Requires based package (does anyone still uses Tracker?)
  • [NEW] Similarly to AbstractMCMC, step! now takes a state as an argument which is initialized via init for each algorithm. This allows to use the right preallocations for each method!
  • [NEW] I added the basic ugly version of BBVI
  • [NEW] Started to adapt the framework to be compatible with Optimisers.jl (avoiding a dependency on Flux for optimisers)
  • More things to come

@yebai
Copy link
Member

yebai commented Feb 14, 2021

thanks @theogf, it looks like a great PR in progress. Just a clarification question, does DSVI = ADVI?

@theogf
Copy link
Member Author

theogf commented Feb 15, 2021

thanks @theogf, it looks like a great PR in progress. Just a clarification question, does DSVI = ADVI?

That's a good question, after checking the ADVI paper, ADVI is DSVI with bijectors (which is answering the question I had @torfjelde). However the previous implementation of ADVI was not reflecting what was done in the paper.

Reference: Automatic Differentiation Variational Inference https://arxiv.org/abs/1603.00788

@theogf theogf changed the title [WIP] Basic rewrite of the package with addition of DSVI [WIP] Basic rewrite of the package Feb 17, 2021
@theogf
Copy link
Member Author

theogf commented Feb 18, 2021

Here is a larger discussion on having q(theta) and/or q.

So I implemented ADVI and a basic version of BBVI. Interestingly, the first one relies on q and the second one on q(theta).
This leads me to think that the choice q(theta)/q literally depends of the algorithm used and should be treated as such.

The approach I took now for BBVI is to use the amazing state approach. Since we work with our own distributions (but this is not a restriction), I defined a to_vec(q) and to_dist(q, theta) to jump between the two representations. This way we create an initial theta in init that we update, and finally return to_dist(q, theta) at the end of the run.

So basically I would argue to only leave the q approach and to eventually deal with the q(theta) internally.

@coveralls
Copy link

coveralls commented Feb 18, 2021

Pull Request Test Coverage Report for Build 589667488

  • 100 of 153 (65.36%) changed or added relevant lines in 14 files are covered.
  • 20 unchanged lines in 1 file lost coverage.
  • Overall coverage decreased (-5.8%) to 53.695%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/interface.jl 19 21 90.48%
src/ad.jl 1 4 25.0%
src/distributions/cholmvnormal.jl 7 10 70.0%
src/algorithms/advi.jl 23 27 85.19%
src/compat/zygote.jl 0 6 0.0%
src/distributions/diagmvnormal.jl 8 14 57.14%
src/compat/reversediff.jl 0 7 0.0%
src/compat/tracker.jl 0 10 0.0%
src/distributions/distributions.jl 7 19 36.84%
Files with Coverage Reduction New Missed Lines %
src/optimisers.jl 20 0%
Totals Coverage Status
Change from base Build 387007800: -5.8%
Covered Lines: 109
Relevant Lines: 203

💛 - Coveralls

@theogf
Copy link
Member Author

theogf commented Feb 18, 2021

Since (most) tests are passing I am going to make it ready for review while I work on the tests and remove some unnecessary bits.

One part I am quite unhappy about is the way gradients are computed. Every algorithm require a different approach and I could not find a global approach. This means a huge load of code copy/paste to make it work with the different backends...
I would be happy to hear some suggestions

@theogf theogf marked this pull request as ready for review February 18, 2021 17:38
@theogf theogf changed the title [WIP] Basic rewrite of the package Basic rewrite of the package Feb 18, 2021
@theogf
Copy link
Member Author

theogf commented Feb 22, 2021

I solved the gradient issue with going back to grad! as a very generic inplace gradient computing function.
Tests are sometimes not passing because of their randomness, I don't know if there is a better approach here.

src/algorithms/bbvi.jl Outdated Show resolved Hide resolved
Copy link
Member

@torfjelde torfjelde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good stuff! There's quite a bit that needs addressing though.

I've left some comments on what for sure needs changing. I still need to look through again and think about how we can get to a complete solution here.

Some more general comments:

  1. I would really, really like to not reimplement a bunch of distributions in this package. Most of these very particular distributions could probably go into DistributionsAD.jl if we need them.
  2. There's a big focus on mutating operations here. It seems a bit unnecessary given that it's a matter of decreasing memory usage by a constant factor of 2, no? Or am I missing something here? Also, you've mentioned Optimisers.jl; isn't that moving towards non-mutation states for optimisers?

But yeah, after you've responded to my comments, I'll have to go through again 👍

src/ad.jl Outdated Show resolved Hide resolved
src/algorithms/advi.jl Outdated Show resolved Hide resolved
Comment on lines 55 to 56
update_mean!(q, vec(mean(Δ, dims = 2)), opt)
update_cov!(alg, q, Δ, state, opt)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Too general IMO. ADVI can be used with non-normal distributions as the underlying, in which case there is no such thing as cov.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well actually not, and that's where I am trying to be maybe more precise. The ADVI approach I implemented is following the ADVI paper which namely states :
image
So the underlying distribution is always a Gaussian.
I will add a reference to the paper in the docs

If we want to train something different it will have to be a different algorithm.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fully aware that the original paper only had this in mind, but it's unecessarily restricted and is mainly just a consequence of the restricted tools they had available at the time.

It's called "Automatic Differential Variational Inference" which really refers to the fact that we use AD + reparameterization trick.

IMO it just comes down to: what do it cost us to allow any reparameterization with any valid base distribution, and just let Gaussian be the default? That just seems overall superior, no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well in theory yes that sounds amazing. The problem comes from how do you formulate the reparametrization trick for non-Gaussian. It's not always possible right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Normalizing flows; affine transformation is just an instance:)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, we don't even need all these implementations of differnt Gaussians. We just make it different affine transformations, e.g. introduce a Affine <: Bijector and define different evaluations depending on whether the scaling-matrix is Cholesky, etc.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The above is the reasoning behind the current impl btw. Though I 100% agree we should make it more convenient to instantiate the different approaches.

nsamples(alg::ADVI) = alg.samples_per_step
niters(alg::ADVI) = alg.max_iters

function compats(::ADVI)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this specifying which distriubitons it's compatible with?
IMO this shouldn't be here; it will be difficult to keep track off + too restrictive (e.g. ADVI works for any distribution for which we can use the reparam trick, e.g. any TransformedDistribution) + it will be incorrect as soon as someone decides to extend functionality in a different package/user code. Instead we should let methods fail according to missing impls.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the previous point.

src/algorithms/advi.jl Outdated Show resolved Hide resolved
src/interface.jl Outdated
Comment on lines 72 to 85
## Verify that the algorithm can work with the corresponding variational distribution
function check_compatibility(alg, q)
if !compat(alg, q)
throw(ArgumentError("Algorithm $(alg) cannot work with distributions of type $(typeof(q)), compatible distributions are: $(compats(alg))"))
end
end

function compat(alg::VariationalInference, q)
return q isa compats(alg)
end

function compats(::Any)
return ()
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
## Verify that the algorithm can work with the corresponding variational distribution
function check_compatibility(alg, q)
if !compat(alg, q)
throw(ArgumentError("Algorithm $(alg) cannot work with distributions of type $(typeof(q)), compatible distributions are: $(compats(alg))"))
end
end
function compat(alg::VariationalInference, q)
return q isa compats(alg)
end
function compats(::Any)
return ()
end

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think my solution is awesome :D

end
include("gradients.jl")
include("interface.jl")
# include("optimisers.jl") # Relying on Tracker...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we not add Tracker to test/Project.toml?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is currently outdated, I need to have another look

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the problem is that it only works with Tracker and I had deleted it for a while

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am confused. Isn't using ForwardDiff? o.O

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See here :

)::Array{typeof(Tracker.data(Δ)), 1}

src/utils.jl Outdated Show resolved Hide resolved
Comment on lines 2 to 3
abstract type AbstractPosteriorMvNormal{T} <:
Distributions.ContinuousMultivariateDistribution end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed vs. AbstractMvNormal or whatever it's called?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was not aware of AbstractMvNormal until recently. This is still practical to have AbstractPosteriorMvNormal to define stuff like mean(d::AbstractPosteriorMvNormal) = d.\mu

Comment on lines 33 to 35
function to_vec(q::CholMvNormal)
vcat(q.μ, vec(q.Γ))
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we not make use of something similar to Flux.params here since we already depend on Functors.jl?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well you need to tell that to ForwardDiff.jl, ReverseDiff.jl and co :D

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are no ties between Flux.params and the AD-framework used anymore:)
It's using Functors.jl to define what's trainable and what isn't. So Flux.params should just work for the above distributions. Then you just vec and vcat.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry I misunderstood you, I thought you were complaining about the vcat.

@trappmartin
Copy link
Member

Is there a timeframe for this PR?

@theogf
Copy link
Member Author

theogf commented Sep 17, 2021

Is there a timeframe for this PR?

Not really, I think the problem is that it tries to put too many things at the same time.
Also there are some things where we are not entirely sure on how to proceed.

@torfjelde What do you think about having the variational distributions defined as bijected distributions in another PR?

@ParadaCarleton
Copy link
Member

Is there a timeframe for this PR?

Not really, I think the problem is that it tries to put too many things at the same time. Also there are some things where we are not entirely sure on how to proceed.

@torfjelde What do you think about having the variational distributions defined as bijected distributions in another PR?

What's missing?

@yebai
Copy link
Member

yebai commented Jun 10, 2023

Closed in favour of #45

@yebai yebai closed this Jun 10, 2023
@yebai yebai deleted the tg/rework_advi branch June 10, 2023 12:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants