-
Notifications
You must be signed in to change notification settings - Fork 19
Add the forward-backward Wasserstein Gaussian variational inference algorithm #210
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
342eaf5
refactor split shared batch reshuffling into `src/reshuffling.jl`
Red-Portal ba28e2e
Merge branch 'main' of github.com:TuringLang/AdvancedVI.jl into wasse…
Red-Portal f9ddebb
refactor move `SubsampledNormals` code under `test/models/`
Red-Portal 563bb5d
fix wrong capability for subsamplednormals
Red-Portal 6674b42
fix remove unused fields
Red-Portal bd0b889
fix wrong variable name, add missing rng argument
Red-Portal fec5055
Merge branch 'main' of github.com:TuringLang/AdvancedVI.jl into wasse…
Red-Portal bb2fa3a
fix missing `dimension` and wrong variance in subsampled normals
Red-Portal 7b6246f
add Wasserstein VI algorithm
Red-Portal a3201f8
add tests for Wasserstein VI
Red-Portal f478e15
add docs for Wasserstein VI
Red-Portal 9583ef9
fix formatting
Red-Portal ba31866
fix formatting
Red-Portal 455118f
fix formatting
Red-Portal 8321555
fix typos
Red-Portal c985932
add optional Stein's identity expected Hessian estimator
Red-Portal 17c1069
fix test also capability in subsampling convergence test
Red-Portal 283961f
update docs remove comment that hessian is required for fwdbwdwass
Red-Portal 724103c
update docs
Red-Portal f8ca2be
apply formatter
Red-Portal b478db9
apply formatter
Red-Portal 176726d
apply formatter
Red-Portal 53522c9
apply formatter
Red-Portal e21084c
update docs
Red-Portal fd36d38
update docs
Red-Portal 5ef5805
run formatter
Red-Portal ff67c2e
refactor test for `KLMinWassFwdBwd`
Red-Portal 2bd3bc8
run formatter
Red-Portal 85dcaa9
improve docstrings callback section
Red-Portal 4b34875
bump patch version
Red-Portal File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| # [`KLMinWassFwdBwd`](@id klminwassfwdbwd) | ||
|
|
||
| ## Description | ||
|
|
||
| This algorithm aims to minimize the exclusive (or reverse) Kullback-Leibler (KL) divergence by running proximal gradient descent (also known as forward-backward splitting) in Wasserstein space[^DBCS2023]. | ||
| (This algorithm is also sometimes referred to as "Wasserstein VI".) | ||
| Since `KLMinWassFwdBwd` is a measure-space algorithm, its use is restricted to full-rank Gaussian variational families (`FullRankGaussian`) that makes the measure-valued operations tractable. | ||
|
|
||
| ```@docs | ||
| KLMinWassFwdBwd | ||
| ``` | ||
|
|
||
| The associated objective value, which is the ELBO, can be estimated through the following: | ||
|
|
||
| ```@docs; canonical=false | ||
| estimate_objective( | ||
| ::Random.AbstractRNG, | ||
| ::KLMinWassFwdBwd, | ||
| ::MvLocationScale, | ||
| ::Any; | ||
| ::Int, | ||
| ) | ||
| ``` | ||
|
|
||
| [^DBCS2023]: Diao, M. Z., Balasubramanian, K., Chewi, S., & Salim, A. (2023). Forward-backward Gaussian variational inference via JKO in the Bures-Wasserstein space. In *International Conference on Machine Learning*. PMLR. | ||
| ## [Methodology](@id klminwassfwdbwd_method) | ||
|
|
||
| This algorithm aims to solve the problem | ||
|
|
||
| ```math | ||
| \mathrm{minimize}_{q \in \mathcal{Q}}\quad \mathrm{KL}\left(q, \pi\right) | ||
| ``` | ||
|
|
||
| where $\mathcal{Q}$ is some family of distributions, often called the variational family. | ||
| Since we usually only have access to the unnormalized densities of the target distribution $\pi$, we don't have direct access to the KL divergence. | ||
| Instead, we focus on minimizing a surrogate objective, the *free energy functional*, which corresponds to the negated evidence lower bound[^JGJS1999], defined as | ||
|
|
||
| ```math | ||
| \mathcal{F}\left(q\right) \triangleq \mathcal{U}\left(q\right) + \mathcal{H}\left(q\right), | ||
| ``` | ||
|
|
||
| where | ||
|
|
||
| ```math | ||
| \begin{aligned} | ||
| \mathcal{U}\left(q\right) &= \mathbb{E}_{\theta \sim q} -\log \pi\left(\theta\right) | ||
| &&\text{(``potential energy'')} | ||
| \\ | ||
| \mathcal{H}\left(q\right) &= \mathbb{E}_{\theta \sim q} \log q\left(\theta\right) . | ||
| &&\text{(``Boltzmann entropy'')} | ||
| \end{aligned} | ||
| ``` | ||
|
|
||
| For solving this problem, `KLMinWassFwdBwd` relies on proximal stochastic gradient descent (PSGD)---also known as "forward-backward splitting"---that iterates | ||
|
|
||
| ```math | ||
| q_{t+1} = \mathrm{JKO}_{\gamma_t \mathcal{H}}\big( | ||
| q_{t} - \gamma_t \widehat{\nabla_{\mathrm{BW}} \mathcal{U}} (q_{t}) | ||
| \big) , | ||
| ``` | ||
|
|
||
| where $$\widehat{\nabla_{\mathrm{BW}} \mathcal{U}}$$ is a stochastic estimate of the Bures-Wasserstein measure-valued gradient of $$\mathcal{U}$$, the JKO (proximal) operator is defined as | ||
|
|
||
| ```math | ||
| \mathrm{JKO}_{\gamma_t \mathcal{H}}(\mu) | ||
| = | ||
| \argmin_{\nu \in \mathcal{Q}} \left\{ \mathcal{H}(\nu) + \frac{1}{2 \gamma_t} \mathrm{W}_2 {(\mu, \nu)}^2 \right\} , | ||
| ``` | ||
|
|
||
| and $$\mathrm{W}_2$$ is the Wasserstein-2 distance. | ||
| When $$\mathcal{Q}$$ is set to be the Bures-Wasserstein space of $$\mathbb{R}^d$$, this algorithm is referred to as the Jordan-Kinderlehrer-Otto (JKO) scheme[^JKO1998], which was originally developed to study gradient flows under Wasserstein metrics. | ||
| Within this context, `KLMinWassFwdBwd` can be viewed as a numerical realization of the JKO scheme by restricting $$\mathcal{Q}$$ to be a tractable parametric variational family. | ||
| Specifically, Diao *et al.*[^DBCS2023] derived the JKO update for multivariate Gaussians, which is implemented by `KLMinWassFwdBwd`. | ||
| `KLMinWassFwdBwd` also exactly corresponds to the measure-space analog of [KLMinRepGradProxDescent](@ref klminrepgradproxdescent). | ||
|
|
||
| [^JKO1998]: Jordan, R., Kinderlehrer, D., & Otto, F. (1998). The variational formulation of the Fokker--Planck equation. *SIAM Journal on Mathematical Analysis*, 29(1). | ||
| [^JGJS1999]: Jordan, M. I., Ghahramani, Z., Jaakkola, T. S., & Saul, L. K. (1999). An introduction to variational methods for graphical models. Machine learning, 37, 183-233. | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,213 @@ | ||
|
|
||
| """ | ||
| KLMinWassFwdBwd(n_samples, stepsize, subsampling) | ||
| KLMinWassFwdBwd(; n_samples, stepsize, subsampling) | ||
|
|
||
| KL divergence minimization by running stochastic proximal gradient descent (forward-backward splitting) in Wasserstein space[^DBCS2023]. | ||
|
|
||
| Denoting the target log-density as \$\$ \\log \\pi \$\$ and the current variational approximation as \$\$q\$\$, the original algorithm requires estimating the quantity \$\$ \\mathbb{E}_q \\nabla^2 \\log \\pi \$\$. If the target `LogDensityProblem` associated with \$\$ \\log \\pi \$\$ has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), we use the sample average of the Hessian. If the target has only first-order capability, we use Stein's identity. | ||
|
|
||
| # (Keyword) Arguments | ||
| - `n_samples::Int`: Number of samples used to estimate the Wasserstein gradient. (default: `1`) | ||
| - `stepsize::Float64`: Step size of stochastic proximal gradient descent. | ||
| - `subsampling::Union{Nothing,<:AbstractSubsampling}`: Optional subsampling strategy. | ||
|
|
||
| !!! note | ||
| The `subsampling` strategy is only applied to the target `LogDensityProblem` but not to the variational approximation `q`. That is, `KLMinWassFwdBwd` does not support amortization or structured variational families. | ||
|
|
||
| # Output | ||
| - `q`: The last iterate of the algorithm. | ||
|
|
||
| # Callback Signature | ||
| The `callback` function supplied to `optimize` needs to have the following signature: | ||
|
|
||
| callback(; rng, iteration, q, info) | ||
|
|
||
| The keyword arguments are as follows: | ||
| - `rng`: Random number generator internally used by the algorithm. | ||
| - `iteration`: The index of the current iteration. | ||
| - `q`: Current variational approximation. | ||
| - `info`: `NamedTuple` containing the information generated during the current iteration. | ||
|
|
||
| # Requirements | ||
| - The variational family is [`FullRankGaussian`](@ref FullRankGaussian). | ||
| - The target distribution has unconstrained support (\$\$\\mathbb{R}^d\$\$). | ||
| - The target `LogDensityProblems.logdensity(prob, x)` has at least first-order differentiation capability. | ||
| """ | ||
| @kwdef struct KLMinWassFwdBwd{Sub<:Union{Nothing,<:AbstractSubsampling}} <: | ||
| AbstractVariationalAlgorithm | ||
| n_samples::Int = 1 | ||
| stepsize::Float64 | ||
| subsampling::Sub = nothing | ||
| end | ||
|
|
||
| """ | ||
| gaussian_expectation_gradient_and_hessian!(rng, q, n_samples, grad_buf, hess_buf, prob) | ||
|
|
||
| Estimate the expectations of the gradient and Hessians of the log-density of `prob` taken over the Gaussian `q`. For estimating the expectation of the Hessian, if `prob` has second-order differentiation capability, this function uses the sample average of the Hessian. Otherwise, it uses Stein's identity. | ||
|
|
||
| # Arguments | ||
| - `rng::Random.AbstractRNG`: Random number generator. | ||
| - `q::MvLocationScale{<:LowerTriangular,<:Normal,L}`: Gaussian to take expectation over. | ||
| - `n_samples::Int`: Number of samples used for estimation. | ||
| - `grad_buf::AbstractVector`: Buffer for the gradient estimate. | ||
| - `hess_buf::AbstractMatrix`: Buffer for the Hessian estimate. | ||
| - `prob`: `LogDensityProblem` associated with the log-density gradient and Hessian subject to expectation. | ||
| """ | ||
| function gaussian_expectation_gradient_and_hessian!( | ||
| rng::Random.AbstractRNG, | ||
| q::MvLocationScale{<:LowerTriangular,<:Normal,L}, | ||
| n_samples::Int, | ||
| grad_buf::AbstractVector{T}, | ||
| hess_buf::AbstractMatrix{T}, | ||
| prob, | ||
| ) where {T<:Real,L} | ||
| logπ_avg = zero(T) | ||
| fill!(grad_buf, zero(T)) | ||
| fill!(hess_buf, zero(T)) | ||
|
|
||
| if LogDensityProblems.capabilities(typeof(prob)) ≤ | ||
| LogDensityProblems.LogDensityOrder{1}() | ||
| # Use Stein's identity | ||
| d = LogDensityProblems.dimension(prob) | ||
| u = randn(rng, T, d, n_samples) | ||
| z = q.scale*u .+ q.location | ||
| for b in 1:n_samples | ||
| zb, ub = view(z, :, b), view(u, :, b) | ||
| logπ, ∇logπ = LogDensityProblems.logdensity_and_gradient(prob, zb) | ||
| logπ_avg += logπ/n_samples | ||
| grad_buf += ∇logπ/n_samples | ||
| hess_buf += ub*(∇logπ/n_samples)' | ||
| end | ||
| return logπ_avg, grad_buf, hess_buf | ||
| else | ||
| # Use sample average of the Hessian. | ||
| z = rand(rng, q, n_samples) | ||
| for b in 1:n_samples | ||
| zb = view(z, :, b) | ||
| logπ, ∇logπ, ∇2logπ = LogDensityProblems.logdensity_gradient_and_hessian( | ||
| prob, zb | ||
| ) | ||
| logπ_avg += logπ/n_samples | ||
| grad_buf += ∇logπ/n_samples | ||
| hess_buf += ∇2logπ/n_samples | ||
| end | ||
| return logπ_avg, grad_buf, hess_buf | ||
| end | ||
| end | ||
|
|
||
| struct KLMinWassFwdBwdState{Q,P,S,Sigma,GradBuf,HessBuf} | ||
| q::Q | ||
| prob::P | ||
| sigma::Sigma | ||
| iteration::Int | ||
| sub_st::S | ||
| grad_buf::GradBuf | ||
| hess_buf::HessBuf | ||
| end | ||
|
|
||
| function init( | ||
| rng::Random.AbstractRNG, | ||
| alg::KLMinWassFwdBwd, | ||
| q_init::MvLocationScale{<:LowerTriangular,<:Normal,L}, | ||
| prob, | ||
| ) where {L} | ||
| sub = alg.subsampling | ||
| n_dims = LogDensityProblems.dimension(prob) | ||
| capability = LogDensityProblems.capabilities(typeof(prob)) | ||
| if capability < LogDensityProblems.LogDensityOrder{1}() | ||
| throw( | ||
| ArgumentError( | ||
| "`KLMinWassFwdBwd` requires at least first-order differentiation capability. The capability of the supplied `LogDensityProblem` is $(capability).", | ||
| ), | ||
| ) | ||
| end | ||
| sub_st = isnothing(sub) ? nothing : init(rng, sub) | ||
| grad_buf = Vector{eltype(q_init.location)}(undef, n_dims) | ||
| hess_buf = Matrix{eltype(q_init.location)}(undef, n_dims, n_dims) | ||
| return KLMinWassFwdBwdState(q_init, prob, cov(q_init), 0, sub_st, grad_buf, hess_buf) | ||
| end | ||
|
|
||
| output(::KLMinWassFwdBwd, state) = state.q | ||
|
|
||
| function step( | ||
| rng::Random.AbstractRNG, alg::KLMinWassFwdBwd, state, callback, objargs...; kwargs... | ||
| ) | ||
| (; n_samples, stepsize, subsampling) = alg | ||
| (; q, prob, sigma, iteration, sub_st, grad_buf, hess_buf) = state | ||
|
|
||
| m = mean(q) | ||
| Σ = sigma | ||
| η = convert(eltype(m), stepsize) | ||
| iteration += 1 | ||
|
|
||
| # Maybe apply subsampling | ||
| prob_sub, sub_st′, sub_inf = if isnothing(subsampling) | ||
| prob, sub_st, NamedTuple() | ||
| else | ||
| batch, sub_st′, sub_inf = step(rng, subsampling, sub_st) | ||
| prob_sub = subsample(prob, batch) | ||
| prob_sub, sub_st′, sub_inf | ||
| end | ||
|
|
||
| # Estimate the Wasserstein gradient | ||
| logπ_avg, grad_buf, hess_buf = gaussian_expectation_gradient_and_hessian!( | ||
| rng, q, n_samples, grad_buf, hess_buf, prob_sub | ||
| ) | ||
|
|
||
| m′ = m - η * (-grad_buf) | ||
| M = I - η*Hermitian(-hess_buf) | ||
| Σ_half = Hermitian(M*Σ*M) | ||
|
|
||
| # Compute the JKO proximal operator | ||
| Σ′ = (Σ_half + 2*η*I + sqrt(Hermitian(Σ_half*(Σ_half + 4*η*I))))/2 | ||
| q′ = MvLocationScale(m′, cholesky(Σ′).L, q.dist) | ||
|
|
||
| state = KLMinWassFwdBwdState(q′, prob, Σ′, iteration, sub_st′, grad_buf, hess_buf) | ||
| elbo = logπ_avg + entropy(q′) | ||
| info = merge((elbo=elbo,), sub_inf) | ||
|
|
||
| if !isnothing(callback) | ||
| info′ = callback(; rng, iteration, q, info) | ||
| info = !isnothing(info′) ? merge(info′, info) : info | ||
| end | ||
| state, false, info | ||
| end | ||
|
|
||
| """ | ||
| estimate_objective([rng,] alg, q, prob; n_samples) | ||
|
|
||
| Estimate the ELBO of the variational approximation `q` against the target log-density `prob`. | ||
|
|
||
| # Arguments | ||
| - `rng::Random.AbstractRNG`: Random number generator. | ||
| - `alg::KLMinWassFwdBwd`: Variational inference algorithm. | ||
| - `q::MvLocationScale{<:Any,<:Normal,<:Any}`: Gaussian variational approximation. | ||
| - `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. | ||
|
|
||
| # Keyword Arguments | ||
| - `n_samples::Int`: Number of Monte Carlo samples for estimating the objective. (default: Same as the the number of samples used for estimating the gradient during optimization.) | ||
|
|
||
| # Returns | ||
| - `obj_est`: Estimate of the objective value. | ||
| """ | ||
| function estimate_objective( | ||
| rng::Random.AbstractRNG, | ||
| alg::KLMinWassFwdBwd, | ||
| q::MvLocationScale{S,<:Normal,L}, | ||
| prob; | ||
| n_samples::Int=alg.n_samples, | ||
| ) where {S,L} | ||
| obj = RepGradELBO(n_samples; entropy=MonteCarloEntropy()) | ||
| if isnothing(alg.subsampling) | ||
| return estimate_objective(rng, obj, q, prob) | ||
| else | ||
| sub = alg.subsampling | ||
| sub_st = init(rng, sub) | ||
| return mapreduce(+, 1:length(sub)) do _ | ||
| batch, sub_st, _ = step(rng, sub, sub_st) | ||
| prob_sub = subsample(prob, batch) | ||
| estimate_objective(rng, obj, q, prob_sub) / length(sub) | ||
| end | ||
| end | ||
| end |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.