Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace uses of Rebroadcast by SpecifyShape #915

Conversation

ricardoV94
Copy link
Contributor

@ricardoV94 ricardoV94 commented Apr 19, 2022

This one is not as straightforward as I was hoping, mainly because of the current dependency on the ability to "unbroadcast" variables (i.e., to mask known broadcastable dimensions). This was done via unbroadcast and patternbroadcast which combined both unbroadcast and addbroadcast

"Unbroadcasting" is still used explicitly in Scan and broadcast_like.

Changes

For now explicit "unbroadcasting" goes through the more limited MaskBroadcastable Op, adapted from the old Rebroadcast. Everything else that used to rely on Rebroadcast uses SpecifyShape or the new thin wrapper specify_broadcastable.

Addresses part of #748 and #917

Closes #955

@brandonwillard
Copy link
Member

Ideally we wouldn't need this Op at all but I fear there are some deeply-ingrained dependencies on this...

Sounds like this is the core of our work: finding these dependencies and addressing them directly.

@ricardoV94 ricardoV94 force-pushed the replace_rebroadcast_by_specifyshape branch from a865f2d to cb0eba7 Compare May 9, 2022 12:31
@ricardoV94
Copy link
Contributor Author

The last places that seem to make explicit use of "unbroadcasting" are these:

# the template may have 1s in its shape without being broadcastable
if rval.broadcastable != template.broadcastable:
rval = unbroadcast(
rval,
*[
i
for i in range(rval.ndim)
if rval.broadcastable[i] and not template.broadcastable[i]
],
)

aesara/aesara/scan/basic.py

Lines 750 to 758 in 6cca25e

# We need now to allocate space for storing the output and copy
# the initial state over. We do this using the expand function
# defined in scan utils
sit_sot_scan_inputs.append(
expand_empty(
at.unbroadcast(shape_padleft(actual_arg), 0),
actual_n_steps,
)
)

aesara/aesara/scan/basic.py

Lines 878 to 885 in 6cca25e

# we need to see if we need to pad our sequences with an
# unbroadcastable dimension; case example : we return an
# output for which we want all intermediate. If n_steps is 1
# then, if we return the output as given by the innner function
# this will represent only a slice and it will have one
# dimension less.
if isinstance(inner_out.type, TensorType) and return_steps.get(pos, 0) != 1:
outputs[pos] = at.unbroadcast(shape_padleft(inner_out), 0)

aesara/aesara/scan/basic.py

Lines 1012 to 1017 in 6cca25e

sit_sot_scan_inputs.append(
expand_empty(
at.unbroadcast(shape_padleft(input.variable), 0),
actual_n_steps,
)
)

@ricardoV94 ricardoV94 force-pushed the replace_rebroadcast_by_specifyshape branch 12 times, most recently from 8a4b5a9 to 794801e Compare May 10, 2022 16:03
@@ -751,7 +751,7 @@ def wrap_into_list(x):
# defined in scan utils
sit_sot_scan_inputs.append(
expand_empty(
at.unbroadcast(shape_padleft(actual_arg), 0),
mask_broadcastable(shape_padleft(actual_arg), 0),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue here seems to be creating an EmptyAlloc that does not have a fixed shape of 1 on the leftmost dimension. I didn't dig to find out why/whether this is actually needed

@codecov
Copy link

codecov bot commented May 10, 2022

Codecov Report

Merging #915 (4cfaaa5) into main (ccfe2d3) will decrease coverage by 0.02%.
The diff coverage is 91.44%.

❗ Current head 4cfaaa5 differs from pull request most recent head 8a6a871. Consider uploading reports for the commit 8a6a871 to get more accurate results

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #915      +/-   ##
==========================================
- Coverage   79.27%   79.24%   -0.03%     
==========================================
  Files         152      152              
  Lines       47965    47848     -117     
  Branches    10923    10902      -21     
==========================================
- Hits        38023    37919     -104     
+ Misses       7438     7423      -15     
- Partials     2504     2506       +2     
Impacted Files Coverage Δ
aesara/compile/function/pfunc.py 83.41% <ø> (ø)
aesara/scan/op.py 85.39% <ø> (-0.09%) ⬇️
aesara/tensor/blas.py 79.63% <33.33%> (-0.08%) ⬇️
aesara/tensor/nnet/conv.py 79.40% <33.33%> (-0.57%) ⬇️
aesara/tensor/nnet/opt.py 43.02% <66.66%> (+0.05%) ⬆️
aesara/tensor/nnet/batchnorm.py 76.92% <81.81%> (-0.12%) ⬇️
aesara/tensor/basic_opt.py 85.66% <83.33%> (-0.50%) ⬇️
aesara/ifelse.py 49.71% <100.00%> (-0.15%) ⬇️
aesara/link/jax/dispatch.py 80.20% <100.00%> (-1.69%) ⬇️
aesara/link/numba/dispatch/tensor_basic.py 100.00% <100.00%> (+2.06%) ⬆️
... and 36 more

@ricardoV94
Copy link
Contributor Author

Tests are passing!

@ricardoV94 ricardoV94 force-pushed the replace_rebroadcast_by_specifyshape branch 2 times, most recently from e021750 to 6d2a42b Compare May 11, 2022 10:07
@ricardoV94 ricardoV94 marked this pull request as ready for review May 11, 2022 10:09
@ricardoV94 ricardoV94 force-pushed the replace_rebroadcast_by_specifyshape branch 2 times, most recently from 637827f to 953d410 Compare May 12, 2022 10:43
@ricardoV94 ricardoV94 force-pushed the replace_rebroadcast_by_specifyshape branch 2 times, most recently from 2cb7222 to d93f915 Compare May 17, 2022 16:44
@ricardoV94 ricardoV94 force-pushed the replace_rebroadcast_by_specifyshape branch 2 times, most recently from 09de773 to 217f985 Compare May 28, 2022 10:19
@ricardoV94
Copy link
Contributor Author

ricardoV94 commented May 28, 2022

I pushed a commit that removes the use of "unbroadcasting" in broadcast_like. This seems to have been introduced a long time ago in 9a45333, perhaps due to old rewrite limitations concerning change of static broadcastable type shape? It might no longer be relevant after #711, but unfortunately there were no specific tests targeting the functionality to be able to reason more directly.

In that same commit, several alloc calls were replaced by broadcast_like, which could mean that most of the remaining uses of broadcast_like might not be needed anymore? This would facilitate #288

If "unbroadcasting" is not necessary in broadcast_like, the only remaining use is in the creation of Scan. This is still clearly needed at the moment (even though it should not!) as indicated by the tests mentioned above.

@ricardoV94 ricardoV94 force-pushed the replace_rebroadcast_by_specifyshape branch from 217f985 to 1e26263 Compare May 28, 2022 16:50
@ricardoV94 ricardoV94 marked this pull request as draft May 28, 2022 17:54
@brandonwillard
Copy link
Member

If "unbroadcasting" is not necessary in broadcast_like, the only remaining use is in the creation of Scan. This is still clearly needed at the moment (even though it should not!) as indicated by the tests mentioned above.

Let's remove it from Scan. I'll take a look today.

@ricardoV94 ricardoV94 force-pushed the replace_rebroadcast_by_specifyshape branch 2 times, most recently from a31c9b0 to d3e2275 Compare May 31, 2022 07:30
@ricardoV94 ricardoV94 force-pushed the replace_rebroadcast_by_specifyshape branch from 4cfaaa5 to dab821b Compare June 7, 2022 10:06
@ricardoV94
Copy link
Contributor Author

ricardoV94 commented Jun 7, 2022

If "unbroadcasting" is not necessary in broadcast_like, the only remaining use is in the creation of Scan. This is still clearly needed at the moment (even though it should not!) as indicated by the tests mentioned above.

Let's remove it from Scan. I'll take a look today.

All the tests passed without the use of unbroadcast in broadcast_like, so Scan is the only place where it is used now

@ricardoV94 ricardoV94 marked this pull request as ready for review June 7, 2022 11:20
@ricardoV94 ricardoV94 force-pushed the replace_rebroadcast_by_specifyshape branch from dab821b to fb3a880 Compare July 7, 2022 12:48
@ricardoV94
Copy link
Contributor Author

Renamed MaskBroadcastable to Unbroadcast

Ricardo added 9 commits July 7, 2022 15:37
This rewrite predated partial shape specification in SpecifyShape, and as such ignored possible shape refinement over consecutive SpecifyShapes. It now merges information across consecutive SpecifyShape. Unlike the original, preference is given to the outer SpecifyShape, similar to what local_rebroadcast_lift does.
Adds condition in convert_variable_test which would fail before this change
The behavior was already accounted by filter_variable which is called directly on as a fallback by the optimizer routines
This change was introduced in 9a45333 and did not include specific tests. It was likely introduced to cope with the old restrictions regarding rewrite substitution of variables with different static broadcastable shape information, which was alleviated in aesara-devs#711
@ricardoV94 ricardoV94 force-pushed the replace_rebroadcast_by_specifyshape branch from fb3a880 to 8a6a871 Compare July 7, 2022 13:37
@brandonwillard brandonwillard merged commit 7f8af9b into aesara-devs:main Jul 7, 2022
@brandonwillard
Copy link
Member

brandonwillard commented Jul 7, 2022

This was/is a big improvement!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
important refactor This issue involves refactoring shape inference
Projects
None yet
Development

Successfully merging this pull request may close these issues.

MergeOptimization tries to merge nodes with different static shapes
2 participants