From 68162929ea63ddb5d32c8752f26c8c3e1e78e912 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 19 Nov 2025 11:19:19 -0500 Subject: [PATCH 01/18] add batch and match --- src/AdvancedVI.jl | 4 +- src/algorithms/fisherminbatchmatch.jl | 195 ++++++++++++++++++++++++++ test/runtests.jl | 1 + 3 files changed, 199 insertions(+), 1 deletion(-) create mode 100644 src/algorithms/fisherminbatchmatch.jl diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index d9f1fb26f..1d07fd975 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -358,7 +358,9 @@ include("algorithms/gauss_expected_grad_hess.jl") include("algorithms/klminwassfwdbwd.jl") include("algorithms/klminsqrtnaturalgraddescent.jl") include("algorithms/klminnaturalgraddescent.jl") +include("algorithms/fisherminbatchmatch.jl") -export KLMinWassFwdBwd, KLMinSqrtNaturalGradDescent, KLMinNaturalGradDescent +export KLMinWassFwdBwd, + KLMinSqrtNaturalGradDescent, KLMinNaturalGradDescent, FisherMinBatchMatch end diff --git a/src/algorithms/fisherminbatchmatch.jl b/src/algorithms/fisherminbatchmatch.jl new file mode 100644 index 000000000..aa4bd72fc --- /dev/null +++ b/src/algorithms/fisherminbatchmatch.jl @@ -0,0 +1,195 @@ + +""" + FisherMinBatchMatch(n_samples, subsampling) + FisherMinBatchMatch(; n_samples, subsampling) + +Covariance-weighted fisher divergence minimization via the batch-and-match algorithm[^DBCS2023]. + +# (Keyword) Arguments +- `n_samples::Int`: Number of samples (batchsize) used to compute the moments required for the batch-and-match update. (default: `32`) +- `subsampling::Union{Nothing,<:AbstractSubsampling}`: Optional subsampling strategy. (default: `nothing`) + +!!! note + The `subsampling` strategy is only applied to the target `LogDensityProblem` but not to the variational approximation `q`. That is, `FisherMinBatchMatch` does not support amortization or structured variational families. + +# Output +- `q`: The last iterate of the algorithm. + +# Callback Signature +The `callback` function supplied to `optimize` needs to have the following signature: + + callback(; rng, iteration, q, info) + +The keyword arguments are as follows: +- `rng`: Random number generator internally used by the algorithm. +- `iteration`: The index of the current iteration. +- `q`: Current variational approximation. +- `info`: `NamedTuple` containing the information generated during the current iteration. + +# Requirements +- The variational family is [`FullRankGaussian`](@ref FullRankGaussian). +- The target distribution has unconstrained support. +- The target `LogDensityProblems.logdensity(prob, x)` has at least first-order differentiation capability. +""" +@kwdef struct FisherMinBatchMatch{Sub<:Union{Nothing,<:AbstractSubsampling}} <: + AbstractVariationalAlgorithm + n_samples::Int = 32 + subsampling::Sub = nothing +end + +struct BatchMatchState{Q,P,Sigma,Sub,GradBuf} + q::Q + prob::P + sigma::Sigma + iteration::Int + sub_st::Sub + grad_buf::GradBuf +end + +function init( + rng::Random.AbstractRNG, + alg::FisherMinBatchMatch, + q::MvLocationScale{<:LowerTriangular,<:Normal,L}, + prob, +) where {L} + (; n_samples, subsampling) = alg + capability = LogDensityProblems.capabilities(typeof(prob)) + if capability < LogDensityProblems.LogDensityOrder{1}() + throw( + ArgumentError( + "`FisherMinBatchMatch` requires at least first-order differentiation capability. The capability of the supplied `LogDensityProblem` is $(capability).", + ), + ) + end + sub_st = isnothing(subsampling) ? nothing : init(rng, subsampling) + params, _ = Optimisers.destructure(q) + n_dims = LogDensityProblems.dimension(prob) + grad_buf = Matrix{eltype(params)}(undef, n_dims, n_samples) + return BatchMatchState(q, prob, cov(q), 0, sub_st, grad_buf) +end + +output(::FisherMinBatchMatch, state) = state.q + +function step( + rng::Random.AbstractRNG, + alg::FisherMinBatchMatch, + state, + callback, + objargs...; + kwargs..., +) + (; n_samples, subsampling) = alg + (; q, prob, sigma, iteration, sub_st, grad_buf) = state + + d = LogDensityProblems.dimension(prob) + μ = q.location + C = q.scale + Σ = sigma + iteration += 1 + + # Maybe apply subsampling + prob_sub, sub_st′, sub_inf = if isnothing(subsampling) + prob, sub_st, NamedTuple() + else + batch, sub_st′, sub_inf = step(rng, subsampling, sub_st) + prob_sub = subsample(prob, batch) + prob_sub, sub_st′, sub_inf + end + + u = randn(rng, eltype(μ), d, n_samples) + z = C*u .+ μ + logπ_avg = 0 + for b in 1:n_samples + logπb, gb = LogDensityProblems.logdensity_and_gradient(prob_sub, view(z, :, b)) + grad_buf[:, b] = gb + logπ_avg += logπb/n_samples + end + + # Estimate objective values + # + # WF = E[| ∇log(q/π) (z) |_{CC'}^2] (definition) + # = E[| C' (∇logq(z) - ∇logπ(z)) |^2] (Σ = CC') + # = E[| C' ( -(CC')\((Cu + μ) - μ) - ∇logπ(z)) |^2] (z = Cu + μ) + # = E[| C' ( -(CC')\(Cu) - ∇logπ(z)) |^2] + # = E[| -u - C'∇logπ(z)) |^2] + weighted_fisher = sum(abs2, -u .- (C'*grad_buf))/n_samples + elbo = logπ_avg + entropy(q) + + # BaM updates + zbar, C = mean_and_cov(z, 2) + gbar, Γ = mean_and_cov(grad_buf, 2) + + μmz = μ - zbar + λ = convert(eltype(μ), d*n_samples / iteration) + + U = Symmetric(λ*Γ + (λ/(1 + λ)*gbar)*gbar') + V = Symmetric(Σ + λ*C + (λ/(1 + λ)*μmz)*μmz') + + Σ′ = Hermitian(2*V/(I + real(sqrt(I + 4*U*V)))) + μ′ = 1/(1 + λ)*μ + λ/(1 + λ)*(Σ′*gbar + zbar) + q′ = MvLocationScale(μ′[:, 1], cholesky(Σ′).L, q.dist) + + info = (iteration=iteration, weighted_fisher=weighted_fisher, elbo=elbo) + + state = BatchMatchState(q′, prob, Σ′, iteration, sub_st′, grad_buf) + + if !isnothing(callback) + info′ = callback(; rng, iteration, q, state) + info = !isnothing(info′) ? merge(info′, info) : info + end + state, false, info +end + +function estimate_covweight_fisher( + rng::Random.AbstractRNG, + n_samples::Int, + q::MvLocationScale{S,<:Normal,L}, + prob, + grad_buf::Matrix=Matrix{eltype(params)}( + undef, LogDensityProblems.dimension(prob), n_samples + ), +) where {S,L} +end + +""" + estimate_objective([rng,] alg, q, prob; n_samples) + +Estimate the covariance-weighted Fisher divergence of the variational approximation `q` against the target log-density `prob`. + +# Arguments +- `rng::Random.AbstractRNG`: Random number generator. +- `alg::FisherMinBatchMatch`: Variational inference algorithm. +- `q::MvLocationScale{<:Any,<:Normal,<:Any}`: Gaussian variational approximation. +- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. + +# Keyword Arguments +- `n_samples::Int`: Number of Monte Carlo samples for estimating the objective. (default: Same as the the number of samples used for estimating the gradient during optimization.) + +# Returns +- `obj_est`: Estimate of the objective value. +""" +function estimate_objective( + rng::Random.AbstractRNG, + alg::FisherMinBatchMatch, + q::MvLocationScale{S,<:Normal,L}, + prob; + n_samples::Int=alg.n_samples, +) where {S,L} + d = LogDensityProblems.dimension(prob) + grad_buf = Matrix{eltype(params)}(undef, d, n_samples) + d = LogDensityProblems.dimension(prob) + μ = q.location + C = q.scale + u = randn(rng, eltype(μ), d, n_samples) + z = C*u .+ μ + for b in 1:n_samples + _, gb = LogDensityProblems.logdensity_and_gradient(prob, view(z, :, b)) + grad_buf[:, b] = gb + end + # WF = E[| ∇log(q/π) (z) |_{CC'}^2] (definition) + # = E[| C' (∇logq(z) - ∇logπ(z)) |^2] (Σ = CC') + # = E[| C' ( -(CC')\((Cu + μ) - μ) - ∇logπ(z)) |^2] (z = Cu + μ) + # = E[| C' ( -(CC')\(Cu) - ∇logπ(z)) |^2] + # = E[| -u - C'∇logπ(z)) |^2] + return sum(abs2, -u .- (C'*grad_buf))/n_samples +end diff --git a/test/runtests.jl b/test/runtests.jl index c8f5f66b2..64371ca47 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -69,6 +69,7 @@ if GROUP == "All" || GROUP == "GENERAL" include("algorithms/klminwassfwdbwd.jl") include("algorithms/klminsqrtnaturalgraddescent.jl") include("algorithms/klminnaturalgraddescent.jl") + include("algorithms/fisherminbatchmatch.jl") end if GROUP == "All" || GROUP == "AD" From 09317e2a5efdaaaf22fcd90ea347460eff5fb5de Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 19 Nov 2025 11:20:29 -0500 Subject: [PATCH 02/18] update HISTORY --- HISTORY.md | 1 + 1 file changed, 1 insertion(+) diff --git a/HISTORY.md b/HISTORY.md index d07bf48fd..bd2ec880c 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -8,6 +8,7 @@ Specifically, the following measure-space optimization algorithms have been adde - `KLMinWassFwdBwd` - `KLMinNaturalGradDescent` - `KLMinSqrtNaturalGradDescent` + - `FisherMinBatchMatch` ## Interface Change From 2b72d51bfb3a826bb9c22da0b59b02a942e57c2b Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 19 Nov 2025 11:25:42 -0500 Subject: [PATCH 03/18] fun formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/algorithms/fisherminbatchmatch.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/algorithms/fisherminbatchmatch.jl b/src/algorithms/fisherminbatchmatch.jl index aa4bd72fc..88f714b49 100644 --- a/src/algorithms/fisherminbatchmatch.jl +++ b/src/algorithms/fisherminbatchmatch.jl @@ -148,8 +148,7 @@ function estimate_covweight_fisher( grad_buf::Matrix=Matrix{eltype(params)}( undef, LogDensityProblems.dimension(prob), n_samples ), -) where {S,L} -end +) where {S,L} end """ estimate_objective([rng,] alg, q, prob; n_samples) From b746ab77328662d9c555720a0b546bd6577a8138 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 19 Nov 2025 11:58:04 -0500 Subject: [PATCH 04/18] fix add missing test file --- test/algorithms/fisherminbatchmatch.jl | 142 +++++++++++++++++++++++++ 1 file changed, 142 insertions(+) create mode 100644 test/algorithms/fisherminbatchmatch.jl diff --git a/test/algorithms/fisherminbatchmatch.jl b/test/algorithms/fisherminbatchmatch.jl new file mode 100644 index 000000000..3880b1c9d --- /dev/null +++ b/test/algorithms/fisherminbatchmatch.jl @@ -0,0 +1,142 @@ + +@testset "FisherMinBatchMatch" begin + begin + modelstats = normal_meanfield(Random.default_rng(), Float64; capability=2) + (; model, n_dims, μ_true, L_true) = modelstats + + alg = FisherMinBatchMatch() + L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(Float64, n_dims), L0) + + @testset "callback" begin + T = 10 + callback(; iteration, kwargs...) = (iteration_check=iteration,) + _, info, _ = optimize(alg, T, model, q0; callback, show_progress=PROGRESS) + @test [i.iteration_check for i in info] == 1:T + end + + @testset "estimate_objective" begin + q_true = FullRankGaussian(μ_true, LowerTriangular(Matrix(L_true))) + + obj_est = estimate_objective(alg, q_true, model) + @test isfinite(obj_est) + + obj_est = estimate_objective(alg, q_true, model; n_samples=10^6) + @test obj_est ≈ 0 atol=1e-2 + end + + @testset "determinism" begin + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + T = 10 + + q_avg, _, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS) + μ = q_avg.location + L = q_avg.scale + + rng_repl = StableRNG(seed) + q_avg, _, _ = optimize(rng_repl, alg, T, model, q0; show_progress=PROGRESS) + μ_repl = q_avg.location + L_repl = q_avg.scale + @test μ == μ_repl + @test L == L_repl + end + end + + @testset "error low capability" begin + modelstats = normal_meanfield(Random.default_rng(), Float64; capability=0) + (; model, n_dims) = modelstats + + alg = FisherMinBatchMatch() + + L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(Float64, n_dims), L0) + @test_throws "first-order" optimize(alg, 1, model, q0) + end + + @testset "type stability type=$(realtype), capability=$(capability)" for realtype in [ + Float64, Float32 + ], + capability in [1, 2] + + modelstats = normal_meanfield(Random.default_rng(), realtype; capability) + (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats + + alg = FisherMinBatchMatch() + T = 10 + + L0 = LowerTriangular(Matrix{realtype}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(realtype, n_dims), L0) + + q, _, _ = optimize(alg, T, model, q0; show_progress=PROGRESS) + + @test eltype(q.location) == eltype(μ_true) + @test eltype(q.scale) == eltype(L_true) + end + + @testset "convergence" begin + modelstats = normal_meanfield(Random.default_rng(), Float64) + (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats + + T = 1000 + alg = FisherMinBatchMatch() + + q_avg, _, _ = optimize(alg, T, model, q0; show_progress=PROGRESS) + + Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true) + Δλ = sum(abs2, q_avg.location - μ_true) + sum(abs2, q_avg.scale - L_true) + + @test Δλ ≤ Δλ0/2 + end + + @testset "subsampling" begin + n_data = 8 + + @testset "determinism" begin + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + modelstats = subsamplednormal(Random.default_rng(), n_data) + (; model, n_dims, μ_true, L_true) = modelstats + + L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(Float64, n_dims), L0) + + T = 10 + batchsize = 3 + subsampling = ReshufflingBatchSubsampling(1:n_data, batchsize) + alg_sub = FisherMinBatchMatch(; subsampling) + + q, _, _ = optimize(rng, alg_sub, T, model, q0; show_progress=PROGRESS) + μ = q.location + L = q.scale + + rng_repl = StableRNG(seed) + q, _, _ = optimize(rng_repl, alg_sub, T, model, q0; show_progress=PROGRESS) + μ_repl = q.location + L_repl = q.scale + @test μ == μ_repl + @test L == L_repl + end + + @testset "convergence" begin + modelstats = subsamplednormal(Random.default_rng(), n_data) + (; model, n_dims, μ_true, L_true) = modelstats + + L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(Float64, n_dims), L0) + + T = 1000 + batchsize = 1 + subsampling = ReshufflingBatchSubsampling(1:n_data, batchsize) + alg_sub = FisherMinBatchMatch(; subsampling) + + q, stats, _ = optimize(alg_sub, T, model, q0; show_progress=PROGRESS) + + Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true) + Δλ = sum(abs2, q.location - μ_true) + sum(abs2, q.scale - L_true) + + @test Δλ ≤ Δλ0/2 + end + end +end From bb70d0276e91d67eb9a7a92efc239bc1f820b396 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 19 Nov 2025 15:27:17 -0500 Subject: [PATCH 05/18] add documentation and update docstring for batch-and-match --- docs/make.jl | 1 + docs/src/fisherminbatchmatch.md | 56 +++++++++++++++++++++++++++ docs/src/index.md | 1 + src/algorithms/fisherminbatchmatch.jl | 8 +++- 4 files changed, 65 insertions(+), 1 deletion(-) create mode 100644 docs/src/fisherminbatchmatch.md diff --git a/docs/make.jl b/docs/make.jl index b28e42cad..34c0080d3 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -29,6 +29,7 @@ makedocs(; "`KLMinWassFwdBwd`" => "klminwassfwdbwd.md", "`KLMinNaturalGradDescent`" => "klminnaturalgraddescent.md", "`KLMinSqrtNaturalGradDescent`" => "klminsqrtnaturalgraddescent.md", + "`FisherMinBatchMatch`" => "fisherminbatchmatch.md", ], "Variational Families" => "families.md", "Optimization" => "optimization.md", diff --git a/docs/src/fisherminbatchmatch.md b/docs/src/fisherminbatchmatch.md new file mode 100644 index 000000000..5ab6c6026 --- /dev/null +++ b/docs/src/fisherminbatchmatch.md @@ -0,0 +1,56 @@ +# [`FisherMinBatchMatch`](@id fisherminbatchmatch) + +## Description + +This algorithm, known as batch-and-match (BaM) aims to minimize the covariance-weighted 2nd-order fisher divergence by running a proximal point-type method[^CMPMGBS24]. +On certain low-dimensional problems, BaM can converge very quickly without any tuning. +Since `FisherMinBatchMatch` is a measure-space algorithm, its use is restricted to full-rank Gaussian variational families (`FullRankGaussian`) that make the measure-valued operations tractable. + +```@docs +FisherMinBatchMatch +``` + +The associated objective value can be estimated through the following: + +```@docs; canonical=false +estimate_objective( + ::Random.AbstractRNG, + ::KLMinWassFwdBwd, + ::MvLocationScale, + ::Any; + ::Int, +) +``` + +[^CMPMGBS24]: Cai, D., Modi, C., Pillaud-Vivien, L., Margossian, C. C., Gower, R. M., Blei, D. M., & Saul, L. K. (2024). Batch and match: black-box variational inference with a score-based divergence. In *Proceedings of the International Conference on Machine Learning*. +## [Methodology](@id fisherminbatchmatch_method) + +This algorithm aims to solve the problem + +```math + \mathrm{minimize}_{q \in \mathcal{Q}}\quad \mathrm{F}_{\mathrm{cov}}(q, \pi), +``` +where $\mathcal{Q}$ is some family of distributions, often called the variational family, and $\mathrm{F}_{\mathrm{cov}}$ is a divergence defined as +```math +\mathrm{F}_{\mathrm{cov}}(q, \pi) = \mathbb{E}_{z \sim q} {\left\lVert \nabla \log \frac{q}{\pi} (z) \right\rVert}_{\mathrm{Cov}(q)}^2 , +``` +where ${\lVert x \rVert}_{A}^2 = x^{\top} A x $ is a weighted norm. +$\mathrm{F}_{\mathrm{cov}}$ can be viewed as a variant of the canonical 2nd order Fisher divergence defined as +```math +\mathrm{F}_{2}(q, \pi) = \sqrt{ \mathbb{E}_{z \sim q} {\left\lVert \nabla \log \frac{q}{\pi} (z) \right\rVert}^2 }. +``` + +The use of the weighted norm ${\lVert \cdot \rVert}_{\mathrm{Cov}(q)}^2$ facilitates the use of a proximal point-type method for minimizing $\mathrm{F}_{2}(q, \pi)$. +In particular, BaM iterates the update + +```math + q_{t+1} = \argmin_{q \in \mathcal{Q}} \left\{ \mathrm{F}_{\mathrm{cov}}(q, \pi) + \frac{2}{\lambda_t} \mathrm{KL}\left(q_t, q\right) \right\} . +``` +Since $\mathrm{F}(q, \pi)$ is intractable, it is replaced with a Monte Carlo approximation with a number of samples `n_samples`. +Furthermore, by restricting $\mathcal{Q}$ to a Gaussian variational family, the update rule admits a closed form solution[^CMPMGBS24]. +Notice that the update does not involve the parameterization of $q_t$, which makes `FisherMinBatchMatch` a measure-space algorithm. + +Historically, the idea of using a proximal point-type update for minimizing a Fisher divergence-like objective was initially coined as Gaussian score matching[^MGMYBS23]. +BaM can be viewed as a successor to this algorithm. + +[^MGMYBS23]: Modi, C., Gower, R., Margossian, C., Yao, Y., Blei, D., & Saul, L. (2023). Variational inference with Gaussian score matching. In *Advances in Neural Information Processing Systems*, 36. diff --git a/docs/src/index.md b/docs/src/index.md index 8fafa6e51..234c28d9f 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -20,3 +20,4 @@ For using the algorithms implemented in `AdvancedVI`, refer to the corresponding - [KLMinNaturalGradDescent](@ref klminnaturalgraddescent) - [KLMinSqrtNaturalGradDescent](@ref klminsqrtnaturalgraddescent) - [KLMinWassFwdBwd](@ref klminwassfwdbwd) + - [FisherMinBatchMatch](@ref fisherminbatchmatch) diff --git a/src/algorithms/fisherminbatchmatch.jl b/src/algorithms/fisherminbatchmatch.jl index 88f714b49..55b44e74a 100644 --- a/src/algorithms/fisherminbatchmatch.jl +++ b/src/algorithms/fisherminbatchmatch.jl @@ -3,12 +3,18 @@ FisherMinBatchMatch(n_samples, subsampling) FisherMinBatchMatch(; n_samples, subsampling) -Covariance-weighted fisher divergence minimization via the batch-and-match algorithm[^DBCS2023]. +Covariance-weighted fisher divergence minimization via the batch-and-match algorithm[^DBCS2023], which is a proximal point-type optimization scheme. # (Keyword) Arguments - `n_samples::Int`: Number of samples (batchsize) used to compute the moments required for the batch-and-match update. (default: `32`) - `subsampling::Union{Nothing,<:AbstractSubsampling}`: Optional subsampling strategy. (default: `nothing`) +!!! warning + `FisherMinBatchMatch` with subsampling enabled results in a biased algorithm and may not properly optimize the covariance-weighted fisher divergence. + +!!! note + `FisherMinBatchMatch` requires a sufficiently large `n_samples` to converge quickly. + !!! note The `subsampling` strategy is only applied to the target `LogDensityProblem` but not to the variational approximation `q`. That is, `FisherMinBatchMatch` does not support amortization or structured variational families. From 010370277a1144e5658fc1a2d2efdf2b7f37bec6 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 19 Nov 2025 15:30:14 -0500 Subject: [PATCH 06/18] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/src/fisherminbatchmatch.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/fisherminbatchmatch.md b/docs/src/fisherminbatchmatch.md index 5ab6c6026..5488485c0 100644 --- a/docs/src/fisherminbatchmatch.md +++ b/docs/src/fisherminbatchmatch.md @@ -46,6 +46,7 @@ In particular, BaM iterates the update ```math q_{t+1} = \argmin_{q \in \mathcal{Q}} \left\{ \mathrm{F}_{\mathrm{cov}}(q, \pi) + \frac{2}{\lambda_t} \mathrm{KL}\left(q_t, q\right) \right\} . ``` + Since $\mathrm{F}(q, \pi)$ is intractable, it is replaced with a Monte Carlo approximation with a number of samples `n_samples`. Furthermore, by restricting $\mathcal{Q}$ to a Gaussian variational family, the update rule admits a closed form solution[^CMPMGBS24]. Notice that the update does not involve the parameterization of $q_t$, which makes `FisherMinBatchMatch` a measure-space algorithm. From 867fc27c5f352d5ffc58d4bb373de611cfa06bd0 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 19 Nov 2025 15:30:21 -0500 Subject: [PATCH 07/18] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/src/fisherminbatchmatch.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/fisherminbatchmatch.md b/docs/src/fisherminbatchmatch.md index 5488485c0..8db47421e 100644 --- a/docs/src/fisherminbatchmatch.md +++ b/docs/src/fisherminbatchmatch.md @@ -30,6 +30,7 @@ This algorithm aims to solve the problem ```math \mathrm{minimize}_{q \in \mathcal{Q}}\quad \mathrm{F}_{\mathrm{cov}}(q, \pi), ``` + where $\mathcal{Q}$ is some family of distributions, often called the variational family, and $\mathrm{F}_{\mathrm{cov}}$ is a divergence defined as ```math \mathrm{F}_{\mathrm{cov}}(q, \pi) = \mathbb{E}_{z \sim q} {\left\lVert \nabla \log \frac{q}{\pi} (z) \right\rVert}_{\mathrm{Cov}(q)}^2 , From c5e725b456e4ab60fc34d61d7b3bae5ad6ec4c8a Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 19 Nov 2025 15:30:27 -0500 Subject: [PATCH 08/18] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/src/fisherminbatchmatch.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/fisherminbatchmatch.md b/docs/src/fisherminbatchmatch.md index 8db47421e..36070a843 100644 --- a/docs/src/fisherminbatchmatch.md +++ b/docs/src/fisherminbatchmatch.md @@ -32,6 +32,7 @@ This algorithm aims to solve the problem ``` where $\mathcal{Q}$ is some family of distributions, often called the variational family, and $\mathrm{F}_{\mathrm{cov}}$ is a divergence defined as + ```math \mathrm{F}_{\mathrm{cov}}(q, \pi) = \mathbb{E}_{z \sim q} {\left\lVert \nabla \log \frac{q}{\pi} (z) \right\rVert}_{\mathrm{Cov}(q)}^2 , ``` From e3e4e3bc1f157988780e85b56311bcc6aa292ff9 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 19 Nov 2025 15:30:42 -0500 Subject: [PATCH 09/18] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/src/fisherminbatchmatch.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/fisherminbatchmatch.md b/docs/src/fisherminbatchmatch.md index 36070a843..e6a69e4b7 100644 --- a/docs/src/fisherminbatchmatch.md +++ b/docs/src/fisherminbatchmatch.md @@ -36,6 +36,7 @@ where $\mathcal{Q}$ is some family of distributions, often called the variationa ```math \mathrm{F}_{\mathrm{cov}}(q, \pi) = \mathbb{E}_{z \sim q} {\left\lVert \nabla \log \frac{q}{\pi} (z) \right\rVert}_{\mathrm{Cov}(q)}^2 , ``` + where ${\lVert x \rVert}_{A}^2 = x^{\top} A x $ is a weighted norm. $\mathrm{F}_{\mathrm{cov}}$ can be viewed as a variant of the canonical 2nd order Fisher divergence defined as ```math From 188489edaa23d5073913033cef93df68b75ebc25 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 19 Nov 2025 15:30:48 -0500 Subject: [PATCH 10/18] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/src/fisherminbatchmatch.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/fisherminbatchmatch.md b/docs/src/fisherminbatchmatch.md index e6a69e4b7..7045d5132 100644 --- a/docs/src/fisherminbatchmatch.md +++ b/docs/src/fisherminbatchmatch.md @@ -39,6 +39,7 @@ where $\mathcal{Q}$ is some family of distributions, often called the variationa where ${\lVert x \rVert}_{A}^2 = x^{\top} A x $ is a weighted norm. $\mathrm{F}_{\mathrm{cov}}$ can be viewed as a variant of the canonical 2nd order Fisher divergence defined as + ```math \mathrm{F}_{2}(q, \pi) = \sqrt{ \mathbb{E}_{z \sim q} {\left\lVert \nabla \log \frac{q}{\pi} (z) \right\rVert}^2 }. ``` From b4da148630d17f1957b1b98d616f0eb26f6cb012 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 20 Nov 2025 10:35:53 -0500 Subject: [PATCH 11/18] fix docs Co-authored-by: Penelope Yong --- docs/src/fisherminbatchmatch.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/fisherminbatchmatch.md b/docs/src/fisherminbatchmatch.md index 7045d5132..a81a43bdb 100644 --- a/docs/src/fisherminbatchmatch.md +++ b/docs/src/fisherminbatchmatch.md @@ -2,7 +2,7 @@ ## Description -This algorithm, known as batch-and-match (BaM) aims to minimize the covariance-weighted 2nd-order fisher divergence by running a proximal point-type method[^CMPMGBS24]. +This algorithm, known as batch-and-match (BaM) aims to minimize the covariance-weighted 2nd-order Fisher divergence by running a proximal point-type method[^CMPMGBS24]. On certain low-dimensional problems, BaM can converge very quickly without any tuning. Since `FisherMinBatchMatch` is a measure-space algorithm, its use is restricted to full-rank Gaussian variational families (`FullRankGaussian`) that make the measure-valued operations tractable. From 3032c6651bb28c229e2fdbff6fbb706155166ac3 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 20 Nov 2025 10:36:06 -0500 Subject: [PATCH 12/18] fix docs Co-authored-by: Penelope Yong --- docs/src/fisherminbatchmatch.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/fisherminbatchmatch.md b/docs/src/fisherminbatchmatch.md index a81a43bdb..dae9b167f 100644 --- a/docs/src/fisherminbatchmatch.md +++ b/docs/src/fisherminbatchmatch.md @@ -38,7 +38,7 @@ where $\mathcal{Q}$ is some family of distributions, often called the variationa ``` where ${\lVert x \rVert}_{A}^2 = x^{\top} A x $ is a weighted norm. -$\mathrm{F}_{\mathrm{cov}}$ can be viewed as a variant of the canonical 2nd order Fisher divergence defined as +$\mathrm{F}_{\mathrm{cov}}$ can be viewed as a variant of the canonical 2nd-order Fisher divergence defined as ```math \mathrm{F}_{2}(q, \pi) = \sqrt{ \mathbb{E}_{z \sim q} {\left\lVert \nabla \log \frac{q}{\pi} (z) \right\rVert}^2 }. From 5773df6afa813a886590285240a2b8b83034b773 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 20 Nov 2025 10:36:20 -0500 Subject: [PATCH 13/18] fix docs Co-authored-by: Penelope Yong --- src/algorithms/fisherminbatchmatch.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/algorithms/fisherminbatchmatch.jl b/src/algorithms/fisherminbatchmatch.jl index 55b44e74a..3ca2da0ac 100644 --- a/src/algorithms/fisherminbatchmatch.jl +++ b/src/algorithms/fisherminbatchmatch.jl @@ -10,7 +10,7 @@ Covariance-weighted fisher divergence minimization via the batch-and-match algor - `subsampling::Union{Nothing,<:AbstractSubsampling}`: Optional subsampling strategy. (default: `nothing`) !!! warning - `FisherMinBatchMatch` with subsampling enabled results in a biased algorithm and may not properly optimize the covariance-weighted fisher divergence. + `FisherMinBatchMatch` with subsampling enabled results in a biased algorithm and may not properly optimize the covariance-weighted Fisher divergence. !!! note `FisherMinBatchMatch` requires a sufficiently large `n_samples` to converge quickly. From 1f1ff5a22c5fe37ad1512cbb084e15076aa827ef Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 20 Nov 2025 10:40:42 -0500 Subject: [PATCH 14/18] fix remove dead code --- src/algorithms/fisherminbatchmatch.jl | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/algorithms/fisherminbatchmatch.jl b/src/algorithms/fisherminbatchmatch.jl index 55b44e74a..cfa4a6aaa 100644 --- a/src/algorithms/fisherminbatchmatch.jl +++ b/src/algorithms/fisherminbatchmatch.jl @@ -146,16 +146,6 @@ function step( state, false, info end -function estimate_covweight_fisher( - rng::Random.AbstractRNG, - n_samples::Int, - q::MvLocationScale{S,<:Normal,L}, - prob, - grad_buf::Matrix=Matrix{eltype(params)}( - undef, LogDensityProblems.dimension(prob), n_samples - ), -) where {S,L} end - """ estimate_objective([rng,] alg, q, prob; n_samples) From e19fb582d0e93055c0781e3bffb6bfba6226cf07 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 20 Nov 2025 10:42:25 -0500 Subject: [PATCH 15/18] fix compute average outside of loop for batch-and-match --- src/algorithms/fisherminbatchmatch.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/algorithms/fisherminbatchmatch.jl b/src/algorithms/fisherminbatchmatch.jl index cfa4a6aaa..3caaeee53 100644 --- a/src/algorithms/fisherminbatchmatch.jl +++ b/src/algorithms/fisherminbatchmatch.jl @@ -104,12 +104,13 @@ function step( u = randn(rng, eltype(μ), d, n_samples) z = C*u .+ μ - logπ_avg = 0 + logπ_sum = zero(eltype(μ)) for b in 1:n_samples logπb, gb = LogDensityProblems.logdensity_and_gradient(prob_sub, view(z, :, b)) grad_buf[:, b] = gb - logπ_avg += logπb/n_samples + logπ_sum += logπb end + logπ_avg = logπ_sum/n_samples # Estimate objective values # From 1ca0f00710984da630abb9859b09220d757fc562 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 20 Nov 2025 10:43:19 -0500 Subject: [PATCH 16/18] fix remove reference in docstring --- src/algorithms/fisherminbatchmatch.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/algorithms/fisherminbatchmatch.jl b/src/algorithms/fisherminbatchmatch.jl index 34ce2144d..b51a133a2 100644 --- a/src/algorithms/fisherminbatchmatch.jl +++ b/src/algorithms/fisherminbatchmatch.jl @@ -3,7 +3,7 @@ FisherMinBatchMatch(n_samples, subsampling) FisherMinBatchMatch(; n_samples, subsampling) -Covariance-weighted fisher divergence minimization via the batch-and-match algorithm[^DBCS2023], which is a proximal point-type optimization scheme. +Covariance-weighted fisher divergence minimization via the batch-and-match algorithm, which is a proximal point-type optimization scheme. # (Keyword) Arguments - `n_samples::Int`: Number of samples (batchsize) used to compute the moments required for the batch-and-match update. (default: `32`) From 445f6c019fc2574eba365e260d92bf8b462a269b Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 20 Nov 2025 10:44:52 -0500 Subject: [PATCH 17/18] fix capitalization in dosctring --- src/algorithms/fisherminbatchmatch.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/algorithms/fisherminbatchmatch.jl b/src/algorithms/fisherminbatchmatch.jl index b51a133a2..cda1128dd 100644 --- a/src/algorithms/fisherminbatchmatch.jl +++ b/src/algorithms/fisherminbatchmatch.jl @@ -3,7 +3,7 @@ FisherMinBatchMatch(n_samples, subsampling) FisherMinBatchMatch(; n_samples, subsampling) -Covariance-weighted fisher divergence minimization via the batch-and-match algorithm, which is a proximal point-type optimization scheme. +Covariance-weighted Fisher divergence minimization via the batch-and-match algorithm, which is a proximal point-type optimization scheme. # (Keyword) Arguments - `n_samples::Int`: Number of samples (batchsize) used to compute the moments required for the batch-and-match update. (default: `32`) From 40ad7e99146658d6957014741264daa538f87c4a Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 20 Nov 2025 11:41:06 -0500 Subject: [PATCH 18/18] refactor move duplicate code in batch match to a common function --- src/algorithms/fisherminbatchmatch.jl | 86 ++++++++++++++------------- 1 file changed, 45 insertions(+), 41 deletions(-) diff --git a/src/algorithms/fisherminbatchmatch.jl b/src/algorithms/fisherminbatchmatch.jl index cda1128dd..b794a12af 100644 --- a/src/algorithms/fisherminbatchmatch.jl +++ b/src/algorithms/fisherminbatchmatch.jl @@ -43,12 +43,13 @@ The keyword arguments are as follows: subsampling::Sub = nothing end -struct BatchMatchState{Q,P,Sigma,Sub,GradBuf} +struct BatchMatchState{Q,P,Sigma,Sub,UBuf,GradBuf} q::Q prob::P sigma::Sigma iteration::Int sub_st::Sub + u_buf::UBuf grad_buf::GradBuf end @@ -70,12 +71,45 @@ function init( sub_st = isnothing(subsampling) ? nothing : init(rng, subsampling) params, _ = Optimisers.destructure(q) n_dims = LogDensityProblems.dimension(prob) + u_buf = Matrix{eltype(params)}(undef, n_dims, n_samples) grad_buf = Matrix{eltype(params)}(undef, n_dims, n_samples) - return BatchMatchState(q, prob, cov(q), 0, sub_st, grad_buf) + return BatchMatchState(q, prob, cov(q), 0, sub_st, u_buf, grad_buf) end output(::FisherMinBatchMatch, state) = state.q +function rand_batch_match_samples_with_objective!( + rng::Random.AbstractRNG, + q::MvLocationScale, + n_samples::Int, + prob, + u_buf=Matrix{eltype(q)}(undef, LogDensityProblems.dimension(prob), n_samples), + grad_buf=Matrix{eltype(q)}(undef, LogDensityProblems.dimension(prob), n_samples), +) + μ = q.location + C = q.scale + u = Random.randn!(rng, u_buf) + z = C*u .+ μ + logπ_sum = zero(eltype(μ)) + for b in 1:n_samples + logπb, gb = LogDensityProblems.logdensity_and_gradient(prob, view(z, :, b)) + grad_buf[:, b] = gb + logπ_sum += logπb + end + logπ_avg = logπ_sum/n_samples + + # Estimate objective values + # + # F = E[| ∇log(q/π) (z) |_{CC'}^2] (definition) + # = E[| C' (∇logq(z) - ∇logπ(z)) |^2] (Σ = CC') + # = E[| C' ( -(CC')\((Cu + μ) - μ) - ∇logπ(z)) |^2] (z = Cu + μ) + # = E[| C' ( -(CC')\(Cu) - ∇logπ(z)) |^2] + # = E[| -u - C'∇logπ(z)) |^2] + fisher = sum(abs2, -u_buf - (C'*grad_buf))/n_samples + + return u_buf, z, grad_buf, fisher, logπ_avg +end + function step( rng::Random.AbstractRNG, alg::FisherMinBatchMatch, @@ -85,7 +119,7 @@ function step( kwargs..., ) (; n_samples, subsampling) = alg - (; q, prob, sigma, iteration, sub_st, grad_buf) = state + (; q, prob, sigma, iteration, sub_st, u_buf, grad_buf) = state d = LogDensityProblems.dimension(prob) μ = q.location @@ -102,25 +136,9 @@ function step( prob_sub, sub_st′, sub_inf end - u = randn(rng, eltype(μ), d, n_samples) - z = C*u .+ μ - logπ_sum = zero(eltype(μ)) - for b in 1:n_samples - logπb, gb = LogDensityProblems.logdensity_and_gradient(prob_sub, view(z, :, b)) - grad_buf[:, b] = gb - logπ_sum += logπb - end - logπ_avg = logπ_sum/n_samples - - # Estimate objective values - # - # WF = E[| ∇log(q/π) (z) |_{CC'}^2] (definition) - # = E[| C' (∇logq(z) - ∇logπ(z)) |^2] (Σ = CC') - # = E[| C' ( -(CC')\((Cu + μ) - μ) - ∇logπ(z)) |^2] (z = Cu + μ) - # = E[| C' ( -(CC')\(Cu) - ∇logπ(z)) |^2] - # = E[| -u - C'∇logπ(z)) |^2] - weighted_fisher = sum(abs2, -u .- (C'*grad_buf))/n_samples - elbo = logπ_avg + entropy(q) + u_buf, z, grad_buf, fisher, logπ_avg = rand_batch_match_samples_with_objective!( + rng, q, n_samples, prob_sub, u_buf, grad_buf + ) # BaM updates zbar, C = mean_and_cov(z, 2) @@ -136,9 +154,10 @@ function step( μ′ = 1/(1 + λ)*μ + λ/(1 + λ)*(Σ′*gbar + zbar) q′ = MvLocationScale(μ′[:, 1], cholesky(Σ′).L, q.dist) - info = (iteration=iteration, weighted_fisher=weighted_fisher, elbo=elbo) + elbo = logπ_avg + entropy(q) + info = (iteration=iteration, covweighted_fisher=fisher, elbo=elbo) - state = BatchMatchState(q′, prob, Σ′, iteration, sub_st′, grad_buf) + state = BatchMatchState(q′, prob, Σ′, iteration, sub_st′, u_buf, grad_buf) if !isnothing(callback) info′ = callback(; rng, iteration, q, state) @@ -171,21 +190,6 @@ function estimate_objective( prob; n_samples::Int=alg.n_samples, ) where {S,L} - d = LogDensityProblems.dimension(prob) - grad_buf = Matrix{eltype(params)}(undef, d, n_samples) - d = LogDensityProblems.dimension(prob) - μ = q.location - C = q.scale - u = randn(rng, eltype(μ), d, n_samples) - z = C*u .+ μ - for b in 1:n_samples - _, gb = LogDensityProblems.logdensity_and_gradient(prob, view(z, :, b)) - grad_buf[:, b] = gb - end - # WF = E[| ∇log(q/π) (z) |_{CC'}^2] (definition) - # = E[| C' (∇logq(z) - ∇logπ(z)) |^2] (Σ = CC') - # = E[| C' ( -(CC')\((Cu + μ) - μ) - ∇logπ(z)) |^2] (z = Cu + μ) - # = E[| C' ( -(CC')\(Cu) - ∇logπ(z)) |^2] - # = E[| -u - C'∇logπ(z)) |^2] - return sum(abs2, -u .- (C'*grad_buf))/n_samples + _, _, _, fisher, _ = rand_batch_match_samples_with_objective!(rng, q, n_samples, prob) + return fisher end