Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A unified derivative operator #973

Open
shashi opened this issue Sep 16, 2023 · 4 comments
Open

A unified derivative operator #973

shashi opened this issue Sep 16, 2023 · 4 comments

Comments

@shashi
Copy link
Member

shashi commented Sep 16, 2023

Now that we are thinking about matrix calculus and symbolic array functions [fn2], I thought it is time to finally come up with a unified derivative operator/API. Here I'm going to use "derivative" to mean scalar derivatives, gradients, jacobians--something that's a Fréchet derivative in general.

The general API I'm thinking of is:

  • ∂(x) is the "derivative with respect to x" operator, x is not restricted to being a Real symbol but can be an array of symbols, or a symbol that represents an array, or a nesting of these. See [fn3] for why ∂
  • ∂(x)(f(x)) is the derivative linear map, with an added syntax-regularizing rule that it must support right-multiplication with an element from the vector space that x is from. This allows us to return 3 as the derivative of 3x, and also return a function-like LinOp(Δ ↦ A*Δ + Δ*A) which applies the function Δ ↦ A*Δ + Δ*A by treating its right multiplicand as Δ. In the below examples, x and y are scalars, u is a vector and A is a square matrix.
    • ∂(x)(3x) = 3. -- scalar derivative
    • ∂([x, y])(x + y) = [1 1] -- gradient
    • ∂(u)(A*u) = A -- jacobian
    • ∂(A)(A*A) = LinOp(Δ ↦ A*Δ + Δ*A) which is some object that substitutes the multiplicand for Δ on right-multiplication. One possible variation on this is in [fn1].
    • ∂(..) has some recursive structure, for example--∂([A, u])(A*u) = [∂(A)(A*u) A] this object can be right-multiplied with a vector of elements of the same shape as [A, u]
    • shape information can be propagated by means of the lambda, right-multiplication/application will disallow incorrect inputs since we store sizes of symbolic arrays.

An cool consequence is ∂(x)(f(x)) * .. is the "jvp", for example, LinOp(Δ ↦ A*Δ + Δ*A) would be the jvp-calculating function. Gaurav used an example of a function x -> x[1] + x[4] whose derivative is LinOp(Δ ↦ Δ[1] + Δ[4]) which would have a smaller memory footprint than the gradient if x is a million elements (assuming ∂ could choose to get the LinOp instead of gradient). It's possible to make it so that * I "materializes" a linear operator in general, whenever possible. @YingboMa this is the API we thought of for ForwardDiff2. (TODO: think about how * would work to do chain rule in the case of LinOp).

Present state of affairs: We have Differential (from pre-Symbolics days) which represents scalar differentiation, we don't really represent derivatives of other kinds like gradients, jacobians etc in the expression trees, so expand_derivatives only expands scalar derivatives. I think expand_derivatives should be part of simplify. Just to be clear, the current derivative, gradient and jacobian functions will continue to work as they do now.

Higher-order notation could be:

∂(y)(∂(x)(f(x, y))) = (∂(y)*∂(x))(f(x, y)) = (∂(y, x))(f(x, y))

fn1: Lambdas: Gaurav suggested there could be an object that behaves like a "hole" or the right multiplicand -- like in APL, I think that should be syntactic sugar for lambda. I've been thinking of adding lambdas for a while, and I think it's going to be interesting to have lambdas with unbound closed symbolic variables in them that take part in partial evaluation and get bound later if and only when the expression is compiled into code.

fn2: Array function registration (#292 #753 etc) needed some syntax, and the ∂ scheme comes in handy in allowing the definition of various derivatives while registering a function to be a symbolic primitive. This would look something like:

@register_symbolic foo(x::AbstractMatrix,
                       y::AbstractMatrix) begin
    size=(size(x)[1], size(y)[2])
    type=AbstractArray{Real,2}

    (x)(foo(x, y)) = ...
    (y)(foo(y)) = ...
    (y)((x)(foo(x, y))) = ...
end

We can just add the derivative rules as rewrite rules into expand_derivatives (in spirit, but we can also do something faster in practice.)

cc @alanedelman @stevengj @YingboMa @Roni-Edwin

@brianguenter I didn't think too much about how this would play with FastDifferentiation... We can talk more over slack.

fn3: I'm sure most of us would prefer is the exported symbol as opposed to D. It makes expressions nicely readable.

thanks to @gaurav-arya and @avik-pal for helping me brainstorm this. And to scmutils for allowing some experiments.

@shashi shashi changed the title All kinds of derivatives A unified derivative operator Sep 16, 2023
@xtalax
Copy link
Contributor

xtalax commented Sep 17, 2023

You also need to add support for doing the cross/dot product with these, and likely for arrays in general

@brianguenter
Copy link

@shashi @YingboMa now would be a good time for us to talk about tensor derivatives. It might give you ideas for how to structure your new tensor derivative API.

I've prepared a few pages describing how this works but before we talked I wanted to add one more example showing convolution, and another section explaining why the method works.

However, this has already taken me too long - tracking down bugs in FastDifferentiation has chewed up all my available discretionary time. We could get started on tensor derivatives with what exists now and I can expand it over time. Are you both available this week?

@YingboMa
Copy link
Member

I am flexible after 11 am EST this week.

@shashi
Copy link
Member Author

shashi commented Sep 19, 2023

Yes, let's talk soon. I'll message on slack. I can also help you debug the q(t) issue on a call sometime.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants