Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use implicit gradient Ops #1275

Open
brandonwillard opened this issue Oct 24, 2022 · 5 comments
Open

Use implicit gradient Ops #1275

brandonwillard opened this issue Oct 24, 2022 · 5 comments
Labels
enhancement New feature or request graph rewriting help wanted Extra attention is needed important refactor This issue involves refactoring request discussion

Comments

@brandonwillard
Copy link
Member

brandonwillard commented Oct 24, 2022

We can delay the construction of explicit gradient graphs (i.e. use of Op.grad and the like) by employing implicit gradient Ops that are later replaced with explicit sub-graph (e.g. similar to how OpFromGraphs can be "in-lined").

The approach would look as follows:

from functools import wraps

import aesara
import aesara.tensor as at

from aesara.graph.basic import Apply
from aesara.graph.op import Op

from aesara.compile.mode import optdb
from aesara.graph.rewriting.basic import in2out, node_rewriter



class Gradient(Op):
    __props__ = ("grad_options",)

    def __init__(self, **grad_options):
        self.grad_options = tuple(grad_options.items())

    def make_node(self, cost, *wrt):

        # Only the output types are needed, but, since there's some caching
        # here and, if we also assume that most gradients are _eventually_
        # expanded as-is, this seems somewhat less wasteful.
        grad_res = aesara.grad(cost, wrt, **dict(self.grad_options))

        if not isinstance(grad_res, (tuple, list)):
            grads = (grad_res,)
        else:
            grads = grad_res

        inputs = (cost,) + wrt
        outputs = [g.clone() for g in grads]

        return Apply(self, inputs, outputs)

    def perform(self, *args, **kwargs):
        raise NotImplementedError("This shouldn't ever be called")


@wraps(aesara.grad)
def grad(cost, wrt, **kwargs):

    if not isinstance(wrt, (list, tuple)):
        wrt = [wrt]

    GradOp = Gradient(**kwargs)
    return GradOp(cost, *wrt)


@node_rewriter([Gradient])
def expand_gradients(fgraph, node):
    op = node.op

    cost, *wrt = node.inputs
    grad_res = aesara.grad(cost, wrt, **dict(op.grad_options))

    if not isinstance(grad_res, (tuple, list)):
        grads = (grad_res,)
    else:
        grads = grad_res

    return grads


optdb.register(
    "expand_gradients",
    in2out(expand_gradients),
    "fast_compile",
    "fast_run",
    position=-0.01,
)


x = at.vector("x")
x_sum = x.sum()

x_grad = grad(x_sum, x)

aesara.dprint(x_grad)
# Gradient{grad_options=()} [id A]
#  |Sum{acc_dtype=float64} [id B]
#  | |x [id C]
#  |x [id C]


with aesara.config.change_flags(on_opt_error="raise"):
    x_grad_fn = aesara.function([x], x_grad)

aesara.dprint(x_grad_fn)
# Alloc [id A] 1
#  |TensorConstant{(1,) of 1.0} [id B]
#  |Shape_i{0} [id C] 0
#    |x [id D]

This also has the effect of enabling rewrites on gradient expressions and of providing more shape information to our gradient implementations.

For instance, this could be used to remove shape inference responsibilities and requirements from some Op.make_node and Op.grad implementations (e.g. Elemwise.L_op) by allowing access to ShapeFeatures and other compile/rewrite-time only information. Simply put, this is probably the easiest—and even best—way to guarantee that symbolic gradient implementations will always have the most Type and shape information available, and all without wastefully cloning shape graphs and re-performing rewrites (e.g. like constant folding) on them.

This approach was proposed in Theano/Theano#4452 and might also help with #682. As mentioned in the latter, we need to think carefully about when we make implicit gradient Ops explicit (i.e. "expand" them). Depending on exactly which rewrites are applied and when, the resulting gradient graphs could be quite different and have distinct and possibly unexpected numerical properties.

To keep things simple, we can expand implicit Ops right after the first pass of basic canonicalizations so that shape inference/ShapeFeature is useful and other rewrites (e.g. specializations) won't get in the way. If this approach helps with #682, then great, but, if not, I don't know if we should get into the details of further delayed or staged expansions just yet. Regardless, we'll have the machinery available to do that whenever we want.

N.B. Performing expansions in this way still changes our gradient results so that they're dependent on our canonicalizations. In some ways, this relationship sounds good, since it seems to imply that the set of graphs we would be dealing with from then on would be more "regular". Over time, we could converge on a more concentrated and effective set of stabilizing rewrites for the exact kinds of gradients that our implementations and canonicalizations tend to produce, because we would have to deal less with the particulars of "random" user-formulated graphs.

@brandonwillard brandonwillard added enhancement New feature or request help wanted Extra attention is needed graph rewriting refactor This issue involves refactoring request discussion labels Oct 24, 2022
@brandonwillard
Copy link
Member Author

brandonwillard commented Oct 24, 2022

@rlouf, did you ask about this approach (i.e. using rewrites for gradients) in a Discussion or comment in Gitter? If so, let's link that here so we can highlight more potential connections, use-cases, considerations, etc.

@ricardoV94
Copy link
Contributor

ricardoV94 commented Oct 24, 2022

This Aeppl PR highlighted one of the mentioned disadvantages of eager grads, failing to automatically use specialized gradients of specialized Ops: aesara-devs/aeppl#156

Similar issue for the gradient of the Softmax: #679

@rlouf
Copy link
Member

rlouf commented Oct 24, 2022

@rlouf, did you ask about this approach (i.e. using rewrites for gradients) in a Discussion or comment in Gitter?

I remember it was mentioned during the last Aesara meeting, and if I find other discussions I'll link them here.

@brandonwillard
Copy link
Member Author

brandonwillard commented Oct 24, 2022

This Aeppl PR highlighted one of the mentioned disadvantages of eager grads, failing to automatically use specialized gradients of specialized Ops: aesara-devs/aeppl#156

Similar issue for the gradient of the Softmax: #679

Thanks; these are exactly what we needed to build a larger case for this change.

At this point, it seems like the only way to move some shape-dependent/critical logic to compile-time—where ShapeFeatures are available—is by always running a "global" ShapeFeature somehow, or this approach, and the latter is considerably more accessible and promising at this time.

The underlying idea is that it should be possible for all Op.make_node and Op.grad implementations to return TensorTypes with only the correct dtypes and number of dimensions. In other words, no shape inference should be needed by those methods or at construction-time/the user-level in general. All the shape inference logic should be implemented once and in one place, i.e. Op.infer_shape, and applied efficiently to all the relevant graphs in a consistent manner at compile-time.

The primary thing preventing that from happening is the errant designs employed by the broadcasting-related Ops and constructors like Elemwise, BroadcastTo, and broadcast_shape_iter, and possibly a couple others.

@brandonwillard
Copy link
Member Author

brandonwillard commented Dec 2, 2022

Before I forget, we can also use gradient methods offered by our transpilation target languages in this case. For example, we could convert an un-expanded/inlined gradient Op to a gradient computation in JAX or Numba.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request graph rewriting help wanted Extra attention is needed important refactor This issue involves refactoring request discussion
Projects
Status: Graph
Development

No branches or pull requests

3 participants