Skip to content

Commit

Permalink
test_broken for GaussAdjoint, also test gradient equality
Browse files Browse the repository at this point in the history
  • Loading branch information
jClugstor committed Aug 1, 2024
1 parent de43449 commit 36032a8
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions test/prob_kwargs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,18 @@ let callback_count1 = 0, callback_count2 = 0
end

@testset "Callback duplication check" begin
u0p = [2.0, 3.0]
for adjoint_type in [
ForwardDiffSensitivity(), ReverseDiffAdjoint(), TrackerAdjoint(),
BacksolveAdjoint(), InterpolatingAdjoint(), QuadratureAdjoint()]
BacksolveAdjoint(), InterpolatingAdjoint(), QuadratureAdjoint(), GaussAdjoint()]
count1 = 0
count2 = 0
u0p = [2.0, 3.0]
Zygote.gradient(x -> f1(x, adjoint_type), u0p)
Zygote.gradient(x -> f2(x, adjoint_type), u0p)

@test callback_count1 == callback_count2
if adjoint_type == GaussAdjoint()
@test_broken Zygote.gradient(x -> f1(x, adjoint_type), u0p) == Zygote.gradient(x -> f2(x, adjoint_type), u0p)
else
@test Zygote.gradient(x -> f1(x, adjoint_type), u0p) == Zygote.gradient(x -> f2(x, adjoint_type), u0p)
@test callback_count1 == callback_count2
end
end
end
end

0 comments on commit 36032a8

Please sign in to comment.