-
-
Notifications
You must be signed in to change notification settings - Fork 121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
cudnn complex convolution via gauss trick #517
Conversation
Great, looks good to go. Since this is an important feature, maybe we should document it somewhere? the |
I noticed one dissimilarity between the cpu and cuda versions was that the cuda version could only handle when both arguments are complex, whereas the cpu version can handle a mix (though with errors for some of the pullback (∇) functions). This is understandable when the mix of real and complex doesn't make sense. E.g. for I've now added a few more functions to allow the cuda conv functions to also handle a mix of real and complex inputs (in the same way that the cpu version can), and an additional sentence to the There is one test that fails for the mixed case ( |
I would make sure the case with non-zero beta isn't a correctness issue. We have some handling of that for single |
Ok, I think the beta bug is an error with the CPU version. Heres a MWE for the mixed real-complex using NNlib, CUDA, cuDNN
for T=(Float64, ComplexF64), beta=(0,1), flip=(false, true)
@show T, beta, flip
x_cpu = fill(T(1), 2, 1, 1)
w_cpu = T.([1; -1;;;])
x_gpu = CuArray(x_cpu)
w_gpu = CuArray(w_cpu)
cdims = NNlib.DenseConvDims(x_cpu, w_cpu; flipkernel=flip)
y_cpu = fill(T(1), 1, 1, 1)
y_gpu = CuArray(y_cpu)
w_cpu_2 = NNlib.∇conv_filter!(copy(w_cpu), real(x_cpu), y_cpu, cdims, alpha=T(1), beta=T(beta))
w_gpu_2 = NNlib.∇conv_filter!(copy(w_gpu), real(x_gpu), y_gpu, cdims, alpha=T(1), beta=T(beta))
@show w_cpu_2
@show w_gpu_2
@show w_cpu_2 ≈ Array(w_gpu_2)
end Output:
So only the last case fails with EditI investigated this further in #518, solution in #519. The PR should be good to go once that is merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed that we already have GPU-related notes in https://github.com/FluxML/NNlib.jl/blob/v0.9.3/docs/src/reference.md?plain=1#L79. It would be good to add what @CarloLucibello mentioned there too. Otherwise this LGTM.
It's great to have this feature now, thanks! |
Addresses #510 using the same method as in Pytorch (Gauss's trick, complex conv via 3 real convs).
PR Checklist
conv
docstring updated