Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use flat AST-generated functions for JAX FunctionGraph conversion #371

Merged
merged 5 commits into from
Apr 14, 2021

Conversation

brandonwillard
Copy link
Member

@brandonwillard brandonwillard commented Apr 12, 2021

This PR changes the JAX translation process so that it converts FunctionGraphs into flattened AST-generated functions instead of nested functions. The resulting code and JAXified functions are much simpler than the previous implementation's.

Here's an example:

import inspect

import aesara.tensor as at

from aesara.link.jax.dispatch import jax_funcify
from aesara.graph.fg import FunctionGraph


x = at.vector("x")
y = at.vector("y")

out = at.exp(x**2 + x * y + 1)

out_fg = FunctionGraph([x, y], [out])

jax_fn = jax_funcify(out_fg)
>>> print(inspect.getsource(jax_fn))

def jax_funcified_fgraph(x, y):
    auto_31 = dimshuffle(auto_30)
    auto_32 = elemwise(x, y)
    auto_34 = dimshuffle1(auto_33)
    auto_35 = power(x, auto_34)
    auto_36 = elemwise1(auto_35, auto_32)
    auto_37 = elemwise2(auto_36, auto_31)
    auto_38 = exp(auto_37)
    return (auto_38,)

This implementation uses the same basic idea as the prototype in #365, with the key differences being that it is very streamlined, requires minimal changes to be used with a simple Linker base class, and it generates the source first then the AST from that. Because of this, it requires no external libraries and produces easily debuggable functions.

This change addresses concerns over potential recomputation of outputs that are referenced more than once in a graph, or that are one of many outputs of a single Apply node. A test has also been added to this PR that directly confirms such outputs are always reused and never recomputed (in both cases).

Local tests replacing the Elemwise JAXification with jax.vmap also appear to work, so, now, we should seriously consider using jax.vmap in at least some cases (e.g. for Composite functions). This PR now contains a commit that uses jax.numpy.vectorize for Elemwise Ops with Composite scalar Ops.

@brandonwillard brandonwillard added the JAX Involves JAX transpilation label Apr 12, 2021
@brandonwillard brandonwillard self-assigned this Apr 12, 2021
@brandonwillard brandonwillard changed the title Create custom functions via AST for FunctionGraphs Use custom AST-generated functions for JAX FunctionGraph conversion Apr 12, 2021
@brandonwillard brandonwillard force-pushed the ast-jax-linker branch 2 times, most recently from 8716f02 to fcd4515 Compare April 12, 2021 03:15
@brandonwillard brandonwillard marked this pull request as ready for review April 12, 2021 03:15
@brandonwillard brandonwillard force-pushed the ast-jax-linker branch 2 times, most recently from 9b0bcb4 to b861c01 Compare April 12, 2021 03:34
@brandonwillard brandonwillard added the enhancement New feature or request label Apr 12, 2021
@codecov
Copy link

codecov bot commented Apr 12, 2021

Codecov Report

Merging #371 (d866d8e) into master (334c86f) will increase coverage by 0.02%.
The diff coverage is 83.49%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #371      +/-   ##
==========================================
+ Coverage   71.91%   71.94%   +0.02%     
==========================================
  Files         166      168       +2     
  Lines       54736    54753      +17     
==========================================
+ Hits        39365    39390      +25     
+ Misses      15371    15363       -8     
Impacted Files Coverage Δ
aesara/link/jax/jax_dispatch.py 0.00% <0.00%> (-82.52%) ⬇️
aesara/link/jax/jax_linker.py 0.00% <0.00%> (-81.71%) ⬇️
aesara/link/jax/dispatch.py 80.93% <80.93%> (ø)
aesara/link/basic.py 89.23% <88.88%> (+0.28%) ⬆️
aesara/link/jax/__init__.py 100.00% <100.00%> (ø)
aesara/link/jax/linker.py 100.00% <100.00%> (ø)
aesara/link/utils.py 63.12% <100.00%> (+10.66%) ⬆️

@brandonwillard brandonwillard force-pushed the ast-jax-linker branch 2 times, most recently from 2cb9f0a to 15526e8 Compare April 12, 2021 04:17
@brandonwillard brandonwillard changed the title Use custom AST-generated functions for JAX FunctionGraph conversion Use flat AST-generated functions for JAX FunctionGraph conversion Apr 12, 2021
@brandonwillard brandonwillard force-pushed the ast-jax-linker branch 2 times, most recently from 0acded9 to a4bffbd Compare April 12, 2021 04:34
@brandonwillard brandonwillard force-pushed the ast-jax-linker branch 2 times, most recently from 4788268 to 9472eaf Compare April 12, 2021 05:18
@brandonwillard brandonwillard mentioned this pull request Apr 13, 2021
@brandonwillard brandonwillard merged commit 1549649 into aesara-devs:master Apr 14, 2021
@brandonwillard brandonwillard linked an issue May 19, 2021 that may be closed by this pull request
@brandonwillard brandonwillard deleted the ast-jax-linker branch November 10, 2021 20:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request important JAX Involves JAX transpilation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Use AST-based approach to Python graph conversion
1 participant