diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 6dbfdb829..6d4bdf21f 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -272,7 +272,8 @@ end end - @adjoint (::Type{T})(xs::Array) where {T <: CUDA.CuArray} = + _eyelike(y::CUDA.CuVector{T}) where T = CUDA.CuArray(I(length(y))) + @adjoint (::Type{T})(xs::AbstractArray) where {T <: CUDA.CuArray} = T(xs), Δ -> (convert(Array, Δ), ) @adjoint function sum(xs::CUDA.AbstractGPUArray; dims = :) diff --git a/src/lib/grad.jl b/src/lib/grad.jl index a522d685a..ad2feb54a 100644 --- a/src/lib/grad.jl +++ b/src/lib/grad.jl @@ -173,12 +173,7 @@ _jvec(x::Number) = _jvec(vcat(x)) _jvec(x) = throw(ArgumentError("jacobian expected a function which returns an array, or a scalar, got $(typeof(x))")) _jvec(x::AbstractArray{<:Complex}) = throw(ArgumentError("jacobian does not accept complex output")) -_eyelike(y::Vector) = Matrix{eltype(y)}(I, length(y), length(y)) -function _eyelike(y::AbstractVector) # version which works on GPU - out = fill!(similar(y, length(y), length(y)), 0) - out[LinearAlgebra.diagind(out)] .= 1 - out -end +_eyelike(y::AbstractVector{T}) where T = T.(I(length(y))) _gradcopy!(dst::AbstractArray, src::AbstractArray{<:Number}) = copyto!(dst, src) _gradcopy!(dst::AbstractArray, src::Number) = copyto!(dst, src) diff --git a/test/cuda.jl b/test/cuda.jl index 5cb1c8cdc..586d2384a 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -87,6 +87,13 @@ end @test j2[v1] ≈ cu(res2) end +@testset "UniformScaling" begin + r = cu(rand(3)) + @test gradient(r) do r + sum(Zygote._eyelike(r) .+ r) + end == (cu(fill(3.f0,3)), ) +end + @testset "gradient algebra" begin w, b = rand(2) |> cu, rand(2) |> cu x1, x2 = rand(2) |> cu, rand(2) |> cu diff --git a/test/gradcheck.jl b/test/gradcheck.jl index e37e0ea15..25901e3c8 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1,47 +1,3 @@ -using Zygote, Test, Random, LinearAlgebra, Statistics, SparseArrays, FillArrays, - AbstractFFTs, FFTW, Distances -using Zygote: gradient -using Base.Broadcast: broadcast_shape -using Distributed: pmap, CachingPool, workers -import FiniteDifferences - -function ngradient(f, xs::AbstractArray...) - grads = zero.(xs) - for (x, Δ) in zip(xs, grads), i in 1:length(x) - δ = sqrt(eps()) - tmp = x[i] - x[i] = tmp - δ/2 - y1 = f(xs...) - x[i] = tmp + δ/2 - y2 = f(xs...) - x[i] = tmp - Δ[i] = (y2-y1)/δ - end - return grads -end - -function gradcheck(f, xs...) - grad_zygote = gradient(f, xs...) - grad_finite_difference = ngradient(f, xs...) - return all(isapprox.(grad_zygote, grad_finite_difference; rtol = 1e-5, atol = 1e-5)) -end - -gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...) -gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...) - -# utilities for using gradcheck with complex matrices -_splitreim(A) = (real(A),) -_splitreim(A::AbstractArray{<:Complex}) = reim(A) - -_joinreim(A, B) = complex.(A, B) -_joinreim(A) = A - -function _dropimaggrad(A) - back(Δ) = real(Δ) - back(Δ::Nothing) = nothing - return Zygote.hook(back, A) -end - Random.seed!(0) @testset "println, show, string, etc" begin diff --git a/test/gradcheck_utils.jl b/test/gradcheck_utils.jl new file mode 100644 index 000000000..63ed4bbd8 --- /dev/null +++ b/test/gradcheck_utils.jl @@ -0,0 +1,43 @@ +using Zygote, Test, Random, LinearAlgebra, Statistics, SparseArrays, FillArrays, + AbstractFFTs, FFTW, Distances +using Zygote: gradient +using Base.Broadcast: broadcast_shape +using Distributed: pmap, CachingPool, workers +import FiniteDifferences + +function ngradient(f, xs::AbstractArray...) + grads = zero.(xs) + for (x, Δ) in zip(xs, grads), i in 1:length(x) + δ = sqrt(eps()) + tmp = x[i] + x[i] = tmp - δ/2 + y1 = f(xs...) + x[i] = tmp + δ/2 + y2 = f(xs...) + x[i] = tmp + Δ[i] = (y2-y1)/δ + end + return grads +end + +function gradcheck(f, xs...) + grad_zygote = gradient(f, xs...) + grad_finite_difference = ngradient(f, xs...) + return all(isapprox.(grad_zygote, grad_finite_difference; rtol = 1e-5, atol = 1e-5)) +end + +gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...) +gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...) + +# utilities for using gradcheck with complex matrices +_splitreim(A) = (real(A),) +_splitreim(A::AbstractArray{<:Complex}) = reim(A) + +_joinreim(A, B) = complex.(A, B) +_joinreim(A) = A + +function _dropimaggrad(A) + back(Δ) = real(Δ) + back(Δ::Nothing) = nothing + return Zygote.hook(back, A) +end diff --git a/test/runtests.jl b/test/runtests.jl index 565ad182f..1f3facb43 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,7 @@ using Zygote, Test using Zygote: gradient, ZygoteRuleConfig using CUDA using CUDA: has_cuda +include("gradcheck_utils.jl") @testset "all" begin # Overall testset ensures it keeps running after failure