Skip to content

Commit

Permalink
Merge 9054ab3 into ab4c02c
Browse files Browse the repository at this point in the history
  • Loading branch information
mohamed82008 committed Apr 9, 2020
2 parents ab4c02c + 9054ab3 commit e008172
Show file tree
Hide file tree
Showing 4 changed files with 271 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Expand All @@ -19,6 +20,7 @@ DiffResults = "1"
DiffRules = "0.1, 1"
ForwardDiff = "0.10"
FunctionWrappers = "1"
MacroTools = "0.5"
NaNMath = "0.3"
SpecialFunctions = "0.8, 0.9, 0.10"
StaticArrays = "0.10, 0.11, 0.12"
Expand Down
2 changes: 2 additions & 0 deletions src/ReverseDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ using ForwardDiff
using ForwardDiff: Dual, Partials
using StaticArrays

using MacroTools

# Not all operations will be valid over all of these types, but that's okay; such cases
# will simply error when they hit the original operation in the overloaded definition.
const ARRAY_TYPES = (:AbstractArray, :AbstractVector, :AbstractMatrix, :Array, :Vector, :Matrix)
Expand Down
122 changes: 122 additions & 0 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,125 @@ end
@inline (self::SkipOptimize{F})(a) where {F} = self.f(value(a))
@inline (self::SkipOptimize{F})(a, b) where {F} = self.f(value(a), value(b))
@inline (self::SkipOptimize{F})(a, b, c) where {F} = self.f(value(a), value(b), value(c))

"""
f(x) = dot(x, x)
f(x::ReverseDiff.TrackedVector) = ReverseDiff.track(f, x)
ReverseDiff.@grad function f(x)
xv = ReverseDiff.value(x)
return dot(xv, xv), Δ -> (Δ * 2 * xv,)
end
The `@grad` macro provides a way for the users to define custom adjoints for single-output functions wrt to their input numbers or arrays.
"""
macro grad(expr)
d = MacroTools.splitdef(expr)
f = d[:name]
closure = gensym(f)
d[:name] = closure
closure_ex = MacroTools.combinedef(d)

@gensym tp output_value output back args kwargs
args_ex = getargs_expr(d[:args])
kwargs_ex = getkwargs_expr(d[:kwargs])
return quote
function $ReverseDiff.track(::typeof($f), $(d[:args]...); $(d[:kwargs]...)) where {$(d[:whereparams]...),}
$closure_ex
$args = $args_ex
$kwargs = $kwargs_ex
$tp = $ReverseDiff.tape($args...)
$output_value, $back = $closure($args...; $kwargs...)
$output = $ReverseDiff.track($output_value, $tp)
$ReverseDiff.record!(
$tp,
$ReverseDiff.SpecialInstruction,
$f,
$args,
$output,
($back, $closure, $kwargs),
)
return $output
end

if !hasmethod(
$ReverseDiff.special_reverse_exec!,
Tuple{$ReverseDiff.SpecialInstruction{typeof($f)}},
)
@noinline function $ReverseDiff.special_reverse_exec!(instruction::$ReverseDiff.SpecialInstruction{typeof($f)})
output = instruction.output
input = instruction.input
back = instruction.cache[1]
input_derivs = back($ReverseDiff.deriv(output))
@assert input_derivs isa Tuple
$ReverseDiff.add_to_deriv!.(input, input_derivs)
$ReverseDiff.unseed!(output)
return nothing
end
end

if !hasmethod(
$ReverseDiff.special_forward_exec!,
Tuple{$ReverseDiff.SpecialInstruction{typeof($f)}},
)
@noinline function $ReverseDiff.special_forward_exec!(instruction::$ReverseDiff.SpecialInstruction{typeof($f)})
output, input = instruction.output, instruction.input
$ReverseDiff.pull_value!.(input)
pullback = instruction.cache[2]
kwargs = instruction.cache[3]
out_value = pullback(input...; kwargs...)[1]
$ReverseDiff.value!(output, out_value)
return nothing
end
end
end |> esc
end
add_to_deriv!(d1, d2) = nothing
function add_to_deriv!(d1::Union{TrackedReal, TrackedArray}, d2)
increment_deriv!(d1, d2)
end
function getargs_expr(args_with_types)
expr = Expr(:tuple)
for at in args_with_types
x, tosplat = remove_tp(at)
if tosplat
push!(expr.args, :($x...))
else
push!(expr.args, x)
end
end
return expr
end
function getkwargs_expr(kwargs_with_types)
syms = []
final = nothing
for at in kwargs_with_types
final isa Nothing || throw("Invalid kwargs.")
x, tosplat = remove_tp(at)
if tosplat
final = x
else
push!(syms, x)
end
end
expr = length(syms) == 0 ? :(NamedTuple()) : Expr(:tuple, [:($f = $f) for f in syms]...)
final = final == nothing ? :(NamedTuple()) : final
return :(Base.merge($expr, $final))
end
function remove_tp(t)
if @capture(t, X_::T_...)
return X, true
elseif @capture(t, X_::T_)
return X, false
elseif @capture(t, X_::T_ = V_)
return X, false
elseif @capture(t, ::typeof(T_)...)
return T, true
elseif @capture(t, ::typeof(T_))
return T, false
elseif @capture(t, X_...)
return X, true
elseif @capture(t, X_ = V_)
return X, false
else
return t, false
end
end
145 changes: 145 additions & 0 deletions test/MacrosTests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,4 +181,149 @@ ReverseDiff.@skip g6 = (a, b) -> sqrt(a^2 + b^2)
test_println("@skip anonymous functions", g6)
test_skip(g6, a, b, tp)

#########
# @grad #
#########

using LinearAlgebra
using ReverseDiff: @grad, TrackedReal, TrackedVector, TrackedMatrix, TrackedArray

@testset "@grad macro" begin
x = rand(3);
A = rand(3, 3);
A_x = [vec(A); x];
global custom_grad_called

f1(x) = dot(x, x)
f1(x::TrackedVector) = ReverseDiff.track(f1, x)
@grad function f1(x::AbstractVector)
global custom_grad_called = true
xv = ReverseDiff.value(x)
dot(xv, xv), Δ ->* 2 * xv,)
end

custom_grad_called = false
g1 = ReverseDiff.gradient(f1, x)
g2 = ReverseDiff.gradient(x -> dot(x, x), x)
@test g1 == g2
@test custom_grad_called

f2(A, x) = A * x
f2(A, x::TrackedVector) = ReverseDiff.track(f2, A, x)
f2(A::TrackedMatrix, x) = ReverseDiff.track(f2, A, x)
f2(A::TrackedMatrix, x::TrackedVector) = ReverseDiff.track(f2, A, x)
@grad function f2(A::AbstractMatrix, x::AbstractVector)
global custom_grad_called = true
Av = ReverseDiff.value(A)
xv = ReverseDiff.value(x)
Av * xv, Δ ->* xv', Av' * Δ)
end

custom_grad_called = false
g1 = ReverseDiff.gradient(x -> sum(f2(A, x)), x)
g2 = ReverseDiff.gradient(x -> sum(A * x), x)
@test g1 == g2
@test custom_grad_called

custom_grad_called = false
g1 = ReverseDiff.gradient(A -> sum(f2(A, x)), A)
g2 = ReverseDiff.gradient(A -> sum(A * x), A)
@test g1 == g2
@test custom_grad_called

custom_grad_called = false
g1 = ReverseDiff.gradient(A_x -> sum(f2(reshape(A_x[1:9], 3, 3), A_x[10:end])), A_x)
g2 = ReverseDiff.gradient(A_x -> sum(reshape(A_x[1:9], 3, 3) * A_x[10:end]), A_x)
@test g1 == g2
@test custom_grad_called

f3(A; dims) = sum(A, dims = dims)
f3(A::TrackedMatrix; dims) = ReverseDiff.track(f3, A; dims = dims)
@grad function f3(A::AbstractMatrix; dims)
global custom_grad_called = true
Av = ReverseDiff.value(A)
sum(Av, dims = dims), Δ -> (zero(Av) .+ Δ,)
end
custom_grad_called = false
g1 = ReverseDiff.gradient(A -> sum(f3(A, dims = 1)), A)
g2 = ReverseDiff.gradient(A -> sum(sum(A, dims = 1)), A)
@test g1 == g2
@test custom_grad_called

f4(::typeof(log), A; dims) = sum(log, A, dims = dims)
f4(::typeof(log), A::TrackedMatrix; dims) = ReverseDiff.track(f4, log, A; dims = dims)
@grad function f4(::typeof(log), A::AbstractMatrix; dims)
global custom_grad_called = true
Av = ReverseDiff.value(A)
sum(log, Av, dims = dims), Δ -> (nothing, 1 ./ Av .* Δ)
end
custom_grad_called = false
g1 = ReverseDiff.gradient(A -> sum(f4(log, A, dims = 1)), A)
g2 = ReverseDiff.gradient(A -> sum(sum(log, A, dims = 1)), A)
@test g1 == g2
@test custom_grad_called

f5(x) = log(x)
f5(x::TrackedReal) = ReverseDiff.track(f5, x)
@grad function f5(x::Real)
global custom_grad_called = true
xv = ReverseDiff.value(x)
log(xv), Δ -> (1 / xv * Δ,)
end
custom_grad_called = false
g1 = ReverseDiff.gradient(x -> f5(x[1]) * f5(x[2]) + exp(x[3]), x)
g2 = ReverseDiff.gradient(x -> log(x[1]) * log(x[2]) + exp(x[3]), x)
@test g1 == g2
@test custom_grad_called

f6(x) = sum(x)
f6(x::TrackedArray{<:AbstractFloat}) = ReverseDiff.track(f6, x)
@grad function f6(x::TrackedArray{T}) where {T <: AbstractFloat}
global custom_grad_called = true
xv = ReverseDiff.value(x)
sum(xv), Δ -> (one.(xv) .* Δ,)
end

custom_grad_called = false
g1 = ReverseDiff.gradient(f6, x)
g2 = ReverseDiff.gradient(sum, x)
@test g1 == g2
@test custom_grad_called

x2 = round.(Int, x)
custom_grad_called = false
g1 = ReverseDiff.gradient(f6, x2)
g2 = ReverseDiff.gradient(sum, x2)
@test g1 == g2
@test !custom_grad_called
f6(x::TrackedArray) = ReverseDiff.track(f6, x)
@test_throws MethodError ReverseDiff.gradient(f6, x2)

f7(x...) = +(x...)
f7(x::TrackedReal{<:AbstractFloat}...) = ReverseDiff.track(f7, x...)
@grad function f7(x::TrackedReal{T}...) where {T <: AbstractFloat}
global custom_grad_called = true
xv = ReverseDiff.value.(x)
+(xv...), Δ -> one.(xv) .* Δ
end
custom_grad_called = false
g1 = ReverseDiff.gradient(x -> f7(x...), x)
g2 = ReverseDiff.gradient(sum, x)
@test g1 == g2
@test custom_grad_called

f8(A; kwargs...) = sum(A, kwargs...)
f8(A::TrackedMatrix; kwargs...) = ReverseDiff.track(f8, A; kwargs...)
@grad function f8(A::AbstractMatrix; kwargs...)
global custom_grad_called = true
Av = ReverseDiff.value(A)
sum(Av; kwargs...), Δ -> (zero(Av) .+ Δ,)
end
custom_grad_called = false
g1 = ReverseDiff.gradient(A -> sum(f8(A, dims = 1)), A)
g2 = ReverseDiff.gradient(A -> sum(sum(A, dims = 1)), A)
@test g1 == g2
@test custom_grad_called
end

end # module

0 comments on commit e008172

Please sign in to comment.