Skip to content

Commit

Permalink
Merge 7f7ba11 into b0c4be3
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed Apr 25, 2024
2 parents b0c4be3 + 7f7ba11 commit 4e35eab
Show file tree
Hide file tree
Showing 11 changed files with 306 additions and 131 deletions.
50 changes: 50 additions & 0 deletions .github/workflows/IntegrationTest.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
name: IntegrationTest

on:
push:
branches:
- master
merge_group:
types: [checks_requested]
pull_request:
branches: [v0.2-backport]

jobs:
test:
name: ${{ matrix.package.repo }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
package:
- {user: TuringLang, repo: Turing.jl}

steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
with:
version: 1
arch: x64
- uses: julia-actions/julia-buildpkg@latest
- name: Clone Downstream
uses: actions/checkout@v2
with:
repository: ${{ matrix.package.user }}/${{ matrix.package.repo }}
path: downstream
- name: Load this and run the downstream tests
shell: julia --color=yes --project=downstream {0}
run: |
using Pkg
try
# force it to use this PR's version of the package
Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps
Pkg.update()
Pkg.test(julia_args=["--depwarn=no"]) # resolver may fail with test time deps
catch err
err isa Pkg.Resolve.ResolverError || rethrow()
# If we can't resolve that means this is incompatible by SemVer and this is fine
# It means we marked this as a breaking change, so we don't need to worry about
# Mistakenly introducing a breaking change, as we have intentionally made one
@info "Not compatible with this release. No problem." exception=err
exit(0) # Exit immediately, as a success
end
33 changes: 26 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
name = "AdvancedVI"
uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
version = "0.2.4"
version = "0.2.5"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Expand All @@ -16,22 +18,39 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

[weakdeps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
AdvancedVIEnzymeExt = ["Enzyme"]
AdvancedVIFluxExt = ["Flux"]
AdvancedVIReverseDiffExt = ["ReverseDiff"]
AdvancedVIZygoteExt = ["Zygote"]

[compat]
Bijectors = "0.11, 0.12, 0.13"
Distributions = "0.21, 0.22, 0.23, 0.24, 0.25"
DistributionsAD = "0.2, 0.3, 0.4, 0.5, 0.6"
DocStringExtensions = "0.8, 0.9"
Enzyme = "0.12"
LinearAlgebra = "1.6"
ForwardDiff = "0.10.3"
Flux = "0.14"
ProgressMeter = "1.0.0"
Requires = "0.5, 1.0"
Random = "1.6"
Requires = "1"
ReverseDiff = "1"
StatsBase = "0.32, 0.33, 0.34"
StatsFuns = "0.8, 0.9, 1"
Tracker = "0.2.3"
Zygote = "0.6"
julia = "1.6"

[extras]
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Pkg", "Test"]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
42 changes: 42 additions & 0 deletions ext/AdvancedVIEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
module AdvancedVIEnzymeExt

if isdefined(Base, :get_extension)
using AdvancedVI: AdvancedVI, ADTypes, DiffResults, Distributions
using Enzyme: Enzyme
else
using ..AdvancedVI: AdvancedVI, ADTypes, DiffResults, Distributions
using ..Enzyme: Enzyme
end

AdvancedVI.ADBackend(::Val{:enzyme}) = ADTypes.AutoEnzyme()
function AdvancedVI.setadbackend(::Val{:enzyme})
Base.depwarn("`setadbackend` is deprecated. Please pass a `ADTypes.AbstractADType` as a keyword argument to the VI algorithm.", :setadbackend)
AdvancedVI.ADBACKEND[] = :enzyme
end

function AdvancedVI.grad!(
vo,
alg::AdvancedVI.VariationalInference{<:ADTypes.AutoEnzyme},
q,
model,
θ::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
args...
)
f(θ) =
if (q isa Distributions.Distribution)
-vo(alg, AdvancedVI.update(q, θ), model, args...)
else
-vo(alg, q(θ), model, args...)
end
# Use `Enzyme.ReverseWithPrimal` once it is released:
# https://github.com/EnzymeAD/Enzyme.jl/pull/598
y = f(θ)
DiffResults.value!(out, y)
dy = DiffResults.gradient(out)
fill!(dy, 0)
Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, dy))
return out
end

end
13 changes: 13 additions & 0 deletions ext/AdvancedVIFluxExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module AdvancedVIFluxExt

if isdefined(Base, :get_extension)
using AdvancedVI: AdvancedVI
using Flux: Flux
else
using ..AdvancedVI: AdvancedVI
using ..Flux: Flux
end

AdvancedVI.apply!(o::Flux.Optimise.AbstractOptimiser, x, Δ) = Flux.Optimise.apply!(o, x, Δ)

end
40 changes: 40 additions & 0 deletions ext/AdvancedVIReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
module AdvancedVIReverseDiffExt

if isdefined(Base, :get_extension)
using AdvancedVI: AdvancedVI, ADTypes, DiffResults, Distributions
using ReverseDiff: ReverseDiff
else
using ..AdvancedVI: AdvancedVI, ADTypes, DiffResults, Distributions
using ..ReverseDiff: ReverseDiff
end

AdvancedVI.ADBackend(::Val{:reversediff}) = ADTypes.AutoReverseDiff()

function AdvancedVI.setadbackend(::Val{:reversediff})
Base.depwarn("`setadbackend` is deprecated. Please pass a `ADTypes.AbstractADType` as a keyword argument to the VI algorithm.", :setadbackend)
AdvancedVI.ADBACKEND[] = :reversediff
end

tape(f, x) = ReverseDiff.GradientTape(f, x)

function AdvancedVI.grad!(
vo,
alg::AdvancedVI.VariationalInference{<:ADTypes.AutoReverseDiff},
q,
model,
θ::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
args...
)
f(θ) =
if (q isa Distributions.Distribution)
-vo(alg, AdvancedVI.update(q, θ), model, args...)
else
-vo(alg, q(θ), model, args...)
end
tp = tape(f, θ)
ReverseDiff.gradient!(out, tp, θ)
return out
end

end
39 changes: 39 additions & 0 deletions ext/AdvancedVIZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
module AdvancedVIZygoteExt

if isdefined(Base, :get_extension)
using AdvancedVI: AdvancedVI, ADTypes, DiffResults, Distributions
using Zygote: Zygote
else
using ..AdvancedVI: AdvancedVI, ADTypes, DiffResults, Distributions
using ..Zygote: Zygote
end

AdvancedVI.ADBackend(::Val{:zygote}) = ADTypes.AutoZygote()
function AdvancedVI.setadbackend(::Val{:zygote})
Base.depwarn("`setadbackend` is deprecated. Please pass a `ADTypes.AbstractADType` as a keyword argument to the VI algorithm.", :setadbackend)
AdvancedVI.ADBACKEND[] = :zygote
end

function AdvancedVI.grad!(
vo,
alg::AdvancedVI.VariationalInference{<:ADTypes.AutoZygote},
q,
model,
θ::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
args...
)
f(θ) =
if (q isa Distributions.Distribution)
-vo(alg, AdvancedVI.update(q, θ), model, args...)
else
-vo(alg, q(θ), model, args...)
end
y, back = Zygote.pullback(f, θ)
dy = first(back(1.0))
DiffResults.value!(out, y)
DiffResults.gradient!(out, dy)
return out
end

end

0 comments on commit 4e35eab

Please sign in to comment.