Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make autodiff insertion more versatile #181

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 42 additions & 8 deletions docs/src/man/automatic_differentiation.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,22 @@ derivative of a function which is part of a larger function to be
automatically differentiated.

Another use case is when the analytical derivative can be computed much more
efficiently than the automatically differentiatiated derivative.
efficiently than the automatically differentiatiated derivative, for example
when the function value is obtained through an iterative procedure.

```@docs
@implement_gradient
propagate_gradient
```

### Caveats
The method relies on multiple dispatch to run your gradient function instead of
the calling the regular function with dual numbers. Julia will always prefer the
most specific type definition, but it can sometimes be hard to know which is most
specific. Therefore, it is always recommended to test that your gradient function
is called when testing, by e.g. inserting a print statement at the beginning as
in the example below.

### Example
Lets consider the function ``h(\mathbf{f}(\mathbf{g}(\mathbf{x})))``
where `h(x)=norm(x)`, `f(x)=x ⋅ x`, and `g(x)=dev(x)`. For `f(x)` we
Expand All @@ -117,38 +127,62 @@ then have the analytical derivative
\frac{\partial f_{ij}}{\partial x_{kl}} = \delta_{ik} x_{lj} + x_{ik} \delta_{jl}
```
which we can insert into our known analytical derivative using the
`@implement_gradient` macro. Below, we compare with the result when
the full derivative is calculated using automatic differentiation.
`@implement_gradient` macro. Below, we compare with the result when
the full derivative is calculated using automatic differentiation.
The example with `f3` shows the important case when the type of input
to our regular function is specified, then the macro will not work,
and we have to use the function `propagate_gradient` instead, manually
specifying for which type we want to override. This type must be more
specific than the type specified for the regular function, e.g.
if the type specification is `::Tensor{2}` then the special dual type
definition should be at least as specific as
`::Tensor{2,<:Any, <:Tensors.Dual}`.

```jldoctest
# Define functions
h(x) = norm(x)
f1(x) = x ⋅ x
f2(x) = f1(x)
f3(x::Tensor{2}) = f1(x)
g(x) = dev(x)

# Define composed functions
cfun1(x) = h(f1(g(x)))
cfun2(x) = h(f2(g(x)))
cfun3(x) = h(f3(g(x)))

# Define known derivative
function df2dx(x::Tensor{2,dim}) where{dim}
println("Hello from df2dx") # Show that df2dx is called
function dfdx(x::Tensor{2,dim}) where{dim}
println("Hello from dfdx") # Show that dfdx is called
fval = f2(x)
I2 = one(Tensor{2,dim})
dfdx_val = otimesu(I2, transpose(x)) + otimesu(x, I2)
return fval, dfdx_val
end

# Implement known derivative
@implement_gradient f2 df2dx
@implement_gradient f2 dfdx
@implement_gradient f3 dfdx # Doesn't work because `Tensor{2}` specified for f3

# Calculate gradients
x = rand(Tensor{2,2})

gradient(cfun1, x) ≈ gradient(cfun2, x)
println("gradient of cfun2, with hello")
println(gradient(cfun1, x) ≈ gradient(cfun2, x))
println("gradient of cfun3, no hello:")
println(gradient(cfun1, x) ≈ gradient(cfun3, x)) # No "Hello from dfdx" printed

f3(x::Tensor{2,<:Any,<:Tensors.Dual}) = propagate_gradient(dfdx, x)
println("gradient of cfun3, with hello")
println(gradient(cfun1, x) ≈ gradient(cfun3, x))

