Skip to content

Commit

Permalink
ReverseDiff support (#1170)
Browse files Browse the repository at this point in the history
* reversediff support

* compat section update

* test fixes

* deprecate :forward_diff and :reverse_diff

* deprecation comment fix

* allow the caching of the compiled tape

* fix caching and test it

* use depwarn

* add one more caching test

* make Memoization an optional dep

* add type annotations to _get!
  • Loading branch information
mohamed82008 committed Mar 24, 2020
1 parent bcd5ab9 commit e946541
Show file tree
Hide file tree
Showing 14 changed files with 258 additions and 63 deletions.
10 changes: 6 additions & 4 deletions Project.toml
Expand Up @@ -35,10 +35,10 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
AbstractMCMC = "0.5.2"
AdvancedHMC = "0.2.20"
AdvancedMH = "0.4"
Bijectors = "0.6.2"
Bijectors = "0.6.4"
BinaryProvider = "0.5.6"
Distributions = "0.22"
DistributionsAD = "0.4.3"
Distributions = "0.22, 0.23"
DistributionsAD = "0.4.8"
DocStringExtensions = "0.8"
DynamicPPL = "0.4"
EllipticalSliceSampling = "0.2"
Expand All @@ -62,12 +62,14 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CmdStan = "593b3428-ca2f-500c-ae53-031589ec8ddd"
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"

[targets]
test = ["Pkg", "TerminalLoggers", "Test", "UnicodePlots", "StatsBase", "FiniteDifferences", "DynamicHMC", "CmdStan", "BenchmarkTools", "Zygote"]
test = ["Pkg", "TerminalLoggers", "Test", "UnicodePlots", "StatsBase", "FiniteDifferences", "DynamicHMC", "CmdStan", "BenchmarkTools", "Zygote", "ReverseDiff", "Memoization"]
8 changes: 4 additions & 4 deletions benchmarks/benchmarks.jl
Expand Up @@ -71,10 +71,10 @@ A = rand(Wishart(dim2, Matrix{Float64}(I, dim2, dim2)));
d = MvNormal(zeros(dim2), A)

# ForwardDiff
Turing.setadbackend(:forward_diff)
BenchmarkSuite["mnormal"]["forward_diff"] = @benchmarkable sample(mdemo($d, 1), HMC(0.1, 5), 5000)
Turing.setadbackend(:forwarddiff)
BenchmarkSuite["mnormal"]["forwarddiff"] = @benchmarkable sample(mdemo($d, 1), HMC(0.1, 5), 5000)


# BackwardDiff
Turing.setadbackend(:reverse_diff)
BenchmarkSuite["mnormal"]["reverse_diff"] = @benchmarkable sample(mdemo($d, 1), HMC(0.1, 5), 5000)
Turing.setadbackend(:reversediff)
BenchmarkSuite["mnormal"]["reversediff"] = @benchmarkable sample(mdemo($d, 1), HMC(0.1, 5), 5000)
5 changes: 2 additions & 3 deletions docs/src/using-turing/autodiff.md
Expand Up @@ -8,10 +8,9 @@ title: Automatic Differentiation
## Switching AD Modes


Turing supports two types of automatic differentiation (AD) in the back end during sampling. The current default AD mode is [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl), but Turing also supports [Tracker](https://github.com/FluxML/Tracker.jl)-based differentation.
Turing supports four packages of automatic differentiation (AD) in the back end during sampling. The default AD backend is [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl) for forward-mode AD. Three reverse-mode AD backends are also supported, namely [Tracker](https://github.com/FluxML/Tracker.jl), [Zygote](https://github.com/FluxML/Zygote.jl) and [ReverseDiff](https://github.com/JuliaDiff/ReverseDiff.jl). `Zygote` and `ReverseDiff` are supported optionally if explicitly loaded by the user with `using Zygote` or `using ReverseDiff` next to `using Turing`.


To switch between `ForwardDiff` and `Tracker`, one can call function `Turing.setadbackend(backend_sym)`, where `backend_sym` can be `:forward_diff` or `:reverse_diff`.
To switch between the different AD backends, one can call function `Turing.setadbackend(backend_sym)`, where `backend_sym` can be `:forwarddiff` (`ForwardDiff`), `:tracker` (`Tracker`), `:zygote` (`Zygote`) or `:reversediff` (`ReverseDiff.jl`). When using `ReverseDiff`, to compile the tape only once and cache it for later use, the user needs to load [Memoization.jl](https://github.com/marius311/Memoization.jl) first with `using Memoization` then call `Turing.setcache(true)`. However, note that the use of caching in certain types of models can lead to incorrect results and/or errors. Models for which the compiled tape can be safely cached are models with fixed size loops and no run-time if statements. Compile-time if statements are fine.


## Compositional Sampling with Differing AD Modes
Expand Down
3 changes: 1 addition & 2 deletions docs/src/using-turing/guide.md
Expand Up @@ -448,8 +448,7 @@ ForwardDiff (Turing's default AD backend) uses forward-mode chunk-wise AD. The c
#### AD Backend


Since [#428](https://github.com/TuringLang/Turing.jl/pull/428), Turing.jl supports `Tracker` as backend for reverse mode autodiff. To switch between `ForwardDiff.jl` and `Tracker`, one can call function `setadbackend(backend_sym)`, where `backend_sym` can be `:forward_diff` or `:reverse_diff`.

Turing supports four packages of automatic differentiation (AD) in the back end during sampling. The default AD backend is [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl) for forward-mode AD. Three reverse-mode AD backends are also supported, namely [Tracker](https://github.com/FluxML/Tracker.jl), [Zygote](https://github.com/FluxML/Zygote.jl) and [ReverseDiff](https://github.com/JuliaDiff/ReverseDiff.jl). `Zygote` and `ReverseDiff` are supported optionally if explicitly loaded by the user with `using Zygote` or `using ReverseDiff` next to `using Turing`.

For more information on Turing's automatic differentiation backend, please see the [Automatic Differentiation]({{site.baseurl}}/docs/using-turing/autodiff) article.

Expand Down
11 changes: 6 additions & 5 deletions docs/src/using-turing/performancetips.md
Expand Up @@ -34,14 +34,15 @@ end

## Choose your AD backend
Turing currently provides support for two different automatic differentiation (AD) backends.
Generally, try to use `:forward_diff` for models with few parameters and `:reverse_diff` for models with large parameter vectors or linear algebra operations. See [Automatic Differentiation](autodiff) for details.
Generally, try to use `:forwarddiff` for models with few parameters and `:reversediff`, `:tracker` or `:zygote` for models with large parameter vectors or linear algebra operations. See [Automatic Differentiation](autodiff) for details.


## Special care for `reverse_diff`
## Special care for `:tracker` and `:zygote`

In case of `reverse_diff` it is necessary to avoid loops for now.
This is mainly due to the reverse-mode AD backend `Tracker` which is inefficient for such cases.
Therefore, it is often recommended to write a [custom distribution](advanced) which implements a multivariate version of the prior distribution.
In case of `:tracker` and `:zygote`, it is necessary to avoid loops for now.
This is mainly due to the reverse-mode AD backends `Tracker` and `Zygote` which are inefficient for such cases. `ReverseDiff` does better but vectorized operations will still perform better.

Avoiding loops can be done using `filldist(dist, N)` and `arraydist(dists)`. `filldist(dist, N)` creates a multivariate distribution that is composed of `N` identical and independent copies of the univariate distribution `dist` if `dist` is univariate, or it creates a matrix-variate distribution composed of `N` identical and idependent copies of the multivariate distribution `dist` if `dist` is multivariate. `filldist(dist, N, M)` can also be used to create a matrix-variate distribution from a univariate distribution `dist`. `arraydist(dists)` is similar to `filldist` but it takes an array of distributions `dists` as input. Writing a [custom distribution](advanced) with a custom adjoint is another option to avoid loops.


## Make your model type-stable
Expand Down
12 changes: 9 additions & 3 deletions src/core/Core.jl
Expand Up @@ -17,9 +17,15 @@ using Requires

include("container.jl")
include("ad.jl")
@init @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
include("compat/zygote.jl")
export ZygoteAD
function __init__()
@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
include("compat/zygote.jl")
export ZygoteAD
end
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
include("compat/reversediff.jl")
export ReverseDiffAD, setcache
end
end

export @model,
Expand Down
16 changes: 10 additions & 6 deletions src/core/ad.jl
@@ -1,14 +1,18 @@
##############################
# Global variables/constants #
##############################
const ADBACKEND = Ref(:forward_diff)
const ADBACKEND = Ref(:forwarddiff)
setadbackend(backend_sym::Symbol) = setadbackend(Val(backend_sym))
function setadbackend(::Val{:forward_diff})
Base.depwarn("`Turing.setadbackend(:forward_diff)` is deprecated. Please use `Turing.setadbackend(:forwarddiff)` to use `ForwardDiff`.", :setadbackend)
setadbackend(Val(:forwarddiff))
end
function setadbackend(::Val{:forwarddiff})
CHUNKSIZE[] == 0 && setchunksize(40)
ADBACKEND[] = :forward_diff
ADBACKEND[] = :forwarddiff
end
function setadbackend(::Val{:reverse_diff})
ADBACKEND[] = :reverse_diff
function setadbackend(::Val{:tracker})
ADBACKEND[] = :tracker
end

const ADSAFE = Ref(false)
Expand Down Expand Up @@ -37,8 +41,8 @@ struct TrackerAD <: ADBackend end
ADBackend() = ADBackend(ADBACKEND[])
ADBackend(T::Symbol) = ADBackend(Val(T))

ADBackend(::Val{:forward_diff}) = ForwardDiffAD{CHUNKSIZE[]}
ADBackend(::Val{:reverse_diff}) = TrackerAD
ADBackend(::Val{:forwarddiff}) = ForwardDiffAD{CHUNKSIZE[]}
ADBackend(::Val{:tracker}) = TrackerAD
ADBackend(::Val) = error("The requested AD backend is not available. Make sure to load all required packages.")

"""
Expand Down
83 changes: 83 additions & 0 deletions src/core/compat/reversediff.jl
@@ -0,0 +1,83 @@
using .ReverseDiff: compile, GradientTape
using .ReverseDiff.DiffResults: GradientResult

struct ReverseDiffAD{cache} <: ADBackend end
const RDCache = Ref(false)
setcache(b::Bool) = RDCache[] = b
getcache() = RDCache[]
ADBackend(::Val{:reversediff}) = ReverseDiffAD{getcache()}
function setadbackend(::Val{:reverse_diff})
Base.depwarn("`Turing.setadbackend(:reverse_diff)` is deprecated. Please use `Turing.setadbackend(:tracker)` to use `Tracker` or `Turing.setadbackend(:reversediff)` to use `ReverseDiff`. To use `ReverseDiff`, please make sure it is loaded separately with `using ReverseDiff`.", :setadbackend)
setadbackend(Val(:reversediff))
end
function setadbackend(::Val{:reversediff})
ADBACKEND[] = :reversediff
end

function gradient_logp(
backend::ReverseDiffAD{false},
θ::AbstractVector{<:Real},
vi::VarInfo,
model::Model,
sampler::AbstractSampler = SampleFromPrior(),
)
T = typeof(getlogp(vi))

# Specify objective function.
function f(θ)
new_vi = VarInfo(vi, sampler, θ)
return getlogp(runmodel!(model, new_vi, sampler))
end
tp, result = taperesult(f, θ)
ReverseDiff.gradient!(result, tp, θ)
l = DiffResults.value(result)
∂l∂θ = DiffResults.gradient(result)

return l, ∂l∂θ
end

tape(f, x) = GradientTape(f, x)
function taperesult(f, x)
return tape(f, x), GradientResult(x)
end

@require Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73" @eval begin
function gradient_logp(
backend::ReverseDiffAD{true},
θ::AbstractVector{<:Real},
vi::VarInfo,
model::Model,
sampler::AbstractSampler = SampleFromPrior(),
)
T = typeof(getlogp(vi))

# Specify objective function.
function f(θ)
new_vi = VarInfo(vi, sampler, θ)
return getlogp(runmodel!(model, new_vi, sampler))
end
ctp, result = memoized_taperesult(f, θ)
ReverseDiff.gradient!(result, ctp, θ)
l = DiffResults.value(result)
∂l∂θ = DiffResults.gradient(result)

return l, ∂l∂θ
end

# This makes sure we generate a single tape per Turing model and sampler
struct RDTapeKey{F, Tx}
f::F
x::Tx
end
function Memoization._get!(f::Union{Function, Type}, d::IdDict, keys::Tuple{Tuple{RDTapeKey}, Nothing})
key = keys[1][1]
return Memoization._get!(f, d, (typeof(key.f), typeof(key.x), size(key.x)))
end
memoized_taperesult(f, x) = memoized_taperesult(RDTapeKey(f, x))
Memoization.@memoize function memoized_taperesult(k::RDTapeKey)
return compiledtape(k.f, k.x), GradientResult(k.x)
end
memoized_tape(f, x) = memoized_tape(RDTapeKey(f, x))
Memoization.@memoize memoized_tape(k::RDTapeKey) = compiledtape(k.f, k.x)
compiledtape(f, x) = compile(GradientTape(f, x))
end
45 changes: 44 additions & 1 deletion src/variational/VariationalInference.jl
Expand Up @@ -38,10 +38,53 @@ function __init__()
else
- vo(alg, q(θ), model, args...)
end
out .= Zygote.gradient(f, θ)
y, back = Tracker.pullback(f, θ)
dy = back(1.0)
DiffResults.value!(out, y)
DiffResults.gradient!(out, dy)
return out
end
end
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
function Variational.grad!(
vo,
alg::VariationalInference{<:Turing.ReverseDiffAD{false}},
q,
model,
θ::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
args...
)
f(θ) = if (q isa VariationalPosterior)
- vo(alg, update(q, θ), model, args...)
else
- vo(alg, q(θ), model, args...)
end
tp = Turing.Core.tape(f, θ)
ReverseDiff.gradient!(out, tp, θ)
return out
end
@require Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73" begin
function Variational.grad!(
vo,
alg::VariationalInference{<:Turing.ReverseDiffAD{true}},
q,
model,
θ::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
args...
)
f(θ) = if (q isa VariationalPosterior)
- vo(alg, update(q, θ), model, args...)
else
- vo(alg, q(θ), model, args...)
end
ctp = Turing.Core.memoized_tape(f, θ)
ReverseDiff.gradient!(out, ctp, θ)
return out
end
end
end
end

export
Expand Down
28 changes: 19 additions & 9 deletions test/core/ad.jl
Expand Up @@ -96,7 +96,7 @@ _to_cov(B) = B * B' + Matrix(I, size(B)...)

test_model_ad(wishart_ad(), logp3, [:v])
end
@turing_testset "Tracker & Zygote + logdet" begin
@turing_testset "Tracker, Zygote and ReverseDiff + logdet" begin
rng, N = MersenneTwister(123456), 7
ȳ, B = randn(rng), randn(rng, N, N)
test_reverse_mode_ad(B->logdet(cholesky(_to_cov(B))), ȳ, B; rtol=1e-8, atol=1e-6)
Expand All @@ -107,7 +107,7 @@ _to_cov(B) = B * B' + Matrix(I, size(B)...)
test_reverse_mode_ad(x->fill(x, 7, 11), randn(rng, 7, 11), randn(rng))
test_reverse_mode_ad(x->fill(x, 7, 11, 13), rand(rng, 7, 11, 13), randn(rng))
end
@turing_testset "Tracker & Zygote + MvNormal" begin
@turing_testset "Tracker, Zygote and ReverseDiff + MvNormal" begin
rng, N = MersenneTwister(123456), 11
B = randn(rng, N, N)
m, A = randn(rng, N), B' * B + I
Expand All @@ -122,7 +122,7 @@ _to_cov(B) = B * B' + Matrix(I, size(B)...)

test_reverse_mode_ad((m, B, x)->logpdf(MvNormal(m, _to_cov(B)), x), randn(rng), m, B, x)
end
@turing_testset "Tracker & Zygote + Diagonal Normal" begin
@turing_testset "Tracker, Zygote and ReverseDiff + Diagonal Normal" begin
rng, N = MersenneTwister(123456), 11
m, σ = randn(rng, N), exp.(0.1 .* randn(rng, N)) .+ 1

Expand All @@ -135,7 +135,7 @@ _to_cov(B) = B * B' + Matrix(I, size(B)...)

test_reverse_mode_ad((m, σ, x)->logpdf(MvNormal(m, σ), x), randn(rng), m, σ, x)
end
@turing_testset "Tracker & Zygote + MvNormal Interface" begin
@turing_testset "Tracker, Zygote and ReverseDiff + MvNormal Interface" begin
# Note that we only test methods where the `MvNormal` ctor actually constructs
# a TuringDenseMvNormal.

Expand Down Expand Up @@ -266,29 +266,39 @@ _to_cov(B) = B * B' + Matrix(I, size(B)...)
)
test_reverse_mode_ad(b->logpdf(MvNormal(N, exp(b)), x), randn(rng), randn(rng))
end
@testset "Simplex Tracker & Zygote AD" begin
@testset "Simplex Tracker, Zygote and ReverseDiff (with and without caching) AD" begin
@model dir() = begin
theta ~ Dirichlet(1 ./ fill(4, 4))
end
Turing.setadbackend(:reverse_diff)
Turing.setadbackend(:tracker)
sample(dir(), HMC(0.01, 1), 1000);
Turing.setadbackend(:zygote)
sample(dir(), HMC(0.01, 1), 1000);
Turing.setadbackend(:reversediff)
Turing.setcache(false)
sample(dir(), HMC(0.01, 1), 1000);
Turing.setcache(true)
sample(dir(), HMC(0.01, 1), 1000);
end
@testset "PDMatDistribution Tracker AD" begin
# FIXME: For some reasons PDMatDistribution AD tests fail with ReverseDiff
@testset "PDMatDistribution AD" begin
@model wishart() = begin
theta ~ Wishart(4, Matrix{Float64}(I, 4, 4))
end
Turing.setadbackend(:reverse_diff)
Turing.setadbackend(:tracker)
sample(wishart(), HMC(0.01, 1), 1000);
#Turing.setadbackend(:reversediff)
#sample(wishart(), HMC(0.01, 1), 1000);
Turing.setadbackend(:zygote)
sample(wishart(), HMC(0.01, 1), 1000);

@model invwishart() = begin
theta ~ InverseWishart(4, Matrix{Float64}(I, 4, 4))
end
Turing.setadbackend(:reverse_diff)
Turing.setadbackend(:tracker)
sample(invwishart(), HMC(0.01, 1), 1000);
#Turing.setadbackend(:reversediff)
#sample(invwishart(), HMC(0.01, 1), 1000);
Turing.setadbackend(:zygote)
sample(invwishart(), HMC(0.01, 1), 1000);
end
Expand Down

0 comments on commit e946541

Please sign in to comment.