Skip to content

Bug: adding a penalty term using 0.1*penalty fails with error ERROR: scalar getindex is disallowed #770

@xiaodaigh

Description

@xiaodaigh

Consider this MWE, actually, I am not 100 sure if the bug is with Zygote or Cuda.jl generated code but I suspect it's a Zygote issue since loss(xy...) works without error, so I think Zygote is generating index access code on CuArrays.jl in the adjoint which is causing the error, hence I report it here first.

using Flux, CUDA

m = Chain(
    Dense(10, 5, relu),
    Dropout(0.0),
    Dense(5, 5))  |> gpu

CUDA.allowscalar(false)
opt = Flux.Optimise.ADAM()

p = Flux.params(m)

sqnorm(x) = sum(abs2, x)

loss(x,y) = begin
    Flux.Losses.mse(m(x), y) + sum(sqnorm, p)
end

Flux.train!(loss, p, [(rand(10,3), rand(5,3)) |> gpu], opt)

the above works, and If I just modify loss a little but multiplying the penalty by 0.1 then it fails with error ERROR: scalar getindex is disallowed

loss(x,y) = begin
    Flux.Losses.mse(m(x), y) + sum(sqnorm, p)*0.1
end

Flux.train!(loss, p, [(rand(10,3), rand(5,3)) |> gpu], opt)

but if I modify it again by dividing by 10 then it works

loss(x,y) = begin
    Flux.Losses.mse(m(x), y) + sum(sqnorm, p) / 10
end

Flux.train!(loss, p, [(rand(10,3), rand(5,3)) |> gpu], opt)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions