Skip to content

Commit

Permalink
migrate to AbstractDifferentiation
Browse files Browse the repository at this point in the history
  • Loading branch information
Red-Portal committed Jun 9, 2023
1 parent 4040149 commit 2a4514e
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 109 deletions.
3 changes: 1 addition & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
version = "0.2.3"

[deps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Expand All @@ -15,7 +16,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

[compat]
Bijectors = "0.11, 0.12"
Expand All @@ -27,7 +27,6 @@ ProgressMeter = "1.0.0"
Requires = "0.5, 1.0"
StatsBase = "0.32, 0.33, 0.34"
StatsFuns = "0.8, 0.9, 1"
Tracker = "0.2.3"
julia = "1.6"

[extras]
Expand Down
101 changes: 5 additions & 96 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@ using ProgressMeter, LinearAlgebra

using LogDensityProblems

using ForwardDiff
using Tracker

using Distributions
using DistributionsAD

using StatsFuns

using ForwardDiff
import AbstractDifferentiation as AD

value_and_gradient(f, xs...; adbackend) = AD.value_and_gradient(adbackend, f, xs...)

const PROGRESS = Ref(true)
function turnprogress(switch::Bool)
@info("[AdvancedVI]: global PROGRESS is set as $switch")
Expand All @@ -35,58 +37,6 @@ function __init__()
Flux.Optimise.apply!(o::TruncatedADAGrad, x, Δ) = apply!(o, x, Δ)
Flux.Optimise.apply!(o::DecayedADAGrad, x, Δ) = apply!(o, x, Δ)
end
@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
include("compat/zygote.jl")
export ZygoteAD

function AdvancedVI.grad!(
f::Function,
::Type{<:ZygoteAD},
λ::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
)
y, back = Zygote.pullback(f, λ)
dy = first(back(1.0))
DiffResults.value!(out, y)
DiffResults.gradient!(out, dy)
return out
end
end
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
include("compat/reversediff.jl")
export ReverseDiffAD

function AdvancedVI.grad!(
f::Function,
::Type{<:ReverseDiffAD},
λ::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
)
tp = AdvancedVI.tape(f, λ)
ReverseDiff.gradient!(out, tp, λ)
return out
end
end
@require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin
include("compat/enzyme.jl")
export EnzymeAD

function AdvancedVI.grad!(
f::Function,
::Type{<:EnzymeAD},
λ::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
)
# Use `Enzyme.ReverseWithPrimal` once it is released:
# https://github.com/EnzymeAD/Enzyme.jl/pull/598
y = f(λ)
DiffResults.value!(out, y)
dy = DiffResults.gradient(out)
fill!(dy, 0)
Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(λ, dy))
return out
end
end
end

export
Expand All @@ -97,16 +47,6 @@ export

const VariationalPosterior = Distribution{Multivariate, Continuous}


"""
grad!(f, λ, out)
Computes the gradients of the objective f. Default implementation is provided for
`VariationalInference{AD}` where `AD` is either `ForwardDiffAD` or `TrackerAD`.
This implicitly also gives a default implementation of `optimize!`.
"""
function grad! end

"""
vi(model, alg::VariationalInference)
vi(model, alg::VariationalInference, q::VariationalPosterior)
Expand All @@ -126,37 +66,6 @@ function vi end

function update end

# default implementations
function grad!(
f::Function,
adtype::Type{<:ForwardDiffAD},
λ::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult
)
# Set chunk size and do ForwardMode.
chunk_size = getchunksize(adtype)
config = if chunk_size == 0
ForwardDiff.GradientConfig(f, λ)
else
ForwardDiff.GradientConfig(f, λ, ForwardDiff.Chunk(length(λ), chunk_size))
end
ForwardDiff.gradient!(out, f, λ, config)
end

function grad!(
f::Function,
::Type{<:TrackerAD},
λ::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult
)
λ_tracked = Tracker.param(λ)
y = f(λ_tracked)
Tracker.back!(y, 1.0)

DiffResults.value!(out, Tracker.data(y))
DiffResults.gradient!(out, Tracker.grad(λ_tracked))
end

# estimators
abstract type AbstractVariationalObjective end

Expand Down
10 changes: 4 additions & 6 deletions src/objectives/elbo/elbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,14 @@ function ADVI(ℓπ, b⁻¹, n_samples::Int)
end

function estimate_gradient!(
adbackend::AD.AbstractBackend,
rng::Random.AbstractRNG,
objective::ELBO,
λ::Vector{<:Real},
rebuild,
out::DiffResults.MutableDiffResult)
rebuild)

n_samples = objective.n_samples

grad!(ADBackend(), λ, out) do λ′
nelbo, grad = value_and_gradient(λ; adbackend) do λ′
q_η = rebuild(λ′)
ηs = rand(rng, q_η, n_samples)

Expand All @@ -39,6 +38,5 @@ function estimate_gradient!(
elbo = 𝔼ℓ +
-elbo
end
nelbo = DiffResults.value(out)
(elbo=-nelbo,)
first(grad), (elbo=-nelbo,)
end
8 changes: 3 additions & 5 deletions src/vi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ function optimize(
n_max_iter::Int,
λ::AbstractVector{<:Real};
optimizer = TruncatedADAGrad(),
rng = Random.GLOBAL_RNG
rng = Random.default_rng(),
adbackend = AD.ForwardDiffBackend()
)
# TODO: really need a better way to warn the user about potentially
# not using the correct accumulator
Expand All @@ -24,19 +25,16 @@ function optimize(
@info "[$(string(objective))] Should only be seen once: optimizer created for θ" objectid(λ)
end

grad_buf = DiffResults.GradientResult(λ)

i = 0
prog = ProgressMeter.Progress(
n_max_iter; desc="[$(string(objective))] Optimizing...", barlen=0, enabled=PROGRESS[])

# add criterion? A running mean maybe?
time_elapsed = @elapsed begin
for i = 1:n_max_iter
stats = estimate_gradient!(rng, objective, λ, rebuild, grad_buf)
Δλ, stats = estimate_gradient!(adbackend, rng, objective, λ, rebuild)

# apply update rule
Δλ = DiffResults.gradient(grad_buf)
Δλ = apply!(optimizer, λ, Δλ)
@. λ = λ - Δλ

Expand Down

0 comments on commit 2a4514e

Please sign in to comment.