diff --git a/Project.toml b/Project.toml index 25cdc6a05..0cc072d9a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.7.5" +version = "0.8.0" [deps] MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" diff --git a/docs/src/index.md b/docs/src/index.md index 0d1370fc9..96ded454d 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -126,26 +126,26 @@ This document will explain this point of view in some detail. ##### Some terminology/conventions. Let ``p`` be an element of type M, which is defined by some assignment of numbers ``x_1,...,x_m``, -say ``(x_1,...,x_m) = (a_1,...,1_m)`` +say ``(x_1,...,x_m) = (a_1,...,1_m)`` A _function_ ``f:M \to K`` on ``M`` is (for simplicity) a polynomial ``K[x_1, ... x_m]`` -The tangent space ``T_pM`` of ``T`` at point ``p`` is the ``K``-vector space spanned by derivations ``d/dx``. -The tangent space acts linearly on the space of functions. They act as usual on functions. Our starting point is +The tangent space ``T_pM`` of ``T`` at point ``p`` is the ``K``-vector space spanned by derivations ``d/dx``. +The tangent space acts linearly on the space of functions. They act as usual on functions. Our starting point is that we know how to write down ``d/dx(f) = df/dx``. The collection of tangent spaces ``{T_pM}`` for ``p\in M`` is called the _tangent bundle_ of ``M``. -Let ``df`` denote the first order information of ``f`` at each point. This is called the differential of ``f``. +Let ``df`` denote the first order information of ``f`` at each point. This is called the differential of ``f``. If the derivatives of ``f`` and ``g`` agree at ``p``, we say that ``df`` and ``dg`` represent the same cotangent at ``p``. -The covectors ``dx_1, ..., dx_m`` form the basis of the cotangent space ``T^*_pM`` at ``p``. Notice that this vector space is +The covectors ``dx_1, ..., dx_m`` form the basis of the cotangent space ``T^*_pM`` at ``p``. Notice that this vector space is dual to ``T_p`` The collection of cotangent spaces ``{T^*_pM}`` for ``p\in M`` is called the _cotangent bundle_ of ``M``. ##### Push-forwards and pullbacks -Let ``N`` be another type, defined by numbers ``y_1,...,y_n``, and let ``g:M \to N`` be a _map_, that is, +Let ``N`` be another type, defined by numbers ``y_1,...,y_n``, and let ``g:M \to N`` be a _map_, that is, an ``n``-dimensional vector ``(g_1, ..., g_m)`` of functions on ``M``. We define the _push-forward_ ``g_*:TM \to TN`` between tangent bundles by ``g_*(X)(h) = X(g\circ h)`` for any tangent vector ``X`` and function ``f``. @@ -154,7 +154,7 @@ We have ``g_*(d/dx_i)(y_j) = dg_j/dx_i``, so the push-forward corresponds to the Similarly, the pullback of the differential ``df`` is defined by ``g^*(df) = d(f\circ g)``. So for a coordinate differential ``dy_j``, we have ``g^*(dy_j) = d(g_j)``. Notice that this is a covector, and we could have defined the pullback by its action on vectors by -``g^*(dh)(X) = g_*(X)(dh) = X(g\circ h)`` for any function ``f`` on ``N`` and ``X\in TM``. In particular, +``g^*(dh)(X) = g_*(X)(dh) = X(g\circ h)`` for any function ``f`` on ``N`` and ``X\in TM``. In particular, ``g^*(dy_j)(d/dx_i) = d(g_j)/dx_i``. If you work out the action in a basis of the cotangent space, you see that it acts by the adjoint of the Jacobian. @@ -170,13 +170,13 @@ But pulling back gradients still should not be a thing. If the goal is to evaluate the gradient of a function ``f=g\circ h:M \to N \to K``, where ``g`` is a map and ``h`` is a function, we have two obvious options: -First, we may push-forward a basis of ``M`` to ``TK`` which we identify with K itself. +First, we may push-forward a basis of ``M`` to ``TK`` which we identify with K itself. This results in ``m`` scalars, representing components of the gradient. Step-by-step in coordinates: 1. Compute the push-forward of the basis of ``T_pM``, i.e. just the columns of the Jacobian ``dg_i/dx_j``. 2. Compute the push-forward of the function ``h`` (consider it as a map, K is also a manifold!) to get ``h_*(g_*T_pM) = \sum_j dh/dy_i (dg_i/dx_j)`` -Second, we pull back the differential ``dh``: +Second, we pull back the differential ``dh``: 1. compute ``dh = dh/dy_1,...,dh/dy_n`` in coordinates. 2. pull back by (in coordinates) multiplying with the adjoint of the Jacobian, resulting in ``g_*(dh) = \sum_i(dg_i/dx_j)(dh/dy_i)``. @@ -263,12 +263,14 @@ Similarly every `pullback` returns an extra `∂self`, which for things without - **Pullback** - returned by `rrule` - takes output space wobbles, gives input space wiggles - - 1 argument per original function return + - Argument structure matches structure of primal function output + - If primal function returns a tuple, then pullback takes in a tuple of differentials. - 1 return per original function argument + 1 for the function itself - **Pushforward:** - part of `frule` - takes input space wiggles, gives output space wobbles + - Argument structure matches primal function argument structure, but passed as a tuple at start of `frule` - 1 argument per original function argument + 1 for the function itself - 1 return per original function return diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 4edc8c88f..8da21fbd8 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -25,7 +25,7 @@ methods for `frule` and `rrule`: function ChainRulesCore.rrule(::typeof(f), x₁::Number, x₂::Number, ...) Ω = f(x₁, x₂, ...) \$(statement₁, statement₂, ...) - return Ω, (ΔΩ₁, ΔΩ₂, ...) -> ( + return Ω, ((ΔΩ₁, ΔΩ₂, ...)) -> ( NO_FIELDS, ∂f₁_∂x₁ * ΔΩ₁ + ∂f₂_∂x₁ * ΔΩ₂ + ...), ∂f₁_∂x₂ * ΔΩ₁ + ∂f₂_∂x₂ * ΔΩ₂ + ...), @@ -185,8 +185,10 @@ function scalar_rrule_expr(f, call, setup_stmts, inputs, partials) propagation_expr(Δs, ∂s) end + # Multi-output functions have pullbacks with a tuple input that will be destructured + pullback_input = n_outputs == 1 ? first(Δs) : Expr(:tuple, Δs...) pullback = quote - function $(propagator_name(f, :pullback))($(Δs...)) + function $(propagator_name(f, :pullback))($pullback_input) return (NO_FIELDS, $(pullback_returns...)) end end diff --git a/test/rules.jl b/test/rules.jl index a80000376..36a79b342 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -123,3 +123,17 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @inferred frule((Zero(), sx, sy), very_nice, 1, 2) end end + + +simo(x) = (x, 2x) +@scalar_rule(simo(x), 1, 2) + +@testset "@scalar_rule with multiple inputs" begin + y, simo_pb = rrule(simo, π) + + @test simo_pb((10, 20)) == (NO_FIELDS, 50) + + y, ẏ = frule((NO_FIELDS, 50), simo, π) + @test y == (π, 2π) + @test ẏ == (50, 100) +end