Skip to content

Commit

Permalink
Merge c238ce6 into 3387a40
Browse files Browse the repository at this point in the history
  • Loading branch information
mohamed82008 committed Oct 10, 2020
2 parents 3387a40 + c238ce6 commit ed6fb8e
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 3 deletions.
2 changes: 0 additions & 2 deletions src/Bijectors.jl
Expand Up @@ -243,8 +243,6 @@ end

include("interface.jl")

# Broadcasting here breaks Tracker for some reason
maporbroadcast(f, x::AbstractArray{<:Any, N}...) where {N} = map(f, x...)
maporbroadcast(f, x::AbstractArray...) = f.(x...)

# optional dependencies
Expand Down
4 changes: 3 additions & 1 deletion src/compat/reversediff.jl
@@ -1,7 +1,7 @@
module ReverseDiffCompat

using ..ReverseDiff: ReverseDiff, @grad, value, track, TrackedReal, TrackedVector,
TrackedMatrix
TrackedMatrix, TrackedArray
using Requires, LinearAlgebra

using ..Bijectors: Log, SimplexBijector, maphcat, simplex_link_jacobian,
Expand Down Expand Up @@ -46,6 +46,8 @@ function Base.maximum(d::LocationScale{<:TrackedReal})
end
end

maporbroadcast(f, x::Union{AbstractArray, TrackedArray, AbstractArray{<:TrackedReal}}...) = f.(x...)

logabsdetjac(b::Log{1}, x::Union{TrackedVector, TrackedMatrix}) = track(logabsdetjac, b, x)
@grad function logabsdetjac(b::Log{1}, x::AbstractVector)
return -sum(log, value(x)), Δ -> (nothing, -Δ ./ value(x))
Expand Down
2 changes: 2 additions & 0 deletions src/compat/tracker.jl
Expand Up @@ -12,6 +12,8 @@ using .Tracker: Tracker,
using Compat: eachcol
using LinearAlgebra

# Broadcasting here breaks Tracker for some reason
maporbroadcast(f, x::Union{AbstractArray, TrackedArray, AbstractArray{<:TrackedReal}}...) = map(f, x...)
maporbroadcast(f, x::TrackedArray...) = f.(x...)
function maporbroadcast(
f,
Expand Down
6 changes: 6 additions & 0 deletions test/ad/distributions.jl
Expand Up @@ -550,4 +550,10 @@
)
end
end

@testset "Turing issue 1385" begin
dist = arraydist(truncated.(Laplace.(0, [1, 2]), -10.0, 70.0))
x = ReverseDiff.track(rand(dist))
@test typeof(bijector(dist)(x)) <: ReverseDiff.TrackedArray
end
end

0 comments on commit ed6fb8e

Please sign in to comment.