Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement at.diff using basic slicing and subtraction Ops #901

Merged
merged 6 commits into from
Apr 22, 2022

Conversation

larryshamalama
Copy link
Contributor

@larryshamalama larryshamalama commented Apr 13, 2022

Closes #860.

I'm happy to receive any pointers or suggestions via the creation of this PR. Hopefully, this PR can bring some discussion since there are gaps in my understanding.

Two questions:

  • Currently, I did not remove the class DiffOp because the Jax and Numba seem to use it. Should I delete DiffOp altogether or just modify at.diff as in the current (first) commit?
  • Should the loop be replaced by a implementation that builds on aesara.scan?

@ricardoV94
Copy link
Contributor

ricardoV94 commented Apr 13, 2022

  • Currently, I did not remove the class DiffOp because the Jax and Numba seem to use it. Should I delete DiffOp altogether or just modify at.diff as in the current (first) commit?

Remove altogether. We don't need a dispatch for at.diff. We have dispatch for the slicing and subtraction so the new graph will also work fine with those backends. That's yet another reason why we prefer to keep the number of Ops small, less dispatch code to maintain.

  • Should the loop be replaced by a implementation that builds on aesara.scan?

I don't think so. Aesara scan will introduce some overhead and I don't think there's a big need for large or symbolic n (e.g., nobody asked for it before)

@brandonwillard brandonwillard added the refactor This issue involves refactoring label Apr 13, 2022
Copy link
Member

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current error in TestDiffOp.test_output_type looks like a static shape inference problem: i.e. the test assumes that all the Op.make_node and/or Op.infer_shape methods involved can exactly infer the output Type's shape using only static information.

Previously, at.diff would return a single DiffOp that would perform the static shape inference in DiffOp.make_node and return an output Type with the expected Type.shape values. Now, at.diff returns a graph that contains a few different Ops that are expected to produce the same end result. We need to find out how/where those Ops are losing the static information along the way to the graph's output Type.

Since all the new Ops that make up at.diff are differences and *Subtensor*s, my guess is that one of the *Subtensor*s is losing that information in its Op.make_node. I don't think those Ops have been updated to use the newly available Type.shape information, so that could easily be the issue.

In other words, the error may have nothing to do with the changes you made, and everything to do with some shortcomings in a *Subtensor*.make_node implementation.

These are good things to come across, because an improvement to a *Subtensor* Op could affect a lot more than just at.diff.

My advice is to run TestDiffOp.test_output_type locally, stop in the debugger and aesara.dprint(out, print_type=True). That will show you the entire graph and the types of each term in the graph. From there you can see which nodes/Ops produced Types that lost shape information. You may need to reason about that yourself, though (e.g. given a vector y, z = y[slice(None)] implies z.shape == y.shape, so, if y.type.shape == (s,) then z.type.shape == (s,)).

Actually, take a look at this:

import aesara.tensor as at

y = at.tensor(dtype="float32", shape=(2,))
y.type.shape
# (2,)

# We know that the shape of this should be equal to the shape of `y`
z = y[slice(None)]

aesara.dprint(z, print_type=True)
# Subtensor{::} [id A] <TensorType(float32, (None,))> ''   
#  |<TensorType(float32, (2,))> [id B] <TensorType(float32, (2,))>

z.type.shape
# (None,)

The Subtensor node is returning an output variable, z, with a Type of TensorType(float32, (None,)). Since the shape value None corresponds to "not-known-at-compile-time", it's this Op that's breaking the test.

At this point we should at least create an issue for the Subtensor problem. As far as this PR/issue is concerned, you can address the problem here, if it's not too complicated, or we can set that particular test case to xfail and address the problem in a separate PR. (A separate PR can also be worked on simultaneously and this one put on hold.) Either way, I'll leave that up to you.

@brandonwillard brandonwillard changed the title change at.diff to slicing + subtraction Implement at.diff using basic slicing and subtraction Ops Apr 14, 2022
@ricardoV94
Copy link
Contributor

