Skip to content

Commit

Permalink
Add tests for shared updates in JAX and NUMBA backends
Browse files Browse the repository at this point in the history
* Also fixes bug in JITCompiler when first output of inner fgraph is an input variable, as can happen in some specific functions with updates
  • Loading branch information
Ricardo Vieira committed Sep 22, 2022
1 parent 5d97ffa commit f9a47fa
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 4 deletions.
6 changes: 3 additions & 3 deletions aesara/link/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,8 @@ def create_jitable_thunk(
The JITed function that performs the computations.
"""
output_nodes = [o.owner for o in self.fgraph.outputs]
# This is a bit hackish, but we only return one of the output nodes
output_nodes = [o.owner for o in self.fgraph.outputs if o.owner is not None][1:]

converted_fgraph = self.fgraph_convert(
self.fgraph,
Expand Down Expand Up @@ -678,8 +679,7 @@ def thunk(

thunks.append(thunk)

# This is a bit hackish, but we only return one of the output nodes
return thunks, output_nodes[:1], fgraph_jit
return thunks, output_nodes, fgraph_jit

def make_all(self, input_storage=None, output_storage=None, storage_map=None):
fgraph = self.fgraph
Expand Down
2 changes: 1 addition & 1 deletion aesara/link/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def gc_helper(node_list: List[Apply]):
-------
2-tuple
FIRST, the set of Variable instances which are computed by node_list,
and SECOND a dictionary that maps each Variable instance to a the last
and SECOND a dictionary that maps each Variable instance to the last
node to use Variable as an input.
Extended Summary
Expand Down
16 changes: 16 additions & 0 deletions tests/link/jax/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,22 @@ def test_shared():
np.testing.assert_allclose(jax_res, new_a_value * 2)


def test_shared_updates():
a = shared(0)

aesara_jax_fn = function([], a, updates={a: a + 1}, mode="JAX")
res1, res2 = aesara_jax_fn(), aesara_jax_fn()
assert res1 == 0
assert res2 == 1
assert a.get_value() == 2

a.set_value(5)
res1, res2 = aesara_jax_fn(), aesara_jax_fn()
assert res1 == 5
assert res2 == 6
assert a.get_value() == 7


def test_jax_ifelse():

true_vals = np.r_[1, 2, 3]
Expand Down
16 changes: 16 additions & 0 deletions tests/link/numba/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,22 @@ def test_shared():
np.testing.assert_allclose(numba_res, new_a_value * 2)


def test_shared_updates():
a = shared(0)

aesara_numba_fn = function([], a, updates={a: a + 1}, mode="NUMBA")
res1, res2 = aesara_numba_fn(), aesara_numba_fn()
assert res1 == 0
assert res2 == 1
assert a.get_value() == 2

a.set_value(5)
res1, res2 = aesara_numba_fn(), aesara_numba_fn()
assert res1 == 5
assert res2 == 6
assert a.get_value() == 7


# We were seeing some weird results in CI where the following two almost
# sign-swapped results were being return from Numba and Python, respectively.
# The issue might be related to https://github.com/numba/numba/issues/4519.
Expand Down

0 comments on commit f9a47fa

Please sign in to comment.