Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement batched convolve1d #1318

Merged
merged 1 commit into from
Mar 27, 2025

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Mar 24, 2025

Spinoff from #548. It was obvious the old fast Ops were not really fast, just not as slow as the AbstractConv Ops. I suggest we start removing those, by more sane implementations.

This PR implements a simple convolve1d (that can be batched by blockwise). This is actually not offered by either scipy.signal nor numpy.convolve. The scipy impl has no concept of batch dims, so if you pass two 2d inputs you get a convolve2d. The numpy one doesn't support batch dims, period.

Because the 1d case implemented here matches with the one you would get from calling scipy.convolve I added a small wrapper that allows you to go that route. Consider it WIP.

Also note that same behaves differently between numpy and scipy. In numpy it returns something with the shape of the largest of the two inputs, in scipy with the shape of the first input. The one in this PR behaves like scipy. This may be a good argument to go down the jax route and implement a pytensor.numpy and a pytensor.scipy, because such differences are likely to hit users, specially if they don't know which one we are mimicking.

This Op can be used to replace batched_convolution in pymc-marketing with much better graph compile time and runtime (specially as the size of the convolution grows).


📚 Documentation preview 📚: https://pytensor--1318.org.readthedocs.build/en/1318/

@ricardoV94
Copy link
Member Author

One day I'll murder floatX="float32", I swear

Copy link

codecov bot commented Mar 24, 2025

Codecov Report

Attention: Patch coverage is 96.73913% with 3 lines in your changes missing coverage. Please review.

Project coverage is 82.03%. Comparing base (8a7356c) to head (39ab436).
Report is 4 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/tensor/signal/conv.py 97.05% 1 Missing and 1 partial ⚠️
pytensor/link/numba/dispatch/signal/conv.py 90.90% 1 Missing ⚠️

❌ Your patch status has failed because the patch coverage (96.73%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1318      +/-   ##
==========================================
+ Coverage   81.98%   82.03%   +0.04%     
==========================================
  Files         188      193       +5     
  Lines       48489    48622     +133     
  Branches     8673     8684      +11     
==========================================
+ Hits        39756    39886     +130     
- Misses       6582     6584       +2     
- Partials     2151     2152       +1     
Files with missing lines Coverage Δ
pytensor/link/jax/dispatch/__init__.py 100.00% <100.00%> (ø)
pytensor/link/jax/dispatch/signal/__init__.py 100.00% <100.00%> (ø)
pytensor/link/jax/dispatch/signal/conv.py 100.00% <100.00%> (ø)
pytensor/link/numba/dispatch/__init__.py 100.00% <100.00%> (ø)
pytensor/link/numba/dispatch/signal/__init__.py 100.00% <100.00%> (ø)
pytensor/link/numba/dispatch/signal/conv.py 90.90% <90.90%> (ø)
pytensor/tensor/signal/conv.py 97.05% <97.05%> (ø)

... and 1 file with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

return cast(TensorVariable, Blockwise(Conv1d(mode=mode))(in1, in2))


def convolve(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not in love with the name and docstring. The docstring says it works on two n-dim arrays, but then raises if they're not 1d? But the name doesn't imply anything about the shapes?

Is this something that is meant to be a generic front-end to more functionality in the future?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is supposed to match the scipy one but right now only supports 1d. It raises a NotImplementedError, not a ValueError. The idea is when we add an arbitrary order conv this function will dispatch to those

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Want me to remove this until we actually support all the cases the scipy one does?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's better. As it stands, it doesn't add anything (except confusion)

in1_batch_shape = tuple(in1.shape)[:-1]
zeros_left = in2.shape[0] // 2
zeros_right = (in2.shape[0] - 1) // 2
in1 = join(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pad wasn't useful here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want pad until we figure out inline for it. I want PyTensor to optimize across the boundary, specially when gradients get involved

rng = np.random.default_rng()
test_x = rng.normal(size=(3, 5))
test_y = rng.normal(size=(7, 11))
# Object mode is not supported for numba
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we didn't implement it. This is only used for coverage and wouldn't cover the relevant code since it's all low level numba

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specifically it's the numba dispatch of Blockwise that doesn't have an obj implementation.. I'll tweak the comment

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clarified comment

out = convolve1d(x[None], y[:, None], mode=mode)

rng = np.random.default_rng()
test_x = rng.normal(size=(3, 5))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
test_x = rng.normal(size=(3, 5))
test_x = rng.normal(size=(3, 5)).astype(floatX)

?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope no stupid float32 in jax/numba tests and also I explicitly used dmatrix

y = matrix("kernel")
out = convolve1d(x, y)

# Convolution is unchanged by order
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the comment

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't matter if x, y, or y,x. I'm using that property to confirm batching is working as expected

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clarified comment

@zaxtax
Copy link
Contributor

zaxtax commented Mar 25, 2025 via email

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove convolve for now and I think this is good to merge.

If you don't want to lose track of the code, just removing it from signal.init.__all__ is probably good enough, and add a one-time-per-runtime warning that the function is a work in progress and users should prefer convolve1d for now

@ricardoV94
Copy link
Member Author

I'll remove it

Co-authored-by: Rob Zinkov <zaxtax@users.noreply.github.com>
Co-authored-by: Ricardo Vieira <28983449+ricardov94@users.noreply.github.com>
@jessegrabowski jessegrabowski merged commit 0b56ed9 into pymc-devs:main Mar 27, 2025
72 of 73 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants