Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement custom gradient with multi-argument functions #197

Open
KnutAM opened this issue Mar 28, 2023 · 5 comments
Open

implement custom gradient with multi-argument functions #197

KnutAM opened this issue Mar 28, 2023 · 5 comments

Comments

@KnutAM
Copy link
Member

KnutAM commented Mar 28, 2023

From Slack-comment by @koehlerson; how to implement custom gradient calculation for a multi-argument function.
It is common to have such a case for autodiff, so would be good to have a clear way of doing this.
The solution I can come up with now is

using Tensors
import ForwardDiff: Dual

# General setup for any function f(x, args...)
struct Foo{F,T<:Tuple} <: Function # <:Function optional
    f::F
    args::T
end
struct FooGrad{FT<:Foo} <: Function # <: Function required
    foo::FT
end

function (foo::Foo)(x)
    println("Foo with Any: ", typeof(x))  # To show that it works
    return foo.f(x, foo.args...)
end
function (foo::Foo)(x::AbstractTensor{<:Any,<:Any,<:Dual})
    println("Foo with Dual: ", typeof(x))  # To show that it works
    return Tensors._propagate_gradient(FooGrad(foo), x)
end
function (fg::FooGrad)(x)
    println("FooGrad: ", typeof(x)) # To show that it works
    return f_dfdx(fg.foo.f, x, fg.foo.args...)
end

# Specific example to setup for bar(x, a, b), must then also define f_dfdx(::typeof(bar), x, a, b):
bar(x, a, b) = norm(a*x)^b 
dbar_dx(x, a, b) = b*(a^b)*norm(x)^(b-2)*x
f_dfdx(::typeof(bar), args...) = (bar(args...), dbar_dx(args...))

# At the location in the code where the derivative will be calculated
t = rand(SymmetricTensor{2,3}); a = π; b = 2 # Typically inputs
foo = Foo(bar, (a, b))
gradient(foo, t) == dbar_dx(t, a, b)

But it is quite cumbersome, especially if only needed for one function, so a better method would be good.
(Tensors._propagate_gradient is renamed to propagate_gradient, exported, and documented in #181)

@KristofferC
Copy link
Collaborator

KristofferC commented Mar 28, 2023

I don't understand why a closure over a and b wouldn't work here.

x->bar(x, a, b)

@KnutAM
Copy link
Member Author

KnutAM commented Mar 29, 2023

I'm not sure that I follow, that would only define one function for x::Any. Do you have a complete example?
If working directly on bar, I think it is necessary to write a custom propagate_gradient using Tensors._extract_value and Tensors._insert_gradient. Alternatively, we could extend that to accept args...:

function propagate_gradient(f_dfdx::Function, x::Union{AbstractTensor{<:Any, <:Any, <:Dual}, Dual}, args...)
    fval, dfdx_val = f_dfdx(_extract_value(x), args...)
    _check_gradient_shape(fval,x,dfdx_val)
    return _insert_gradient(fval, dfdx_val, x)
end

@KristofferC
Copy link
Collaborator

Okay, I missed the point:

implement custom gradient calculation for a multi-argument function.

Carry on..

@KnutAM KnutAM changed the title implement gradient with multi-argument functions implement custom gradient with multi-argument functions Mar 29, 2023
@koehlerson
Copy link
Member

Initially I planned to do a custom layer for energy densities something like

energy(F,material,state) = #something

analytic_or_AD(energy::FUN, F, material, state) where FUN<:Function = Tensors.hessian(x->energy(x,material,state),F)

where a generic dispatch uses Tensors.hessian and for known analytic parts you call another dispatch. However, @implement_gradient should be capable of handling this imo. Further it feels that I reinvent the wheel. I don't think that the dispatchwise approach could substitute only pieces of the derivative, so mix and match analytic and automatic differentiation when the energy function calls again something which is known analytically as e.g. strain energy densities

@KnutAM
Copy link
Member Author

KnutAM commented Mar 29, 2023

But I think the approach of allowing args... in propagate_gradient could be nice for this:

using Tensors
import Tensors: _extract_value, _insert_gradient, Dual
# Change in Tensors.jl
function propagate_gradient(f_dfdx::Function, x::Union{AbstractTensor{<:Any, <:Any, <:Dual}, Dual}, args...)
    fval, dfdx_val = f_dfdx(_extract_value(x), args...)
    # _check_gradient_shape(fval,x,dfdx_val) # PR181
    return _insert_gradient(fval, dfdx_val, x)
end

# User code:
# - Definitions
bar(x, a, b) = norm(a*x)^b
dbar_dx(x, a, b) = b*(a^b)*norm(x)^(b-2)*x
bar_dbar_dx(x, a, b) = (bar(x, a, b), dbar_dx(x, a, b))
bar(x::AbstractTensor{<:Any, <:Any, <:Dual}, args...) = (println("DualBar"); propagate_gradient(bar_dbar_dx, x, args...))
# - At call-site
t = rand(SymmetricTensor{2,3}); a = π; b = 2 # Typically inputs
gradient(x->bar(x, a, b), t)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants