Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic rewrite of the package 2023 edition #45

Closed
wants to merge 147 commits into from

Conversation

Red-Portal
Copy link
Member

@Red-Portal Red-Portal commented Mar 14, 2023

Hi, this is the initial pull request for the rewrite of AdvancedVI as a successor to #25

The following panel will be updated in real-time, reflecting the discussions happening below.

Roadmap

  • Change the gradient computation interface such that different algorithms can directly manipulate the gradients.
  • Migrate to the LogDensityProblems interface.
  • Migrate to AbstractDifferentiations. Not mature enough yet.
  • Use the ADTypes interface.
  • Use Functor.jl for flattening/unflattening variational parameters.
  • Add more interfaces for calling optimize. (see Missing API method  #32 )
  • Add pre-packaged variational families.
    • location-scale family
    • Reduce memory usage of full-rank parameterization (Seems like there's a unfavorable compute-memory trade-off. See this thread)
  • Migrate to Optimisers.jl.
  • Implement minibatch subsampling (probably require changes upstream, e.g., DynamicPPL, too) (separate issue)
  • Add callback option (Callback function during training #5)
  • Add control variate interface
  • Add BBVI (score gradient) not urgent
  • Tests
  • Benchmarks
    • Compare performance against the current version.
    • ~~ Compare against competing libraries (e.g., Numpyro, Stan, and probably a bare-bones Julia, C++ implementation.~~
  • Support GPU computation (although Bijectors will be a bottleneck for this. (separate issue)

Topics to Discuss

  • Should we use AbstractDifferentiation? Not now.
  • ✔️ Should we migrate to Optimisers? (probably yes)
  • ✔️ Should we call restructure inside of optimize such that the flattening/unflattening is completely abstracted out to the user? Then, in the current state of things, Flux will have to be added as a dependency, otherwise we'll have to roll our own implementation of destructure. destructure is now part of Optimisers, which is much more lightweight.
  • ✔️ Should we keep TruncatedADAGrad, DecayedADAGrad? I think these are quite outdated and would advise people from using these. So how about deprecating these? Planning to deprecate.

Demo

    using Turing
    using Bijectors
    using Optimisers
    using ForwardDiff
    using ADTypes

    import AdvancedVI as AVI

    μ_y, σ_y = 1.0, 1.0
    μ_z, Σ_z = [1.0, 2.0], [1.0 0.; 0. 2.0]

    Turing.@model function normallognormal()
        y ~ LogNormal(μ_y, σ_y)
        z ~ MvNormal(μ_z, Σ_z)
    end
    model   = normallognormal()
    b       = Bijectors.bijector(model)
    b⁻¹     = inverse(b)
    prob    = DynamicPPL.LogDensityFunction(model)
    d       = LogDensityProblems.dimension(prob)

    μ = randn(d)
    L = Diagonal(ones(d))
    q = AVI.MeanFieldGaussian(μ, L)

    n_max_iter = 10^4
    q, stats = AVI.optimize(
        AVI.ADVI(prob, b⁻¹, 10),
        q,
        n_max_iter;
        adbackend = AutoForwardDiff(),
        optimizer = Optimisers.Adam(1e-3)
    )

@Red-Portal Red-Portal added enhancement New feature or request help wanted Extra attention is needed labels Mar 14, 2023
@Red-Portal Red-Portal changed the title Basic rewrite of the package 2023 edition [WIP] [WIP] Basic rewrite of the package 2023 edition Mar 14, 2023
@Red-Portal Red-Portal removed enhancement New feature or request help wanted Extra attention is needed labels Mar 14, 2023
src/estimators/advi.jl Outdated Show resolved Hide resolved
src/AdvancedVI.jl Outdated Show resolved Hide resolved
@Red-Portal Red-Portal marked this pull request as draft March 16, 2023 20:55
This is to avoid having to reconstruct transformed distributions all
the time. The direct use of bijectors also avoids going through lots
of abstraction layers that could break.

Instead, transformed distributions could be constructed only once when
returing the VI result.
@torfjelde
Copy link
Member

I'll have a look at the PR itself later, but for now:

AdvancedVI.jl naively reconstructs/deconstructs MvNormal from its variational parameters. This is okay from the mean-field parameterization, but for full-rank or non-diagonal covariance parameterization, this is a little more complicated since MvNormal in Distributions.jl asks for a PDMat. So the variational parameters must first be converted to a matrix, then to a PDMat, and then fed to MvNormal. For high-dimensional problems, not sure if this is ideal.

In relation to the topic above, I'm starting to believe that implementing our custom distribution (just as @theogf previously did in #25) might be a good idea in terms of performance, especially for reparameterization-trick-based methods. However, instead of reinventing the wheel (by implementing every distribution in existence) or tying ourselves to a small number of specific distributions (a custom MvNormal that is), I think implementing a single general LocationScale distribution would be feasible, where the user provides the underlying univariate base distribution. Through this, we could support distributions like the multivariate Laplace that are not even supported in Distributions.jl with a single general object.

Maybe we should make this into a discussion. I feel like there are several different approaches we can take here.

For flattening the parameters, @theogf has proposed ParameterHandling.jl. But it currently does not work well with AD. The current alternative is ModelWrappers.jlk, but it comes with many dependencies, potentially a governance topic.

For this one in particular we have an implementation in DynamicPPL that can potentially moved to its own package if we really want to: https://github.com/TuringLang/DynamicPPL.jl/blob/b23acff013a9111c8ce2c89dbf5339e76234d120/src/utils.jl#L434-L473

But this has a couple of issues:

  1. Requires 2n memory, since we can't release the original object (we need it as the first argument for construction since these things often depend on runtime information, e.g. the dimensionality of a MvNormal).
  2. Can't specialize on which parameters we actually want, e.g. maybe we only want to learn the mean-parameter for a MvNormal.

(1) can be addressed by instead taking a closure-approach a la Functors.jl:

function flatten(d::MvNormal{<:AbstractVector,<:Diagonal})
    dim = length(d)
    function MvNormal_unflatten(x)
        return MvNormal(d[1:dim], Diagonal(d[dim+1:end]))
    end

    return vcat(d.μ, diag(d.Σ)), MvNormal_unflatten
end

For (2), we have a couple of immediate options:
a) Define "wrapper" distributions.
b) Take a contextual dispatch approach.

For (a) we'd have something like:

abstract type WrapperDistribution{D<:Distribution{V,F}} <: Distribution{V,F} end

# HACK: Probably shouldn't do this.
inner_dist(x::WrapperDistribution) = x.inner

# TODO: Specialize further on `x` to avoid hitting default implementations?
Distributions.logpdf(d::WrapperDistribution, x) = logpdf(d.dist, x)
# Etc.

struct MeanParameterized{D} <: WrapperDistribution{D}
    inner::D
end

function flatten(d::MeanParameterized{<:MvNormal})
    μ = mean(d.inner)
    function MeanParameterized_MvNormal_unflatten(x)
        return MeanParameterized(MvNormal(x, d.inner.Σ))
    end

    return μ, MeanParameterized_MvNormal_unflatten
end

Pros:

  • It's fairly simple to implement.
    Cons:
  • Requires wrapping all the distributions all the time.
  • Nice until we have other sort of nested distributions in which case this can get real ugly real fast.

For (b) we'd have something like

struct MeanOnly end

function flatten(::MeanOnly, d::MvNormal)
    μ = mean(d.inner)
    function MvNormal_meanonly_unflatten(x)
        return MeanParameterized(MvNormal(x, d.inner.Σ))
    end

    return μ, MvNormal_meanonly_unflatten
end

Pros:

  • Cleaner as it avoids nesting.
  • Can easily support "wrapper" distributions since it can just pass the context downwards.
    Cons:
  • Somewhat unclear to me how to make all this composable, e.g. how do we handle arbitrary structs containing distributions?

@Red-Portal
Copy link
Member Author

Red-Portal commented Mar 23, 2023

Hi @torfjelde

Maybe we should make this into a discussion. I feel like there are several different approaches we can take here.

Should we proceed here or create a separate issue?

Whatever approach we take, I think the key would be to avoid inverting or even computing the covariance matrix, provided that we operate with a Cholesky factor. None of the steps of ADVI require any of these, except for the STL estimator, where we do need to invert the Cholesky factor.

@torfjelde
Copy link
Member

Created a discussion: #46

src/estimators/advi.jl Outdated Show resolved Hide resolved
src/estimators/advi.jl Outdated Show resolved Hide resolved
src/estimators/advi.jl Outdated Show resolved Hide resolved
src/AdvancedVI.jl Outdated Show resolved Hide resolved
src/AdvancedVI.jl Outdated Show resolved Hide resolved
@Red-Portal
Copy link
Member Author

Red-Portal commented Jun 9, 2023

@torfjelde Hi, I have significantly changed the sketch for the project structure.

  1. As you previously suggested, the ELBO objective is now formed in a modular way.
  2. I've also migrated to use AbstractDifferentiation instead of rolling our custom differentiation glue functions.

Any comments on the new structure? Also, do you approve the use of AbstractDifferentiation?

@yebai
Copy link
Member

yebai commented Jun 9, 2023

Also, do you approve the use of AbstractDifferentiation?

@devmotion what are your current thoughts on AbstractDifferentiation?

@Red-Portal
Copy link
Member Author

Red-Portal commented Jun 9, 2023

I've now added the pre-packaged location-scale family. Overall, to the user, the basic interface looks like the following:

    μ_y, σ_y = 1.0, 1.0
    μ_z, Σ_z = [1.0, 2.0], [1.0 0.; 0. 2.0]

    Turing.@model function normallognormal()
        y ~ LogNormal(μ_y, σ_y)
        z ~ MvNormal(μ_z, Σ_z)
    end
    model   = normallognormal()
    b       = Bijectors.bijector(model)
    b⁻¹     = inverse(b)
    prob    = DynamicPPL.LogDensityFunction(model)
    d       = LogDensityProblems.dimension(prob)

    μ = randn(d)
    L = Diagonal(ones(d))
    q = AVI.MeanFieldGaussian(μ, L)

    λ₀, restructure  = Flux.destructure(q)

    function rebuild(λ′)
        restructure(λ′)
    end
    λ = AVI.optimize(
        AVI.ADVI(prob, b⁻¹, 10),
        rebuild,
        10000,
        λ₀;
        optimizer = Flux.ADAM(1e-3),
        adbackend = AutoForwardDiff()
    )
    q = restructure(λ)

    μ = q.transform.outer.a
    L = q.transform.inner.a
    Σ = L*L'

    μ_true      = vcat(μ_y, μ_z)
    Σ_diag_true = vcat(σ_y, diag(Σ_z))

    @info("VI Estimation Error",
          norm- μ_true),
          norm(diag(Σ) - Σ_diag_true),)

Some additional notes to the comments above,

  1. Should we call restructure inside of optimize such that the flattening/unflattening is completely abstracted out to the user? Then, in the current state of things, Flux will have to be added as a dependency, otherwise we'll have to roll our own implementation of destructure.
  2. Should we keep TruncatedADAGrad, DecayedADAGrad? I think these are quite outdated and would advise people from using these. So how about deprecating these?
  3. We should probably migrate to Optimisers.jl. The current optimization infrastructure is quite old.

@yebai
Copy link
Member

yebai commented Dec 22, 2023

@Red-Portal is there anything in this PR not yet merged by #49 and #50?

@Red-Portal
Copy link
Member Author

@yebai Yes, we have the documentation still left. I'm currently working on it.

@yebai yebai closed this Jun 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants