-
Notifications
You must be signed in to change notification settings - Fork 129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement batched convolve1d #1318
Conversation
0fa631a
to
fcd70c1
Compare
One day I'll murder |
fcd70c1
to
8e96c41
Compare
8e96c41
to
7c23e43
Compare
Codecov ReportAttention: Patch coverage is
❌ Your patch status has failed because the patch coverage (96.73%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1318 +/- ##
==========================================
+ Coverage 81.98% 82.03% +0.04%
==========================================
Files 188 193 +5
Lines 48489 48622 +133
Branches 8673 8684 +11
==========================================
+ Hits 39756 39886 +130
- Misses 6582 6584 +2
- Partials 2151 2152 +1
🚀 New features to boost your workflow:
|
pytensor/tensor/signal/conv.py
Outdated
return cast(TensorVariable, Blockwise(Conv1d(mode=mode))(in1, in2)) | ||
|
||
|
||
def convolve( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not in love with the name and docstring. The docstring says it works on two n-dim arrays, but then raises if they're not 1d? But the name doesn't imply anything about the shapes?
Is this something that is meant to be a generic front-end to more functionality in the future?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is supposed to match the scipy one but right now only supports 1d. It raises a NotImplementedError, not a ValueError. The idea is when we add an arbitrary order conv this function will dispatch to those
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Want me to remove this until we actually support all the cases the scipy one does?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that's better. As it stands, it doesn't add anything (except confusion)
in1_batch_shape = tuple(in1.shape)[:-1] | ||
zeros_left = in2.shape[0] // 2 | ||
zeros_right = (in2.shape[0] - 1) // 2 | ||
in1 = join( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pad
wasn't useful here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't want pad until we figure out inline for it. I want PyTensor to optimize across the boundary, specially when gradients get involved
tests/link/numba/signal/test_conv.py
Outdated
rng = np.random.default_rng() | ||
test_x = rng.normal(size=(3, 5)) | ||
test_y = rng.normal(size=(7, 11)) | ||
# Object mode is not supported for numba |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because we didn't implement it. This is only used for coverage and wouldn't cover the relevant code since it's all low level numba
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Specifically it's the numba dispatch of Blockwise that doesn't have an obj implementation.. I'll tweak the comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clarified comment
out = convolve1d(x[None], y[:, None], mode=mode) | ||
|
||
rng = np.random.default_rng() | ||
test_x = rng.normal(size=(3, 5)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_x = rng.normal(size=(3, 5)) | |
test_x = rng.normal(size=(3, 5)).astype(floatX) |
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope no stupid float32 in jax/numba tests and also I explicitly used dmatrix
tests/tensor/signal/test_conv.py
Outdated
y = matrix("kernel") | ||
out = convolve1d(x, y) | ||
|
||
# Convolution is unchanged by order |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand the comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't matter if x, y, or y,x. I'm using that property to confirm batching is working as expected
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clarified comment
It would also be good to have like checkboxes on an issue of what we still
need to support
…On Tue, 25 Mar 2025, 14:11 Ricardo Vieira, ***@***.***> wrote:
***@***.**** commented on this pull request.
------------------------------
In pytensor/tensor/signal/conv.py
<#1318 (comment)>:
> + # We implement "same" as "valid" with padded `in1`.
+ in1_batch_shape = tuple(in1.shape)[:-1]
+ zeros_left = in2.shape[0] // 2
+ zeros_right = (in2.shape[0] - 1) // 2
+ in1 = join(
+ -1,
+ zeros((*in1_batch_shape, zeros_left), dtype=in2.dtype),
+ in1,
+ zeros((*in1_batch_shape, zeros_right), dtype=in2.dtype),
+ )
+ mode = "valid"
+
+ return cast(TensorVariable, Blockwise(Conv1d(mode=mode))(in1, in2))
+
+
+def convolve(
Want me to remove this until we actually support all the cases the scipy
one does?
—
Reply to this email directly, view it on GitHub
<#1318 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAACCUIXPA46ZTJND65OJJL2WFBXLAVCNFSM6AAAAABZU66RNSVHI2DSMVQWIX3LMV43YUDVNRWFEZLROVSXG5CSMV3GSZLXHMZDOMJTGY4TANJRGY>
.
You are receiving this because your review was requested.Message ID:
***@***.***>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove convolve
for now and I think this is good to merge.
If you don't want to lose track of the code, just removing it from signal.init.__all__
is probably good enough, and add a one-time-per-runtime warning that the function is a work in progress and users should prefer convolve1d
for now
I'll remove it |
7c23e43
to
8b55a71
Compare
Co-authored-by: Rob Zinkov <zaxtax@users.noreply.github.com> Co-authored-by: Ricardo Vieira <28983449+ricardov94@users.noreply.github.com>
8b55a71
to
39ab436
Compare
Spinoff from #548. It was obvious the old fast Ops were not really fast, just not as slow as the
AbstractConv
Ops. I suggest we start removing those, by more sane implementations.This PR implements a simple convolve1d (that can be batched by blockwise). This is actually not offered by either scipy.signal nor numpy.convolve. The scipy impl has no concept of batch dims, so if you pass two 2d inputs you get a convolve2d. The numpy one doesn't support batch dims, period.
Because the 1d case implemented here matches with the one you would get from calling
scipy.convolve
I added a small wrapper that allows you to go that route. Consider it WIP.Also note that
same
behaves differently between numpy and scipy. In numpy it returns something with the shape of the largest of the two inputs, in scipy with the shape of the first input. The one in this PR behaves like scipy. This may be a good argument to go down the jax route and implement a pytensor.numpy and a pytensor.scipy, because such differences are likely to hit users, specially if they don't know which one we are mimicking.This Op can be used to replace
batched_convolution
inpymc-marketing
with much better graph compile time and runtime (specially as the size of the convolution grows).📚 Documentation preview 📚: https://pytensor--1318.org.readthedocs.build/en/1318/