diff --git a/Project.toml b/Project.toml index a73f18aa..54738173 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AdvancedVI" uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c" -version = "0.1.3" +version = "0.1.4" [deps] Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 28906f9a..09180388 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -142,10 +142,13 @@ function grad!( - vo(alg, q(θ_), model, args...) end - chunk_size = getchunksize(typeof(alg)) # Set chunk size and do ForwardMode. - chunk = ForwardDiff.Chunk(min(length(θ), chunk_size)) - config = ForwardDiff.GradientConfig(f, θ, chunk) + chunk_size = getchunksize(typeof(alg)) + config = if chunk_size == 0 + ForwardDiff.GradientConfig(f, θ) + else + ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(length(θ), chunk_size)) + end ForwardDiff.gradient!(out, f, θ, config) end diff --git a/src/ad.jl b/src/ad.jl index 59c69cb0..62e785e1 100644 --- a/src/ad.jl +++ b/src/ad.jl @@ -8,7 +8,6 @@ function setadbackend(::Val{:forward_diff}) setadbackend(Val(:forwarddiff)) end function setadbackend(::Val{:forwarddiff}) - CHUNKSIZE[] == 0 && setchunksize(40) ADBACKEND[] = :forwarddiff end @@ -26,13 +25,11 @@ function setadsafe(switch::Bool) ADSAFE[] = switch end -const CHUNKSIZE = Ref(40) # default chunksize used by AD +const CHUNKSIZE = Ref(0) # 0 means letting ForwardDiff set it automatically function setchunksize(chunk_size::Int) - if ~(CHUNKSIZE[] == chunk_size) - @info("[AdvancedVI]: AD chunk size is set as $chunk_size") - CHUNKSIZE[] = chunk_size - end + @info("[AdvancedVI]: AD chunk size is set as $chunk_size") + CHUNKSIZE[] = chunk_size end abstract type ADBackend end