Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor fft_conv_general_dilated to match jax.lax.conv_general_dilated signature #10

Closed
ASEM000 opened this issue Nov 4, 2022 · 3 comments

Comments

@ASEM000
Copy link
Owner

ASEM000 commented Nov 4, 2022

Currently, fft based convolution function looks like this

def fft_conv_general_dilated(
    x: jnp.ndarray,
    w: jnp.ndarray,
    strides: tuple[int, ...],
    padding: tuple[tuple[int, int], ...],
    groups: int,
    dilation: tuple[int, ...],
) -> jnp.ndarray:

An improvement would be to match jax.lax.conv_general_dilated signature; this requires adding the lhs_dilation option and renaming the function arguments.

Related issues and motivation for FFT implementation :

google/jax#6343
google/jax#5227
google/jax#7284 -> serket.nn.FFTFilter2D can improve on this
google-deepmind/dm_pix#46 serket.nn.FFTFilter2D can improve on this for large kernel size

@ASEM000 ASEM000 closed this as not planned Won't fix, can't repro, duplicate, stale Mar 19, 2023
@chrisflesher
Copy link

chrisflesher commented Apr 6, 2023

Was wondering if you have any plans to integrate fft_conv_general_dilated with the jax repo? It would be nice to speed up jax.lax.conv_general_dilated for CPU's...

@ASEM000
Copy link
Owner Author

ASEM000 commented Apr 6, 2023

Sounds reasonable, I will reopen the issue when time permits.

@chrisflesher
Copy link

Cool, thanks!

@ASEM000 ASEM000 reopened this Aug 3, 2023
@ASEM000 ASEM000 closed this as not planned Won't fix, can't repro, duplicate, stale Aug 3, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants