diff --git a/src/Bijectors.jl b/src/Bijectors.jl index e435ebf7..6991d446 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -243,8 +243,6 @@ end include("interface.jl") -# Broadcasting here breaks Tracker for some reason -maporbroadcast(f, x::AbstractArray{<:Any, N}...) where {N} = map(f, x...) maporbroadcast(f, x::AbstractArray...) = f.(x...) # optional dependencies diff --git a/src/compat/reversediff.jl b/src/compat/reversediff.jl index 61d61b03..9413cc35 100644 --- a/src/compat/reversediff.jl +++ b/src/compat/reversediff.jl @@ -1,7 +1,7 @@ module ReverseDiffCompat using ..ReverseDiff: ReverseDiff, @grad, value, track, TrackedReal, TrackedVector, - TrackedMatrix + TrackedMatrix, TrackedArray using Requires, LinearAlgebra using ..Bijectors: Log, SimplexBijector, maphcat, simplex_link_jacobian, @@ -46,6 +46,8 @@ function Base.maximum(d::LocationScale{<:TrackedReal}) end end +maporbroadcast(f, x::Union{AbstractArray, TrackedArray, AbstractArray{<:TrackedReal}}...) = f.(x...) + logabsdetjac(b::Log{1}, x::Union{TrackedVector, TrackedMatrix}) = track(logabsdetjac, b, x) @grad function logabsdetjac(b::Log{1}, x::AbstractVector) return -sum(log, value(x)), Δ -> (nothing, -Δ ./ value(x)) diff --git a/src/compat/tracker.jl b/src/compat/tracker.jl index cb8d6549..789d4f4c 100644 --- a/src/compat/tracker.jl +++ b/src/compat/tracker.jl @@ -12,6 +12,8 @@ using .Tracker: Tracker, using Compat: eachcol using LinearAlgebra +# Broadcasting here breaks Tracker for some reason +maporbroadcast(f, x::Union{AbstractArray, TrackedArray, AbstractArray{<:TrackedReal}}...) = map(f, x...) maporbroadcast(f, x::TrackedArray...) = f.(x...) function maporbroadcast( f, diff --git a/test/ad/distributions.jl b/test/ad/distributions.jl index 7f017fe0..264ea727 100644 --- a/test/ad/distributions.jl +++ b/test/ad/distributions.jl @@ -550,4 +550,10 @@ ) end end + + @testset "Turing issue 1385" begin + dist = arraydist(truncated.(Laplace.(0, [1, 2]), -10.0, 70.0)) + x = ReverseDiff.track(rand(dist)) + @test typeof(bijector(dist)(x)) <: ReverseDiff.TrackedArray + end end