Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix None output node bug in JITLinker.create_jitable_thunk #1205

Merged
merged 1 commit into from
Sep 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions aesara/link/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,8 @@ def create_jitable_thunk(
The JITed function that performs the computations.

"""
output_nodes = [o.owner for o in self.fgraph.outputs]
# This is a bit hackish, but we only return one of the output nodes
output_nodes = [o.owner for o in self.fgraph.outputs if o.owner is not None][:1]
Comment on lines +644 to +645
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wasn't changed by this PR, but I wonder does the hack actually work? There's a whole bunch of logic downstream concerning garbage collection and whatnot that is supposed to rely on this. Too dense for me to take a look though :D

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wasn't changed by this PR, but I wonder does the hack actually work? There's a whole bunch of logic downstream concerning garbage collection and whatnot that is supposed to rely on this. Too dense for me to take a look though :D

It must be working; how else have we been using Numba/JAX to evaluate graphs?

The hackish part is that we were and still are (to some extent) using the old Function/VM framework to do something that it wasn't designed to do. For example, we don't have a thunk for each node in the graph, like we normally would.

To be clear, shared variable and updates support for the JIT backends has always been incomplete, and this falls squarely into the latter.


converted_fgraph = self.fgraph_convert(
self.fgraph,
Expand Down Expand Up @@ -678,8 +679,7 @@ def thunk(

thunks.append(thunk)

# This is a bit hackish, but we only return one of the output nodes
return thunks, output_nodes[:1], fgraph_jit
return thunks, output_nodes, fgraph_jit

def make_all(self, input_storage=None, output_storage=None, storage_map=None):
fgraph = self.fgraph
Expand Down
2 changes: 1 addition & 1 deletion aesara/link/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def gc_helper(node_list: List[Apply]):
-------
2-tuple
FIRST, the set of Variable instances which are computed by node_list,
and SECOND a dictionary that maps each Variable instance to a the last
and SECOND a dictionary that maps each Variable instance to the last
node to use Variable as an input.

Extended Summary
Expand Down
16 changes: 16 additions & 0 deletions tests/link/jax/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,22 @@ def test_shared():
np.testing.assert_allclose(jax_res, new_a_value * 2)


def test_shared_updates():
a = shared(0)

aesara_jax_fn = function([], a, updates={a: a + 1}, mode="JAX")
res1, res2 = aesara_jax_fn(), aesara_jax_fn()
assert res1 == 0
assert res2 == 1
assert a.get_value() == 2

a.set_value(5)
res1, res2 = aesara_jax_fn(), aesara_jax_fn()
assert res1 == 5
assert res2 == 6
assert a.get_value() == 7


def test_jax_ifelse():

true_vals = np.r_[1, 2, 3]
Expand Down
16 changes: 16 additions & 0 deletions tests/link/numba/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,22 @@ def test_shared():
np.testing.assert_allclose(numba_res, new_a_value * 2)


def test_shared_updates():
a = shared(0)

aesara_numba_fn = function([], a, updates={a: a + 1}, mode="NUMBA")
res1, res2 = aesara_numba_fn(), aesara_numba_fn()
assert res1 == 0
assert res2 == 1
assert a.get_value() == 2

a.set_value(5)
res1, res2 = aesara_numba_fn(), aesara_numba_fn()
assert res1 == 5
assert res2 == 6
assert a.get_value() == 7


# We were seeing some weird results in CI where the following two almost
# sign-swapped results were being return from Numba and Python, respectively.
# The issue might be related to https://github.com/numba/numba/issues/4519.
Expand Down