Skip to content

Commit

Permalink
Support for Zygote and ReverseDiff gradients (#427)
Browse files Browse the repository at this point in the history
* Support for Zygote and ReverseDiff gradients

* Add Zygote test dependency

* bump ambiguity limit because of Zygote

* fix tests and Zygote backend

* bump Julia to 1.5

* fixing some issues on Julia 1.7-rc1

* more fixing for Julia 1.7

* formatting

* reduce tangent vector length in a test

since the approximation only works very locally (and they changed the default random number generator)

* Update Project.toml

Co-authored-by: Ronny Bergmann <git@ronnybergmann.net>

* bump atol on Rotations

* reduce tangent vector size even further.

* adapt one more tolerance

Co-authored-by: Ronny Bergmann <git@ronnybergmann.net>
  • Loading branch information
mateuszbaran and kellertuer committed Sep 23, 2021
1 parent 9fc772a commit 2a03fb1
Show file tree
Hide file tree
Showing 16 changed files with 118 additions and 26 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
julia-version: ["1.4", "1.6"]
julia-version: ["1.5", "1.6", "~1.7.0-0"]
os: [ubuntu-latest, macOS-latest]
steps:
- uses: actions/checkout@v2
Expand Down
7 changes: 4 additions & 3 deletions Project.toml
@@ -1,7 +1,7 @@
name = "Manifolds"
uuid = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
authors = ["Seth Axen <seth.axen@gmail.com>", "Mateusz Baran <mateuszbaran89@gmail.com>", "Ronny Bergmann <manopt@ronnybergmann.net>", "Antoine Levitt <antoine.levitt@gmail.com>"]
version = "0.6.8"
version = "0.6.9"

[deps]
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
Expand Down Expand Up @@ -42,7 +42,7 @@ SimpleWeightedGraphs = "1"
SpecialFunctions = "0.8, 0.9, 0.10, 1.0"
StaticArrays = "1.0"
StatsBase = "0.32, 0.33"
julia = "1.4"
julia = "1.5"

[extras]
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
Expand All @@ -62,6 +62,7 @@ RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
VisualRegressionTests = "34922c18-7c2a-561c-bac1-01e79b2c4c92"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "Colors", "DoubleFloats", "FiniteDiff", "ForwardDiff", "Gtk", "ImageIO", "ImageMagick", "OrdinaryDiffEq", "NLsolve", "Plots", "PyPlot", "Quaternions", "QuartzImageIO", "RecipesBase", "ReverseDiff"]
test = ["Test", "Colors", "DoubleFloats", "FiniteDiff", "ForwardDiff", "Gtk", "ImageIO", "ImageMagick", "OrdinaryDiffEq", "NLsolve", "Plots", "PyPlot", "Quaternions", "QuartzImageIO", "RecipesBase", "ReverseDiff", "Zygote"]
19 changes: 15 additions & 4 deletions src/Manifolds.jl
Expand Up @@ -157,8 +157,8 @@ using RecursiveArrayTools: ArrayPartition
include("utils.jl")

include("product_representations.jl")
include("differentiation.jl")
include("riemannian_diff.jl")
include("differentiation/differentiation.jl")
include("differentiation/riemannian_diff.jl")

# Main Meta Manifolds
include("manifolds/ConnectionManifold.jl")
Expand Down Expand Up @@ -284,12 +284,12 @@ end
function __init__()
@require FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" begin
using .FiniteDiff
include("finite_diff.jl")
include("differentiation/finite_diff.jl")
end

@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin
using .ForwardDiff
include("forward_diff.jl")
include("differentiation/forward_diff.jl")
end

@require OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" begin
Expand All @@ -302,6 +302,11 @@ function __init__()
include("nlsolve.jl")
end

@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
using .ReverseDiff: ReverseDiff
include("differentiation/reverse_diff.jl")
end

@require Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" begin
using .Test: Test
include("tests/tests_general.jl")
Expand Down Expand Up @@ -332,6 +337,12 @@ function __init__()
include("recipes.jl")
end
end

@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
using .Zygote: Zygote
include("differentiation/zygote.jl")
end

return nothing
end

Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
11 changes: 11 additions & 0 deletions src/differentiation/reverse_diff.jl
@@ -0,0 +1,11 @@
struct ReverseDiffBackend <: AbstractDiffBackend end

function Manifolds._gradient(f, p, ::ReverseDiffBackend)
return ReverseDiff.gradient(f, p)
end

function Manifolds._gradient!(f, X, p, ::ReverseDiffBackend)
return ReverseDiff.gradient!(X, f, p)
end

push!(Manifolds._diff_backends, ReverseDiffBackend())
File renamed without changes.
11 changes: 11 additions & 0 deletions src/differentiation/zygote.jl
@@ -0,0 +1,11 @@
struct ZygoteDiffBackend <: AbstractDiffBackend end