# output
Hello from df2dx
gradient of cfun2, with hello
Hello from dfdx
true
gradient of cfun3, no hello:
true
gradient of cfun3, with hello
Hello from dfdx
true
```
2 changes: 1 addition & 1 deletion src/Tensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ export otimesu, otimesl
export minortranspose, majortranspose, isminorsymmetric, ismajorsymmetric
export tdot, dott, dotdot
export hessian, gradient, curl, divergence, laplace
export @implement_gradient
export @implement_gradient, propagate_gradient
export basevec, eᵢ
export rotate, rotation_tensor
export tovoigt, tovoigt!, fromvoigt, tomandel, tomandel!, frommandel
Expand Down
31 changes: 29 additions & 2 deletions src/automatic_differentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,26 @@ be of symmetric type

"""
macro implement_gradient(f, f_dfdx)
return :($(esc(f))(x :: Union{AbstractTensor{<:Any, <:Any, <:Dual}, Dual}) = _propagate_gradient($(esc(f_dfdx)), x))
return :($(esc(f))(x :: Union{AbstractTensor{<:Any, <:Any, <:Dual}, Dual}) = propagate_gradient($(esc(f_dfdx)), x))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return :($(esc(f))(x :: Union{AbstractTensor{<:Any, <:Any, <:Dual}, Dual}) = propagate_gradient($(esc(f_dfdx)), x))
return :($(esc(f))(x :: Union{AbstractTensor{<:Any, <:Any, <:Dual}, Dual}, args...) = propagate_gradient($(esc(f_dfdx)), x, args...))

Note from #197 , requires corresponding update to propagate_gradient

end
# which calls the general function _propagate_gradient that calls the specialized _insert_gradient method below
function _propagate_gradient(f_dfdx::Function, x::Union{AbstractTensor{<:Any, <:Any, <:Dual}, Dual})
"""
function propagate_gradient(f_dfdx::Function, x::Union{AbstractTensor{<:Any, <:Any, <:Dual}, Dual})

`propagate_gradient` takes in the function `f_dfdx` that, given a
tensor input calculates the value and gradient of a function `f`
wrt. that input, i.e. `fval, dfdx_val = f_dfdx(y::AbstractTensor)`
`y` should not have dual entries.
`propagate_gradient` is used to override `f` when `f`'s input is
tensor with Dual numbers (i.e. when it is used to calculate a
derivative), and can, for example, be used as follows
`f(x::Tensor{2,<:Any,<:Tensors.Dual}) = propagate_gradient(f_dfdx, x)`
where the key part is that the type of x must be specified to be of
type `Tensors.Dual` (which is equivalent to `ForwardDiff.Dual`)
"""
function propagate_gradient(f_dfdx::Function, x::Union{AbstractTensor{<:Any, <:Any, <:Dual}, Dual})
fval, dfdx_val = f_dfdx(_extract_value(x))
_check_gradient_shape(fval,x,dfdx_val)
return _insert_gradient(fval, dfdx_val, x)
end

Expand Down Expand Up @@ -298,6 +313,18 @@ function _insert_gradient(f::Union{Number,AbstractTensor}, dfdg::Union{Number,Ab
return _insert_full_gradient(f, dfdx, Tg())
end

function _check_gradient_shape(f,g,dfdg)
expected_shape = _get_expected_gradient_shape(f, g)
@assert isa(dfdg, expected_shape) "Gradient is a $(typeof(dfdg)), but should be a $expected_shape"
end

# _get_expected_gradient_shape(f_val, g_val), f is function output and g is function input
_get_expected_gradient_shape(::Number, ::Number) = Number
_get_expected_gradient_shape(::TT, ::Number) where{TT<:AbstractTensor} = get_base(TT)
_get_expected_gradient_shape(::Number, ::TT) where{TT<:AbstractTensor} = get_base(TT)
_get_expected_gradient_shape(::Tensor{forder,dim}, ::Tensor{gorder,dim}) where{forder,gorder,dim} = Tensor{forder+gorder,dim}
_get_expected_gradient_shape(::SymmetricTensor{forder,dim}, ::SymmetricTensor{gorder,dim}) where{forder,gorder,dim} = SymmetricTensor{forder+gorder,dim}

# Define helper function to figure out original input to gradient function
_get_original_gradient_input(::Dual{Tag{Tf,Tv}}) where{Tf,Tv} = zero(Tv)
_get_original_gradient_input(::AbstractTensor{<:Any,<:Any,<:Dual{Tag{Tf,Tv}}}) where{Tf,Tv} = zero(Tv)
Expand Down
10 changes: 9 additions & 1 deletion test/test_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,16 @@ S(C) = S(C, μ, Kb)
@test gradient(ts_ts_ts, xs) ≈ gradient(ts_ts_ts_ana, xs)

end


# Test that AssertionError is thrown for erroneous user functions
test_self(x) = x
test_self_wronggradient(x) = test_self(x), one(x) # Only true for scalars, not for tensors
@implement_gradient test_self test_self_wronggradient
@test gradient(test_self, rand()) ≈ 1.0 # Should work fine
@test_throws AssertionError gradient(test_self, rand(Tensor{2,3}))
@test_throws AssertionError gradient(test_self, rand(SymmetricTensor{2,3}))
end


end # testsection