The failing test was introduced in #855, because there used to be an actual error in the output type of the Diff Op, and not just a less precise output type as happens now with the reliance on subtensor Op.

I think it is fine to remove that test (or xfail it), and open a separate issue to make the output type of Subtensor more precise when possible, as @brandonwillard suggested.

@ricardoV94
Copy link
Contributor

To understand the conversation about the output type, it's useful to read the documentation here: https://aesara.readthedocs.io/en/latest/extending/type.html

@larryshamalama
Copy link
Contributor Author

Hi all, revisiting this PR after some hiatus. I just opened an issue to address what @brandonwillard explained above (see issue #922). Thank you for the clear instructions.

As for this PR, I currently removed the class DiffOp altogether and corresponding dispatch classes and tests that build on it. I temporarily marked test_output_type as xfail as @ricardoV94 suggested, but perhaps the Subtensor static shape inference problem can be addressed first. A minor thing is that I renamed TestDiffOp to TestDiff since the Op was removed.

@codecov
Copy link

codecov bot commented Apr 21, 2022

Codecov Report

Merging #901 (55f2d09) into main (e5ebf26) will decrease coverage by 0.03%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #901      +/-   ##
==========================================
- Coverage   78.94%   78.91%   -0.04%     
==========================================
  Files         152      152              
  Lines       47712    47654      -58     
  Branches    10858    10851       -7     
==========================================
- Hits        37667    37606      -61     
- Misses       7546     7548       +2     
- Partials     2499     2500       +1     
Impacted Files Coverage Δ
aesara/link/jax/dispatch.py 80.86% <ø> (-0.23%) ⬇️
aesara/link/numba/dispatch/extra_ops.py 98.00% <ø> (-0.20%) ⬇️
aesara/tensor/extra_ops.py 88.88% <100.00%> (-0.47%) ⬇️
aesara/gradient.py 77.20% <0.00%> (-0.27%) ⬇️
aesara/link/numba/dispatch/basic.py 92.06% <0.00%> (-0.25%) ⬇️

Copy link
Member

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good; thanks!

tests/tensor/test_extra_ops.py Show resolved Hide resolved
@brandonwillard brandonwillard marked this pull request as ready for review April 22, 2022 18:31
@brandonwillard brandonwillard merged commit 99e9600 into aesara-devs:main Apr 22, 2022
@brandonwillard
Copy link
Member

I forgot to remove those "fix pre-commit" before merging; regardless, such commits shouldn't exist. If pre-commit is correctly set up in a project, it shouldn't be possible to commit when there are pre-commit errors.

@larryshamalama, make sure pre-commit is correctly installed in your development environment, and, if you need to fix pre-commit errors in existing commits, amend/rebase onto the commit that introduce the error and include the fixes there.

In general, we would like it to be possible for all tests (pre-commit and otherwise) to pass on every commit. When that's not the case, it's difficult to revert changes and/or determine which commit introduces a specific error.

@larryshamalama
Copy link
Contributor Author

Thanks for the suggestion. I do have the habit to stack my commits without much thought... Whenever I fix pre-commit errors, it's because I forgot to check before pushing.

I'm looking at the repository commit history and I see many of my "fix pre-commit" commits. Generally, is it good practice to separate progress into multiple commits? I just learned about rebasing commits into one so I'll be sure to remember to do this for next time.

Also, for this PR, I initially left it as a draft because I was not sure if I needed to look into issue #922 first.

@larryshamalama larryshamalama deleted the remove_diffop branch April 23, 2022 00:03
@brandonwillard
Copy link
Member

I'm looking at the repository commit history and I see many of my "fix pre-commit" commits. Generally, is it good practice to separate progress into multiple commits? I just learned about rebasing commits into one so I'll be sure to remember to do this for next time.

Take a look at this page for some descriptions and examples of good commit structuring.

Also, for this PR, I initially left it as a draft because I was not sure if I needed to look into issue #922 first.

That's a more important issue, but we don't need it to hold up this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
refactor This issue involves refactoring
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Remove Diff Op
3 participants