You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
An improvement would be to match jax.lax.conv_general_dilated signature; this requires adding the lhs_dilation option and renaming the function arguments.
Related issues and motivation for FFT implementation :
Was wondering if you have any plans to integrate fft_conv_general_dilated with the jax repo? It would be nice to speed up jax.lax.conv_general_dilated for CPU's...
Currently, fft based convolution function looks like this
An improvement would be to match
jax.lax.conv_general_dilated
signature; this requires adding the lhs_dilation option and renaming the function arguments.Related issues and motivation for FFT implementation :
google/jax#6343
google/jax#5227
google/jax#7284 ->
serket.nn.FFTFilter2D
can improve on thisgoogle-deepmind/dm_pix#46
serket.nn.FFTFilter2D
can improve on this for large kernel sizeThe text was updated successfully, but these errors were encountered: