-
Notifications
You must be signed in to change notification settings - Fork 28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix TuringMvNormal's rand on GPU #108
Conversation
It seems the name My suggestion would be to
|
good point, I'll fix the the element types
As this was not an issue here before I assumed that not many people need GPU support? Adding CUDA increases loading times from about 5s to 8s on my machine. It would ofc be nicer to have it as a proper dependency, but I was not sure if that was desired. |
The increased loading times are unfortunate. In my experience using Requires tends to increase loading times as well though and, probably even worse, doesn't allow us to specify any compatibility bounds (which is annoying for the AD backends as well BTW). An alternative would be a separate package, maybe named |
Ok, I think I have not understood entirely how the CPU/GPU using DistributionsAD
using CUDA
using Random
using Zygote
Random.seed!(0)
μ = rand(Float32,2) |> cu
σ = rand(Float32,2) |> cu
d = DistributionsAD.TuringMvNormal(μ, σ)
display(rand(d))
f(μ,σ) = sum(rand(DistributionsAD.TuringMvNormal(μ,σ)))
for a in Zygote.gradient(f, μ, σ) display(a) end with the code from the commit above, fixing only the cpu seed fixes the output of fixing only the cpu seed and using DistributionsAD.jl/src/multivariate.jl Line 101 in 65c362d
seems to use cpu rng ...?
|
I don't remember exactly the original issue - was it caused mainly by |
the original issue (#98) was that |
Ok, so like this |
CUDA 1.3.3 is only compatible with julia 1.5
The latest release of CUDA even requires Julia 1.5. However, IMO we shouldn't drop support for Julia < 1.5 currently. |
Yes, agreed, the commit above restricts CUDA to <1.3.3 which is still compatible with julia 1.3 |
I think the tests are currently failing because of JuliaDiff/ChainRules.jl#262. So I think once JuliaDiff/ChainRules.jl#263 is merged we should be fine. Edit: acutally, it doesn't seem like it. the gradient test tracebacks are a bit tricky to read ^^ |
No, we only test gradients of Edit: Just saw your PR, yes, I guess it would make sense to test gradients there as well if it is supported by some AD backend. |
I'll open an issue for this |
I am still a bit puzzled by the failing |
Possibly. Tests on master passed with ChainRules v0.7.14 and ChainRulesCore v0.9.6, this PR fails with ChainRules v0.7.17 and ChainRulesCore v0.9.7. The differences are JuliaDiff/ChainRulesCore.jl@v0.9.6...v0.9.7 and JuliaDiff/ChainRules.jl@v0.7.14...v0.7.17. |
I guess the problem might be the call of DistributionsAD.jl/src/univariate.jl Line 43 in f754a19
https://github.com/JuliaDiff/ChainRules.jl/blob/0baf7bba2a9a2f235bf2a85edfdc6209f8cf137d/src/rulesets/Base/nondiff.jl#L123. |
Yes, you are right, thats it. The latest commit contains a hacky fix, because |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, I just have the following comments left:
- Can you please bump the version number to 0.6.8 so we can make a new release with the bugfix and new feature? (Ideally, they should be separated but IMO it's OK to keep it in one PR now)
- Can you open an issue over at ChainRulesCore to inform them about the breakage of
isapprox
and raise awareness for keyword argument support in@non_differentiable
? - Can you open an issue in DistributionsAD that we should remove
_isapprox
as soon as the underlying problem is fixed upstream? Otherwise we might just not remember it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, LGTM! Thanks for the PR!
It seems to work but we should make sure to actually run some GPU tests in the future as well: https://github.com/JuliaGPU/gitlab-ci
My attempt to fix #98 for
TuringMvNormal
by@require
ingCUDA
. What do you think?