function Manifolds._gradient(f, p, ::ZygoteDiffBackend)
return Zygote.gradient(f, p)[1]
end

function Manifolds._gradient!(f, X, p, ::ZygoteDiffBackend)
return copyto!(X, Zygote.gradient(f, p)[1])
end

push!(Manifolds._diff_backends, ZygoteDiffBackend())
18 changes: 16 additions & 2 deletions test/ambiguities.jl
Expand Up @@ -4,12 +4,26 @@
# Interims solution until we follow what was proposed in
# https://discourse.julialang.org/t/avoid-ambiguities-with-individual-number-element-identity/62465/2
fmbs = filter(x -> !any(has_type_in_signature.(x, Identity)), mbs)
@test length(fmbs) <= 20
FMBS_LIMIT = 20
@test length(fmbs) <= FMBS_LIMIT
if length(fmbs) > FMBS_LIMIT
for amb in fmbs
println(amb)
println()
end
end
ms = Test.detect_ambiguities(Manifolds)
# Interims solution until we follow what was proposed in
# https://discourse.julialang.org/t/avoid-ambiguities-with-individual-number-element-identity/62465/2
fms = filter(x -> !any(has_type_in_signature.(x, Identity)), ms)
@test length(fms) <= 17
FMS_LIMIT = 21
if length(fms) > FMS_LIMIT
for amb in fms
println(amb)
println()
end
end
@test length(fms) <= FMS_LIMIT
# this test takes way too long to perform regularly
# @test length(our_base_ambiguities()) <= 4
else
Expand Down
4 changes: 3 additions & 1 deletion test/approx_inverse_retraction.jl
Expand Up @@ -3,6 +3,8 @@ using LinearAlgebra

include("utils.jl")

Random.seed!(10)

