-
-
Notifications
You must be signed in to change notification settings - Fork 211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Explain relationship of Zygote's complex gradients with the Wirtinger calculus #328
base: master
Are you sure you want to change the base?
Conversation
The history of the current phrasing is explained here (as a reference when considering this PR) #29 |
@MikeInnes Would you mind reviewing this? |
docs/src/adjoints.md
Outdated
@@ -56,7 +56,7 @@ julia> mygradient(sin, 0.5) | |||
|
|||
The rest of this section contains more technical detail. It can be skipped if you only need an intuition for pullbacks; you generally won't need to worry about it as a user. | |||
|
|||
If ``x`` and ``y`` are vectors, ``\frac{\partial y}{\partial x}`` becomes a Jacobian. Importantly, because we are implementing reverse mode we actually left-multiply the Jacobian, i.e. `v'J`, rather than the more usual `J*v`. Transposing `v` to a row vector and back `(v'J)'` is equivalent to `J'v` so our gradient rules actually implement the *adjoint* of the Jacobian. This is relevant even for scalar code: the adjoint for `y = sin(x)` is `x̄ = sin(x)'*ȳ`; the conjugation is usually moot but gives the correct behaviour for complex code. "Pullbacks" are therefore sometimes called "vector-Jacobian products" (VJPs), and we refer to the reverse mode rules themselves as "adjoints". | |||
If ``x`` and ``y`` are vectors, ``\frac{\partial y}{\partial x}`` becomes a Jacobian. Importantly, because we are implementing reverse mode we actually left-multiply the Jacobian, i.e. `v'J`, rather than the more usual `J*v`. Transposing `v` to a row vector and back `(v'J)'` is equivalent to `J'v` so our gradient rules actually implement the *adjoint* of the Jacobian. This is relevant even for scalar code: the adjoint for `y = sin(x)` is `x̄ = ȳ*cos(x)'`; the conjugation is usually moot but gives the correct behaviour for complex code, if `y(x)` is holomorphic. "Pullbacks" are therefore sometimes called "vector-Jacobian products" (VJPs), and we refer to the reverse mode rules themselves as "adjoints". |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there an intentional change here? If not, best to remove it from the diff.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think GitHub just can't handle such long lines. See here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original wording was a bit unclear, but this gives the correct result for complex code in general, not just the holomorphic case (it is redundant for real code). The adjoint is what causes us to get back complex sensitivities (otherwise the output would be the conjugate of the sensitivity).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is still unclear to me, what you would be conjugating in the non-holomorphic case, since the complex derivative doesn't exist in this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Individual pullbacks don't (and can't) know whether the function as a whole is holomorphic, which is a global property, and don't ever see complex derivatives. They only work with sensitivities, as we define them, and taking the adjoint is correct for sensitivities.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I probably should have clarified: I am only talking about the partial function being holomorphic. So if we're defining the pullback for exp
, for example, we can express the partial derivative of exp
as a complex derivative. We can't do that for abs2
, for example. Is there anywhere, where the term sensitivity
is mathematically defined? In my understanding, it is a vector, you pass to a differential form, so it basically specifies the linear combination of partial derivatives. If we're only limiting ourselves to scalars, this would just be the partial derivative of the output with respect to the current function in the chain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, that's fair, but still an edge case (and even then, we still have the adjoint of the linear map, it's just expressed differently). These specific docs are a reference for adjoints rather than complex AD and I'd rather not overload people with the vagaries of that before they've gotten started with the real part :) So ideally this should just mention briefly that adjoints are generally relevant for complex AD, and have a link to the other docs for more detail.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, sounds like a good idea
docs/src/complex.md
Outdated
@@ -2,7 +2,7 @@ | |||
|
|||
Complex numbers add some difficulty to the idea of a "gradient". To talk about `gradient(f, x)` here we need to talk a bit more about `f`. | |||
|
|||
If `f` returns a real number, things are fairly straightforward. For ``c = x + yi`` and ``z = f(c)``, we can define the adjoint ``\bar c = \frac{\partial z}{\partial x} + \frac{\partial z}{\partial y}i = \bar x + \bar y i`` (note that ``\bar c`` means gradient, and ``c'`` means conjugate). It's exactly as if the complex number were just a pair of reals `(re, im)`. This works out of the box. | |||
If `f` returns a real number, things are fairly straightforward. For ``c = x + yi`` and ``z = f(c)``, we can define the adjoint ``\bar c = \frac{\partial z}{\partial x} + \frac{\partial z}{\partial y}i = \bar x + \bar y i`` (note that ``\bar c`` means gradient, and ``c^*`` means conjugate). It's exactly as if the complex number were just a pair of reals `(re, im)`. This works out of the box. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should stick with c'
consistently here, because that's what Julia uses (and it's used in a bunch of other places on this page). If we're using c*
elsewhere we can change that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It just felt a little bit weird to me, to use \overline
for derivatives and '
for the complex conjugate in mathematical notation. ^*
very commonly refers to just the complex conjugate in physics, so it felt more natural here and less confusing. Of course, I left it in snippets that are actual Julia code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair, but then we should mention the '
notation as well. Might be better to have a note on it in context, rather than trying to get it into this paragraph.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where exactly are you talking about?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as in, we can clarify the notation when it's first used, rather than trying to define everything in a parenthetical.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, will do
The gradient definition Zygote uses can also be expressed in terms of the [Wirtinger calculus](https://en.wikipedia.org/wiki/Wirtinger_derivatives) using the operators ``\frac{\partial}{\partial z}`` and ``\frac{\partial}{\partial z^*}``: | ||
|
||
```math | ||
f: \mathbb{C} \rightarrow \mathbb{R}, \qquad |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't you mean C->C here? The below references Re(f)
and f*
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is C->R, but since f
is real, I'm using f = Re(f) = f*
here, to express this in terms of the Wirtinger derivatives
\left( \frac{\partial w}{\partial z} \right)^{\!*} | ||
= \overline{f} \cdot \left( \frac{\partial w}{\partial z} \right)^{\!*} | ||
\qquad \text{if $w(z)$ holomorphic} \Leftrightarrow \frac{\partial w}{\partial z^*} = 0 | ||
\end{align*} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I basically see what the above section is saying (though it might be nice to have an implementation in code for the sake of precision/clarity). This section strikes me as confusing though; a big equation dump there and while I'm sure it's all correct, I'm not sure what it's trying to get across.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can try explaining this a bit better in words. I'm trying to give the mathematical reason here, why we need to take the complex conjugate of the derivative in pullbacks.
Sorry, it took me so long to implement these changes. I hope that's better now. |
I was confused quite a bit at first about how Zygote handles complex gradients, and this wasn't quite as obvious from the documentation. This also came up in the development of
ChainRules.jl
.