Skip to content

Commit

Permalink
Rename modules jax_linker to linker and jax_dispatch to dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Apr 12, 2021
1 parent 1b6786e commit 0acded9
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion aesara/link/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from aesara.link.jax.jax_linker import JAXLinker
from aesara.link.jax.linker import JAXLinker
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def create_jax_thunks(
"""
import jax

from aesara.link.jax.jax_dispatch import jax_funcify, jax_typify
from aesara.link.jax.dispatch import jax_funcify, jax_typify

output_nodes = [o.owner for o in self.fgraph.outputs]

Expand Down
8 changes: 4 additions & 4 deletions doc/JaxOps.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ logic.
return res if n_outs > 1 else res[0]
*Code in context:*
https://github.com/pymc-devs/aesara/blob/master/aesara/link/jax/jax_dispatch.py#L583
https://github.com/pymc-devs/aesara/blob/master/aesara/link/jax/dispatch.py#L583

Step 3: Register the function with the jax_funcify dispatcher
=============================================================
Expand All @@ -49,9 +49,9 @@ function with the Aesara JAX Linker. This is done through the dispatcher
decorator and closure as seen below. If unsure how dispatching works a
short tutorial on dispatching is at the bottom.

The linker functions should be added to ``jax_dispatch`` module linked
The linker functions should be added to ``dispatch`` module linked
below.
https://github.com/pymc-devs/aesara/blob/master/aesara/link/jax/jax_dispatch.py
https://github.com/pymc-devs/aesara/blob/master/aesara/link/jax/dispatch.py

Here’s an example for the Eye Op.

Expand All @@ -69,7 +69,7 @@ Here’s an example for the Eye Op.
return eye
*Code in context:*
https://github.com/pymc-devs/aesara/blob/master/aesara/link/jax/jax_dispatch.py#L1071
https://github.com/pymc-devs/aesara/blob/master/aesara/link/jax/dispatch.py#L1071

Step 4: Write tests
===================
Expand Down
2 changes: 1 addition & 1 deletion tests/link/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from aesara.graph.optdb import Query
from aesara.ifelse import ifelse
from aesara.link.jax import JAXLinker
from aesara.link.jax.jax_dispatch import jax_funcify
from aesara.link.jax.dispatch import jax_funcify
from aesara.scan.basic import scan
from aesara.tensor import basic as aet
from aesara.tensor import blas as aet_blas
Expand Down

0 comments on commit 0acded9

Please sign in to comment.