-
-
Notifications
You must be signed in to change notification settings - Fork 67
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Issues with ViT
on the GPU
#165
Comments
Should solve at least part of FluxML#165
Should solve at least part of FluxML#165
For the reshape issue, have you looked at what is the output of |
Oh, you're right! That is indeed the issue. I'll see how I can make |
Typically, CUDA.jl will return Not sure exactly how we want to fix this. I feel that any change will be slower for the CPU path, and MLUtils.jl currently doesn't depend on CUDA.jl (we want to keep it this way). cc @CarloLucibello and @ToucheSir for some more eyes |
If I'm not mistaken, Metalhead.jl/src/layers/attention.jl Line 47 in edf83e0
|
Oops so then |
Looks like it does for CuArrays. My understanding from testing with Cthulhu is that https://github.com/JuliaLang/julia/blob/v1.7.3/base/abstractarraymath.jl#L136 transforms what should be a For our purposes, I think replacing the |
can this be closed? |
I think so, we can re-open if anything was missed. |
With some experimentation, issues pop up when I use ViT on the GPU. Documenting these so that they can be tracked down and solved:
Class tokens don't work on the GPU for now because
fill
doesn't automatically allocate on CPU/GPU as per the model (should be solved by UseMLUtils.ones_like
forClassTokens
#166 )MLUtils.chunk
isn't GPU-friendly yet. (A hotfix landed with Hotfix for ViT on GPU #169 so ViTs should work on GPUs for now but a long-term fix is pending)A scalar indexing warning comes up when I run the model (regression introduced in Fix mutation error #162 because of
selectdim
- indexing is unavoidable since data is not contiguous and soselectdim
returns a view) (should be solved by UseMLUtils.ones_like
forClassTokens
#166)The text was updated successfully, but these errors were encountered: