From d5f6b7711109036937f5733e6cfde88e2238b703 Mon Sep 17 00:00:00 2001 From: JulianTrommer Date: Thu, 7 Aug 2025 10:04:04 +0200 Subject: [PATCH 1/2] Added tests for scatter gradient --- test/nn/nnlib.jl | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/test/nn/nnlib.jl b/test/nn/nnlib.jl index 27b1a01d0b..24d8c3a84e 100644 --- a/test/nn/nnlib.jl +++ b/test/nn/nnlib.jl @@ -634,6 +634,49 @@ end test_scatter(dsts, srcs, idxs, res; dims=[0, 1]) end + + @testset "scatter gradient" begin + + dst = Float32[3 3 4 4 5 + 5 5 6 6 7] + dst_ca = Reactant.to_rarray(dst) + + src = ones(Float32, 2, 5) + src_ca = Reactant.to_rarray(src) + + idx = [4, 2, 1, 5, 3] + idx_ca = Reactant.to_rarray(idx) + + function test_scatter(dsts, srcs, idxs) + return sum(NNlib.scatter!(-, dsts, srcs, idxs)) + end + + function test_gradient(objective_function, dsts, srcs, idxs) + derivs, val = Enzyme.gradient( + Enzyme.set_abi(Enzyme.ReverseWithPrimal, Reactant.ReactantABI), + Const(objective_function), + dsts, + srcs, + idxs, + ) + return derivs, val + end + + test_gradient_compiled = @compile test_gradient(test_scatter, dst_ca, src_ca, idx_ca) + + grads_enz, loss_enz = Enzyme.gradient( + Enzyme.ReverseWithPrimal, + Const(test_scatter), + dst, + src, + idx + ) + grads_ca, loss_ca = test_gradient_compiled(test_scatter, dst_ca, src_ca, idx_ca) + + @test grads_enz[1] ≈ Array(grads_ca[1]) + @test grads_enz[2] ≈ Array(grads_ca[2]) + @test loss_enz ≈ loss_ca + end end @testset "∇conv(D = $ndim)" for ndim in 1:3 From e3ab84388d5d71a317f34657ae50ae47c5600f5e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 30 Aug 2025 19:49:55 -0500 Subject: [PATCH 2/2] fix: specify new commit --- deps/ReactantExtra/WORKSPACE | 2 +- test/nn/nnlib.jl | 19 +++++++++---------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 99ce3b9a69..bf37c00fc9 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -4,7 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023" NSYNC_SHA256 = "" -ENZYMEXLA_COMMIT = "6cef09efae7c8c1c7f17314469f1b4d9ace91c96" +ENZYMEXLA_COMMIT = "a5e8c0331a32f63ad80f50d84b7337af94f134ad" ENZYMEXLA_SHA256 = "" diff --git a/test/nn/nnlib.jl b/test/nn/nnlib.jl index 24d8c3a84e..240c52c155 100644 --- a/test/nn/nnlib.jl +++ b/test/nn/nnlib.jl @@ -636,9 +636,10 @@ end end @testset "scatter gradient" begin - - dst = Float32[3 3 4 4 5 - 5 5 6 6 7] + dst = Float32[ + 3 3 4 4 5 + 5 5 6 6 7 + ] dst_ca = Reactant.to_rarray(dst) src = ones(Float32, 2, 5) @@ -648,7 +649,7 @@ end idx_ca = Reactant.to_rarray(idx) function test_scatter(dsts, srcs, idxs) - return sum(NNlib.scatter!(-, dsts, srcs, idxs)) + return sum(NNlib.scatter!(+, dsts, srcs, idxs)) end function test_gradient(objective_function, dsts, srcs, idxs) @@ -662,14 +663,12 @@ end return derivs, val end - test_gradient_compiled = @compile test_gradient(test_scatter, dst_ca, src_ca, idx_ca) + test_gradient_compiled = @compile test_gradient( + test_scatter, dst_ca, src_ca, idx_ca + ) grads_enz, loss_enz = Enzyme.gradient( - Enzyme.ReverseWithPrimal, - Const(test_scatter), - dst, - src, - idx + Enzyme.ReverseWithPrimal, Const(test_scatter), dst, src, idx ) grads_ca, loss_ca = test_gradient_compiled(test_scatter, dst_ca, src_ca, idx_ca)