@testset "approximate inverse retractions" begin
@testset "NLsolveInverseRetraction" begin
@testset "constructor" begin
Expand Down Expand Up @@ -62,7 +64,7 @@ include("utils.jl")
NLsolveInverseRetraction(ProjectionRetraction(), X0; project_point=true)
X = inverse_retract(M, p, q, inv_retr_method)
@test is_vector(M, p, X; atol=1e-9)
@test X X_exp
@test X X_exp atol = 1e-8
@test_throws OutOfInjectivityRadiusError inverse_retract(
M,
p,
Expand Down
64 changes: 52 additions & 12 deletions test/differentiation.jl
Expand Up @@ -17,7 +17,7 @@ using LinearAlgebra: Diagonal, dot
fd51 = Manifolds.FiniteDifferencesBackend()
@testset "diff_backend" begin
@test diff_backend() isa Manifolds.FiniteDifferencesBackend
@test length(diff_backends()) == 2
@test length(diff_backends()) == 3
@test diff_backends()[1] isa Manifolds.FiniteDifferencesBackend

@test length(fd51.method.grid) == 5
Expand All @@ -33,7 +33,7 @@ using LinearAlgebra: Diagonal, dot
fwd_diff = Manifolds.ForwardDiffBackend()
@testset "ForwardDiff" begin
@test diff_backend() isa Manifolds.FiniteDifferencesBackend
@test length(diff_backends()) == 2
@test length(diff_backends()) == 3
@test diff_backends()[1] isa Manifolds.FiniteDifferencesBackend
@test diff_backends()[2] == fwd_diff

Expand All @@ -52,8 +52,8 @@ using LinearAlgebra: Diagonal, dot
finite_diff = Manifolds.FiniteDiffBackend()
@testset "FiniteDiff" begin
@test diff_backend() isa Manifolds.FiniteDifferencesBackend
@test length(diff_backends()) == 3
@test diff_backends()[3] == finite_diff
@test length(diff_backends()) == 4
@test diff_backends()[4] == finite_diff

@test diff_backend!(finite_diff) == finite_diff
@test diff_backend() == finite_diff
Expand All @@ -65,6 +65,42 @@ using LinearAlgebra: Diagonal, dot
diff_backend!(fd51)
end

using ReverseDiff

reverse_diff = Manifolds.ReverseDiffBackend()
@testset "ReverseDiff" begin
@test diff_backend() isa Manifolds.FiniteDifferencesBackend
@test length(diff_backends()) == 4
@test diff_backends()[3] == reverse_diff

@test diff_backend!(reverse_diff) == reverse_diff
@test diff_backend() == reverse_diff
@test diff_backend!(fd51) isa Manifolds.FiniteDifferencesBackend
@test diff_backend() isa Manifolds.FiniteDifferencesBackend

diff_backend!(reverse_diff)
@test diff_backend() == reverse_diff
diff_backend!(fd51)
end

using Zygote: Zygote

zygote_diff = Manifolds.ZygoteDiffBackend()
@testset "Zygote" begin
@test diff_backend() isa Manifolds.FiniteDifferencesBackend
@test length(diff_backends()) == 5
@test diff_backends()[5] == zygote_diff

@test diff_backend!(zygote_diff) == zygote_diff
@test diff_backend() == zygote_diff
@test diff_backend!(fd51) isa Manifolds.FiniteDifferencesBackend
@test diff_backend() isa Manifolds.FiniteDifferencesBackend

diff_backend!(zygote_diff)
@test diff_backend() == zygote_diff
diff_backend!(fd51)
end

@testset "gradient" begin
diff_backend!(fd51)
r2 = Euclidean(2)
Expand All @@ -74,11 +110,11 @@ using LinearAlgebra: Diagonal, dot
f2(x) = 3 * x[1] * x[2] + x[2]^3

@testset "Inference" begin
v = [-1.0, -1.0]
X = [-1.0, -1.0]
@test (@inferred _derivative(c1, 0.0, Manifolds.ForwardDiffBackend()))
[1.0, 0.0]
@test (@inferred _derivative!(c1, v, 0.0, Manifolds.ForwardDiffBackend())) === v
@test v [1.0, 0.0]
@test (@inferred _derivative!(c1, X, 0.0, Manifolds.ForwardDiffBackend())) === X
@test X [1.0, 0.0]

@test (@inferred _derivative(c1, 0.0, finite_diff)) [1.0, 0.0]
@test (@inferred _gradient(f1, [1.0, -1.0], finite_diff)) [1.0, -2.0]
Expand All @@ -87,12 +123,16 @@ using LinearAlgebra: Diagonal, dot
@testset for backend in [fd51, fwd_diff, finite_diff]
diff_backend!(backend)
@test _derivative(c1, 0.0) [1.0, 0.0]
v = [-1.0, -1.0]
@test _derivative!(c1, v, 0.0) === v
@test isapprox(v, [1.0, 0.0])
X = [-1.0, -1.0]
@test _derivative!(c1, X, 0.0) === X
@test isapprox(X, [1.0, 0.0])
end
@testset for backend in [fd51, fwd_diff, finite_diff, reverse_diff, zygote_diff]
diff_backend!(backend)
X = [-1.0, -1.0]
@test _gradient(f1, [1.0, -1.0]) [1.0, -2.0]
@test _gradient!(f1, v, [1.0, -1.0]) === v
@test v [1.0, -2.0]
@test _gradient!(f1, X, [1.0, -1.0]) === X
@test X [1.0, -2.0]
end
diff_backend!(Manifolds.NoneDiffBackend())
@testset for backend in [fd51, Manifolds.ForwardDiffBackend()]
Expand Down
2 changes: 2 additions & 0 deletions test/groups/special_euclidean.jl
Expand Up @@ -78,6 +78,7 @@ Random.seed!(10)
X_pts;
test_diff=true,
diff_convs=[(), (LeftAction(),), (RightAction(),)],
atol=1e-9,
)
end
end
Expand Down Expand Up @@ -128,6 +129,7 @@ Random.seed!(10)
test_diff=true,
test_lie_bracket=true,
diff_convs=[(), (LeftAction(),), (RightAction(),)],
atol=1e-9,
)
# specific affine tests
p = copy(G, pts[1])
Expand Down
2 changes: 1 addition & 1 deletion test/manifolds/power_manifold.jl
Expand Up @@ -240,7 +240,7 @@ end
rand_tvector_atol_multiplier=6.0,
retraction_atol_multiplier=12,
is_tangent_atol_multiplier=12.0,
exp_log_atol_multiplier=2 * prod(power_dimensions(Ms2)),
exp_log_atol_multiplier=3 * prod(power_dimensions(Ms2)),
test_inplace=true,
)
end
Expand Down
2 changes: 1 addition & 1 deletion test/manifolds/rotations.jl
Expand Up @@ -116,7 +116,7 @@ include("../utils.jl")
point_distributions=[ptd],
tvector_distributions=[tvd],
basis_types_to_from=basis_types,
exp_log_atol_multiplier=20,
exp_log_atol_multiplier=250,
retraction_atol_multiplier=12,
test_inplace=true,
)
Expand Down
2 changes: 1 addition & 1 deletion test/manifolds/stiefel.jl
Expand Up @@ -311,7 +311,7 @@ using Manifolds: default_metric_dispatch
M4 = MetricManifold(Stiefel(10, 2), CanonicalMetric())
p = Matrix{Float64}(I, 10, 2)
Random.seed!(42)
Z = project(base_manifold(M4), p, randn(size(p)))
Z = project(base_manifold(M4), p, 0.2 .* randn(size(p)))
s = exp(M4, p, Z)
Z2 = log(M4, p, s)
@test isapprox(M4, p, Z, Z2)
Expand Down

0 comments on commit 2a03fb1

Please sign in to comment.