From b53c8a45d2f7c67935a3d5ef72cbc37e9dfd1284 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sun, 11 Oct 2020 05:10:58 +1100 Subject: [PATCH] fix Tracker + ReverseDiff case --- src/compat/reversediff.jl | 2 -- src/compat/tracker.jl | 20 +++++++++++--------- test/ad/distributions.jl | 6 ------ test/runtests.jl | 7 +++++++ 4 files changed, 18 insertions(+), 17 deletions(-) diff --git a/src/compat/reversediff.jl b/src/compat/reversediff.jl index 9413cc35..d9fcb587 100644 --- a/src/compat/reversediff.jl +++ b/src/compat/reversediff.jl @@ -46,8 +46,6 @@ function Base.maximum(d::LocationScale{<:TrackedReal}) end end -maporbroadcast(f, x::Union{AbstractArray, TrackedArray, AbstractArray{<:TrackedReal}}...) = f.(x...) - logabsdetjac(b::Log{1}, x::Union{TrackedVector, TrackedMatrix}) = track(logabsdetjac, b, x) @grad function logabsdetjac(b::Log{1}, x::AbstractVector) return -sum(log, value(x)), Δ -> (nothing, -Δ ./ value(x)) diff --git a/src/compat/tracker.jl b/src/compat/tracker.jl index 789d4f4c..f0dea430 100644 --- a/src/compat/tracker.jl +++ b/src/compat/tracker.jl @@ -12,16 +12,18 @@ using .Tracker: Tracker, using Compat: eachcol using LinearAlgebra -# Broadcasting here breaks Tracker for some reason -maporbroadcast(f, x::Union{AbstractArray, TrackedArray, AbstractArray{<:TrackedReal}}...) = map(f, x...) +# Broadcasting here breaks Tracker +const TrackedT = Union{TrackedArray, AbstractArray{<:TrackedReal}} maporbroadcast(f, x::TrackedArray...) = f.(x...) -function maporbroadcast( - f, - x1::TrackedArray{T, N}, - x::AbstractArray{<:TrackedReal}..., -) where {T, N} - return f.(convert(Array{TrackedReal{T}, N}, x1), x...) -end +maporbroadcast(f, x::Union{TrackedArray, AbstractArray{<:TrackedReal}}...) = map(f, x...) +maporbroadcast(f, x1::TrackedT, x2::AbstractArray) = map(f, x1, x2) +maporbroadcast(f, x1::AbstractArray, x2::TrackedT) = map(f, x1, x2) +maporbroadcast(f, x1::TrackedT, x2::AbstractArray, x3::AbstractArray) = map(f, x1, x2, x3) +maporbroadcast(f, x1::AbstractArray, x2::TrackedT, x3::AbstractArray) = map(f, x1, x2, x3) +maporbroadcast(f, x1::AbstractArray, x2::AbstractArray, x3::TrackedT) = map(f, x1, x2, x3) +maporbroadcast(f, x1::TrackedT, x2::TrackedT, x3::AbstractArray) = map(f, x1, x2, x3) +maporbroadcast(f, x1::AbstractArray, x2::TrackedT, x3::TrackedT) = map(f, x1, x2, x3) +maporbroadcast(f, x1::TrackedT, x2::AbstractArray, x3::TrackedT) = map(f, x1, x2, x3) _eps(::Type{<:TrackedReal{T}}) where {T} = _eps(T) function Base.minimum(d::LocationScale{<:TrackedReal}) diff --git a/test/ad/distributions.jl b/test/ad/distributions.jl index 264ea727..7f017fe0 100644 --- a/test/ad/distributions.jl +++ b/test/ad/distributions.jl @@ -550,10 +550,4 @@ ) end end - - @testset "Turing issue 1385" begin - dist = arraydist(truncated.(Laplace.(0, [1, 2]), -10.0, 70.0)) - x = ReverseDiff.track(rand(dist)) - @test typeof(bijector(dist)(x)) <: ReverseDiff.TrackedArray - end end diff --git a/test/runtests.jl b/test/runtests.jl index 47f36d95..2c17c6f0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -35,5 +35,12 @@ end if !is_TRAVIS && (GROUP == "All" || GROUP == "AD") include("ad/distributions.jl") + if AD == "ReverseDiff" + @testset "Turing issue 1385" begin + dist = arraydist(truncated.(Laplace.(0, [1, 2]), -10.0, 70.0)) + x = ReverseDiff.track(rand(dist)) + @test typeof(bijector(dist)(x)) <: ReverseDiff.TrackedArray + end + end end