Skip to content

Commit

Permalink
fix Tracker + ReverseDiff case
Browse files Browse the repository at this point in the history
  • Loading branch information
mohamed82008 committed Oct 10, 2020
1 parent c238ce6 commit b53c8a4
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 17 deletions.
2 changes: 0 additions & 2 deletions src/compat/reversediff.jl
Expand Up @@ -46,8 +46,6 @@ function Base.maximum(d::LocationScale{<:TrackedReal})
end
end

maporbroadcast(f, x::Union{AbstractArray, TrackedArray, AbstractArray{<:TrackedReal}}...) = f.(x...)

logabsdetjac(b::Log{1}, x::Union{TrackedVector, TrackedMatrix}) = track(logabsdetjac, b, x)
@grad function logabsdetjac(b::Log{1}, x::AbstractVector)
return -sum(log, value(x)), Δ -> (nothing, -Δ ./ value(x))
Expand Down
20 changes: 11 additions & 9 deletions src/compat/tracker.jl
Expand Up @@ -12,16 +12,18 @@ using .Tracker: Tracker,
using Compat: eachcol
using LinearAlgebra

# Broadcasting here breaks Tracker for some reason
maporbroadcast(f, x::Union{AbstractArray, TrackedArray, AbstractArray{<:TrackedReal}}...) = map(f, x...)
# Broadcasting here breaks Tracker
const TrackedT = Union{TrackedArray, AbstractArray{<:TrackedReal}}
maporbroadcast(f, x::TrackedArray...) = f.(x...)
function maporbroadcast(
f,
x1::TrackedArray{T, N},
x::AbstractArray{<:TrackedReal}...,
) where {T, N}
return f.(convert(Array{TrackedReal{T}, N}, x1), x...)
end
maporbroadcast(f, x::Union{TrackedArray, AbstractArray{<:TrackedReal}}...) = map(f, x...)
maporbroadcast(f, x1::TrackedT, x2::AbstractArray) = map(f, x1, x2)
maporbroadcast(f, x1::AbstractArray, x2::TrackedT) = map(f, x1, x2)
maporbroadcast(f, x1::TrackedT, x2::AbstractArray, x3::AbstractArray) = map(f, x1, x2, x3)
maporbroadcast(f, x1::AbstractArray, x2::TrackedT, x3::AbstractArray) = map(f, x1, x2, x3)
maporbroadcast(f, x1::AbstractArray, x2::AbstractArray, x3::TrackedT) = map(f, x1, x2, x3)
maporbroadcast(f, x1::TrackedT, x2::TrackedT, x3::AbstractArray) = map(f, x1, x2, x3)
maporbroadcast(f, x1::AbstractArray, x2::TrackedT, x3::TrackedT) = map(f, x1, x2, x3)
maporbroadcast(f, x1::TrackedT, x2::AbstractArray, x3::TrackedT) = map(f, x1, x2, x3)

_eps(::Type{<:TrackedReal{T}}) where {T} = _eps(T)
function Base.minimum(d::LocationScale{<:TrackedReal})
Expand Down
6 changes: 0 additions & 6 deletions test/ad/distributions.jl
Expand Up @@ -550,10 +550,4 @@
)
end
end

@testset "Turing issue 1385" begin
dist = arraydist(truncated.(Laplace.(0, [1, 2]), -10.0, 70.0))
x = ReverseDiff.track(rand(dist))
@test typeof(bijector(dist)(x)) <: ReverseDiff.TrackedArray
end
end
7 changes: 7 additions & 0 deletions test/runtests.jl
Expand Up @@ -35,5 +35,12 @@ end

if !is_TRAVIS && (GROUP == "All" || GROUP == "AD")
include("ad/distributions.jl")
if AD == "ReverseDiff"
@testset "Turing issue 1385" begin
dist = arraydist(truncated.(Laplace.(0, [1, 2]), -10.0, 70.0))
x = ReverseDiff.track(rand(dist))
@test typeof(bijector(dist)(x)) <: ReverseDiff.TrackedArray
end
end
end

0 comments on commit b53c8a4

Please sign in to comment.