New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix OneHotVector/Matrix performance on GPU #591
Conversation
Does this help with #582? |
For that, we want to ensure |
This should help with #582 |
LGTM, but are there any tests we can add for this? e.g. returns CuArray on the GPU. |
test/cuda/cuda.jl
Outdated
@testset "onecold gpu" begin | ||
x = rand(Float32, 10, 3) |> gpu; | ||
y = Flux.onehotbatch(1:3, 1:10) |> gpu; | ||
@test_nowarn Flux.onecold(x) .== Flux.onecold(y) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@test_nowarn
doesn't work anymore.
julia> using Test
julia> @test_nowarn @warn "foo"
┌ Warning: foo
└ @ Main REPL[5]:1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops. Perhaps?
@testset "onecold gpu" begin
x = zeros(Float32, 10, 3) |> gpu;
y = Flux.onehotbatch(ones(3), 1:10) |> gpu;
res = Flux.onecold(x) .== Flux.onecold(y)
@test res isa CuArray
end
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That seems to indeed better test what we want to check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks bunches!
This PR aims to fix performance of OneHotMatrix performance on the GPU. It additionally adds some missing functionality to their behaviour.