From 106a8f35f25c5e0501b7b9ec7cf6456e4effe947 Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Sun, 8 Oct 2023 14:24:40 +0200 Subject: [PATCH] Improve performance of selected operations on SE(n) (#655) * Improve performance of selected operations on SE(n) * fix tests, bump version * More special cases * faster hat * fix test * maybe improve coverage? * meh --- Project.toml | 2 +- src/groups/special_euclidean.jl | 38 ++++++++++++++++++++ src/manifolds/GeneralUnitaryMatrices.jl | 38 +++++++++++++++++++- test/groups/special_euclidean.jl | 48 +++++++++++++++++++++++++ 4 files changed, 124 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 97e5750fb2..2ffd46f757 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Manifolds" uuid = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e" authors = ["Seth Axen ", "Mateusz Baran ", "Ronny Bergmann ", "Antoine Levitt "] -version = "0.8.79" +version = "0.8.80" [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" diff --git a/src/groups/special_euclidean.jl b/src/groups/special_euclidean.jl index 5b563f0ff1..a5559d83e3 100644 --- a/src/groups/special_euclidean.jl +++ b/src/groups/special_euclidean.jl @@ -692,3 +692,41 @@ end function project!(M::SpecialEuclideanInGeneralLinear, Y, p, X) return copyto!(Y, project(M, p, X)) end + +### Special methods for better performance of selected operations + +function exp(M::SpecialEuclidean, p::ArrayPartition, X::ArrayPartition) + M1, M2 = M.manifold.manifolds + return ArrayPartition( + exp(M1.manifold, p.x[1], X.x[1]), + exp(M2.manifold, p.x[2], X.x[2]), + ) +end +function log(M::SpecialEuclidean, p::ArrayPartition, q::ArrayPartition) + M1, M2 = M.manifold.manifolds + return ArrayPartition( + log(M1.manifold, p.x[1], q.x[1]), + log(M2.manifold, p.x[2], q.x[2]), + ) +end +function vee(M::SpecialEuclidean, p::ArrayPartition, X::ArrayPartition) + M1, M2 = M.manifold.manifolds + return vcat(vee(M1.manifold, p.x[1], X.x[1]), vee(M2.manifold, p.x[2], X.x[2])) +end +function hat(M::SpecialEuclidean{2}, p::ArrayPartition, c::SVector) + M1, M2 = M.manifold.manifolds + return ArrayPartition( + get_vector_orthogonal(M1.manifold, p.x[1], c[SOneTo(2)], ℝ), + get_vector_orthogonal(M2.manifold, p.x[2], c[SA[3]], ℝ), + ) +end +function hat(M::SpecialEuclidean{3}, p::ArrayPartition, c::SVector) + M1, M2 = M.manifold.manifolds + return ArrayPartition( + get_vector_orthogonal(M1.manifold, p.x[1], c[SOneTo(3)], ℝ), + get_vector_orthogonal(M2.manifold, p.x[2], c[SA[4, 5, 6]], ℝ), + ) +end +function compose(::SpecialEuclidean, p::ArrayPartition, q::ArrayPartition) + return ArrayPartition(p.x[2] * q.x[1] + p.x[1], p.x[2] * q.x[2]) +end diff --git a/src/manifolds/GeneralUnitaryMatrices.jl b/src/manifolds/GeneralUnitaryMatrices.jl index 9a09e1df95..db1930c27d 100644 --- a/src/manifolds/GeneralUnitaryMatrices.jl +++ b/src/manifolds/GeneralUnitaryMatrices.jl @@ -220,6 +220,18 @@ end function exp(M::GeneralUnitaryMatrices{2,ℝ}, p::SMatrix, X::SMatrix, t::Real) return exp(M, p, t * X) end +function exp(M::GeneralUnitaryMatrices{3,ℝ}, p::SMatrix, X::SMatrix) + θ = norm(M, p, X) / sqrt(2) + if θ ≈ 0 + a = 1 - θ^2 / 6 + b = θ / 2 + else + a = sin(θ) / θ + b = (1 - cos(θ)) / θ^2 + end + pinvq = I + a .* X .+ b .* (X^2) + return p * pinvq +end function exp!(M::GeneralUnitaryMatrices{2,ℝ}, q, p, X) @assert size(q) == (2, 2) θ = get_coordinates(M, p, X, DefaultOrthogonalBasis())[1] @@ -323,7 +335,14 @@ function get_coordinates( ) return SA[X[2]] end - +function get_coordinates( + ::Manifolds.GeneralUnitaryMatrices{3,ℝ}, + p::SMatrix, + X::SMatrix, + ::DefaultOrthogonalBasis{ℝ,TangentSpaceType}, +) + return SA[X[3, 2], X[1, 3], X[2, 1]] +end function get_coordinates_orthogonal(M::GeneralUnitaryMatrices{n,ℝ}, p, X, N) where {n} Y = allocate_result(M, get_coordinates, p, X, DefaultOrthogonalBasis(N)) return get_coordinates_orthogonal!(M, Y, p, X, N) @@ -405,6 +424,9 @@ end function get_vector_orthogonal(::GeneralUnitaryMatrices{2,ℝ}, p::SMatrix, Xⁱ, ::RealNumbers) return @SMatrix [0 -Xⁱ[]; Xⁱ[] 0] end +function get_vector_orthogonal(::GeneralUnitaryMatrices{3,ℝ}, p::SMatrix, Xⁱ, ::RealNumbers) + return @SMatrix [0 -Xⁱ[3] Xⁱ[2]; Xⁱ[3] 0 -Xⁱ[1]; -Xⁱ[2] Xⁱ[1] 0] +end function get_vector_orthogonal!(::GeneralUnitaryMatrices{1,ℝ}, X, p, Xⁱ, N::RealNumbers) return X .= 0 @@ -588,6 +610,20 @@ function ManifoldsBase.log(M::GeneralUnitaryMatrices{2,ℝ}, p, q) @inbounds θ = atan(U[2], U[1]) return get_vector(M, p, θ, DefaultOrthogonalBasis()) end +function log(M::Manifolds.GeneralUnitaryMatrices{3,ℝ}, p::SMatrix, q::SMatrix) + U = transpose(p) * q + cosθ = (tr(U) - 1) / 2 + if cosθ ≈ -1 + eig = Manifolds.eigen_safe(U) + ival = findfirst(λ -> isapprox(λ, 1), eig.values) + inds = SVector{3}(1:3) + #TODO this is to stop convert error of ax as a complex number + ax::Vector{Float64} = eig.vectors[inds, ival] + return get_vector(M, p, π * ax, DefaultOrthogonalBasis()) + end + X = U ./ Manifolds.usinc_from_cos(cosθ) + return (X .- X') ./ 2 +end function log!(::GeneralUnitaryMatrices{n,ℝ}, X, p, q) where {n} U = transpose(p) * q X .= real(log_safe(U)) diff --git a/test/groups/special_euclidean.jl b/test/groups/special_euclidean.jl index d9aa56e845..d19c6df846 100644 --- a/test/groups/special_euclidean.jl +++ b/test/groups/special_euclidean.jl @@ -346,4 +346,52 @@ Random.seed!(10) @test isapprox(G, pts[1], hat(G, pts[1], fXp.data), fXp2) end end + + @testset "performance of selected operations" begin + for n in [2, 3] + SEn = SpecialEuclidean(n) + Rn = Rotations(n) + + p = SMatrix{n,n}(I) + + if n == 2 + t = SVector{2}.([1:2, 2:3, 3:4]) + ω = [[1.0], [2.0], [1.0]] + pts = [ + ArrayPartition(ti, exp(Rn, p, hat(Rn, p, ωi))) for (ti, ωi) in zip(t, ω) + ] + Xs = [ + ArrayPartition(SA[-1.0, 2.0], hat(Rn, p, SA[1.0])), + ArrayPartition(SA[1.0, -2.0], hat(Rn, p, SA[0.5])), + ] + elseif n == 3 + t = SVector{3}.([1:3, 2:4, 4:6]) + ω = [SA[pi, 0.0, 0.0], SA[0.0, 0.0, 0.0], SA[1.0, 3.0, 2.0]] + pts = [ + ArrayPartition(ti, exp(Rn, p, hat(Rn, p, ωi))) for (ti, ωi) in zip(t, ω) + ] + Xs = [ + ArrayPartition(SA[-1.0, 2.0, 1.0], hat(Rn, p, SA[1.0, 0.5, -0.5])), + ArrayPartition(SA[-2.0, 1.0, 0.5], hat(Rn, p, SA[-1.0, -0.5, 1.1])), + ] + end + exp(SEn, pts[1], Xs[1]) + compose(SEn, pts[1], pts[2]) + log(SEn, pts[1], pts[2]) + log(SEn, pts[1], pts[3]) + @test isapprox(SEn, log(SEn, pts[1], pts[1]), 0 .* Xs[1]; atol=1e-16) + @test isapprox(SEn, exp(SEn, pts[1], 0 .* Xs[1]), pts[1]) + vee(SEn, pts[1], Xs[2]) + csen = n == 2 ? SA[1.0, 2.0, 3.0] : SA[1.0, 0.0, 2.0, 2.0, -1.0, 1.0] + hat(SEn, pts[1], csen) + # @btime shows 0 but `@allocations` is inaccurate + @static if VERSION >= v"1.9-DEV" + @test (@allocations exp(SEn, pts[1], Xs[1])) <= 4 + @test (@allocations compose(SEn, pts[1], pts[2])) <= 4 + @test (@allocations log(SEn, pts[1], pts[2])) <= 28 + @test (@allocations vee(SEn, pts[1], Xs[2])) <= 13 + @test (@allocations hat(SEn, pts[1], csen)) <= 13 + end + end + end end