-
-
Notifications
You must be signed in to change notification settings - Fork 216
Open
Description
The following example shows inaccuracy of .== when executed on the GPU within gradient computation:
using Metal, CUDA
using Flux
device = gpu_device()
# device = cpu_device()
f = x -> begin
y = [0, 1, 2] |> device
mask = y .== 1
return sum(x[mask])
end
x = Float32[1, 2, 3] |> device
grad = Flux.gradient(f, x) # should be [0.0, 1.0, 0.0], got [0.0, 0.0, 0.0]Per this discussion, this is due to the specific way broadcasted functions are differentiated through on GPU using ForwardDiff.
The problem can be avoided by replacing
mask = y .== 1with
mask = Flux.@ignore_derivatives y .== 1The problem seems to be automatically circumvented on CPU
Zygote.jl/src/lib/broadcast.jl
Lines 206 to 211 in 1b914d9
| @adjoint broadcasted(::AbstractArrayStyle, f::F, args...) where {F} = _broadcast_generic(__context__, f, args...) | |
| @inline function _broadcast_generic(__context__, f::F, args...) where {F} | |
| T = Broadcast.combine_eltypes(f, args) | |
| # Avoid generic broadcasting in two easy cases: | |
| if T == Bool | |
| return (f.(args...), _ -> nothing) |
but not on GPU
Zygote.jl/src/lib/broadcast.jl
Lines 359 to 363 in 1b914d9
| # Ordinary broadcasting calls broadcast_forward anyway when certain its' safe, | |
| # so perhaps this can be deleted? Possible edge case here: | |
| # https://github.com/FluxML/Zygote.jl/pull/1018#issuecomment-873629415 | |
| @adjoint broadcasted(::AbstractGPUArrayStyle, f, args...) = | |
| broadcast_forward(f, args...) |
Due to the potential difficulty of spotting this unexpected behavior, this may worth being considered a bug that warrants fixing.
Metadata
Metadata
Assignees
Labels
No labels