-
-
Notifications
You must be signed in to change notification settings - Fork 216
Open
Description
Consider this MWE, actually, I am not 100 sure if the bug is with Zygote or Cuda.jl generated code but I suspect it's a Zygote issue since loss(xy...) works without error, so I think Zygote is generating index access code on CuArrays.jl in the adjoint which is causing the error, hence I report it here first.
using Flux, CUDA
m = Chain(
Dense(10, 5, relu),
Dropout(0.0),
Dense(5, 5)) |> gpu
CUDA.allowscalar(false)
opt = Flux.Optimise.ADAM()
p = Flux.params(m)
sqnorm(x) = sum(abs2, x)
loss(x,y) = begin
Flux.Losses.mse(m(x), y) + sum(sqnorm, p)
end
Flux.train!(loss, p, [(rand(10,3), rand(5,3)) |> gpu], opt)
the above works, and If I just modify loss a little but multiplying the penalty by 0.1 then it fails with error ERROR: scalar getindex is disallowed
loss(x,y) = begin
Flux.Losses.mse(m(x), y) + sum(sqnorm, p)*0.1
end
Flux.train!(loss, p, [(rand(10,3), rand(5,3)) |> gpu], opt)
but if I modify it again by dividing by 10 then it works
loss(x,y) = begin
Flux.Losses.mse(m(x), y) + sum(sqnorm, p) / 10
end
Flux.train!(loss, p, [(rand(10,3), rand(5,3)) |> gpu], opt)
gabrielpreviato
Metadata
Metadata
Assignees
Labels
No labels