Skip to content

Commit

Permalink
Make convert and constructors do the same thing.
Browse files Browse the repository at this point in the history
For immutable types like Rotations, they should really just do the same thing (ref e.g.
JuliaDiff/ForwardDiff.jl#342 (comment),
https://discourse.julialang.org/t/recommended-style-for-conversion-vs-constructors-in-v0-7/11561/5).

Do the actual work in constructors, and have `convert` just call
the constructors, similar to how Base.Number subtypes work now.
  • Loading branch information
tkoolen committed Sep 13, 2018
1 parent dc41718 commit 0b5bd21
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 58 deletions.
4 changes: 2 additions & 2 deletions perf/runbenchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ for (from, to) in product(rotationtypes, rotationtypes)
if from != to
name = "$(string(from)) -> $(string(to))"
# use eval here because of https://github.com/JuliaCI/BenchmarkTools.jl/issues/50#issuecomment-318673288
noneuler[name] = eval(:(@benchmarkable convert($to, rot) setup = rot = rand($from)))
noneuler[name] = eval(:(@benchmarkable $to(rot) setup = rot = rand($from)))
end
end

Expand All @@ -28,7 +28,7 @@ for from in eulertypes
to = RotMatrix3{T}
name = "$(string(from)) -> $(string(to))"
# use eval here because of https://github.com/JuliaCI/BenchmarkTools.jl/issues/50#issuecomment-318673288
euler[name] = eval(:(@benchmarkable convert($to, rot) setup = rot = rand($from)))
euler[name] = eval(:(@benchmarkable $to(rot) setup = rot = rand($from)))
end


Expand Down
39 changes: 15 additions & 24 deletions src/angleaxis_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,11 @@ end
end

# These functions are enough to satisfy the entire StaticArrays interface:
@inline (::Type{AA})(t::NTuple{9}) where {AA <: AngleAxis} = convert(AA, Quat(t))
@inline Base.getindex(aa::AngleAxis, i::Int) = convert(Quat, aa)[i]
@inline Tuple(aa::AngleAxis) = Tuple(convert(RotMatrix, aa))
@inline (::Type{AA})(t::NTuple{9}) where {AA <: AngleAxis} = AA(Quat(t)) # TODO: consider going directly from tuple (RotMatrix) to AngleAxis
@inline Base.getindex(aa::AngleAxis, i::Int) = Quat(aa)[i]

@inline function Base.convert(::Type{R}, aa::AngleAxis) where R <: RotMatrix
@inline function Base.Tuple(aa::AngleAxis{T}) where T
# Rodrigues' rotation formula.
T = eltype(aa)

s, c = sincos(aa.theta)
c1 = one(T) - c

Expand All @@ -68,17 +65,17 @@ end
sz = s * aa.axis_z

# Note that the RotMatrix constructor argument order makes this look transposed:
R(one(T) - c1y2 - c1z2, c1xy + sz, c1xz - sy,
c1xy - sz, one(T) - c1x2 - c1z2, c1yz + sx,
c1xz + sy, c1yz - sx, one(T) - c1x2 - c1y2)
(one(T) - c1y2 - c1z2, c1xy + sz, c1xz - sy,
c1xy - sz, one(T) - c1x2 - c1z2, c1yz + sx,
c1xz + sy, c1yz - sx, one(T) - c1x2 - c1y2)
end

@inline function Base.convert(::Type{Q}, aa::AngleAxis) where Q <: Quat
@inline function (::Type{Q})(aa::AngleAxis) where Q <: Quat
s, c = sincos(aa.theta / 2)
return Q(c, s * aa.axis_x, s * aa.axis_y, s * aa.axis_z, false)
end

@inline function Base.convert(::Type{AA}, q::Quat) where AA <: AngleAxis
@inline function (::Type{AA})(q::Quat) where AA <: AngleAxis
s2 = q.x * q.x + q.y * q.y + q.z * q.z
sin_t2 = sqrt(s2)
theta = 2 * atan(sin_t2, q.w)
Expand Down Expand Up @@ -147,32 +144,29 @@ end
@inline RodriguesVec(x::X, y::Y, z::Z) where {X,Y,Z} = RodriguesVec{promote_type(promote_type(X, Y), Z)}(x, y, z)

# These functions are enough to satisfy the entire StaticArrays interface:
@inline (::Type{RV})(t::NTuple{9}) where {RV <: RodriguesVec} = convert(RV, Quat(t))
@inline Base.getindex(aa::RodriguesVec, i::Int) = convert(Quat, aa)[i]
@inline Base.Tuple(rv::RodriguesVec) = Tuple(convert(Quat, rv))

# define its interaction with other angle representations
@inline Base.convert(::Type{R}, rv::RodriguesVec) where {R <: RotMatrix} = convert(R, AngleAxis(rv))
@inline (::Type{RV})(t::NTuple{9}) where {RV <: RodriguesVec} = RV(Quat(t)) # TODO: go through AngleAxis once it's faster
@inline Base.getindex(aa::RodriguesVec, i::Int) = Quat(aa)[i]
@inline Base.Tuple(rv::RodriguesVec) = Tuple(Quat(rv))

function Base.convert(::Type{AA}, rv::RodriguesVec) where AA <: AngleAxis
function (::Type{AA})(rv::RodriguesVec) where AA <: AngleAxis
# TODO: consider how to deal with derivative near theta = 0. There should be a first-order expansion here.
theta = rotation_angle(rv)
return theta > 0 ? AA(theta, rv.sx / theta, rv.sy / theta, rv.sz / theta, false) : AA(zero(theta), one(theta), zero(theta), zero(theta), false)
end

function Base.convert(::Type{RV}, aa::AngleAxis) where RV <: RodriguesVec
function (::Type{RV})(aa::AngleAxis) where RV <: RodriguesVec
return RV(aa.theta * aa.axis_x, aa.theta * aa.axis_y, aa.theta * aa.axis_z)
end

function Base.convert(::Type{Q}, rv::RodriguesVec) where Q <: Quat
function (::Type{Q})(rv::RodriguesVec) where Q <: Quat
theta = rotation_angle(rv)
qtheta = cos(theta / 2)
#s = abs(1/2 * sinc((theta / 2) / pi))
s = (1/2 * sinc((theta / 2) / pi)) # TODO check this (I removed an abs)
return Q(qtheta, s * rv.sx, s * rv.sy, s * rv.sz, false)
end

Base.convert(::Type{RV}, q::Quat) where {RV <: RodriguesVec} = convert(RV, convert(AngleAxis, q))
(::Type{RV})(q::Quat) where {RV <: RodriguesVec} = RV(AngleAxis(q))

function Base.:*(rv::RodriguesVec{T1}, v::StaticVector{3, T2}) where {T1,T2}
theta = rotation_angle(rv)
Expand Down Expand Up @@ -200,9 +194,6 @@ end
@inline Base.:^(rv::RodriguesVec, t::Real) = RodriguesVec(rv.sx*t, rv.sy*t, rv.sz*t)
@inline Base.:^(rv::RodriguesVec, t::Integer) = RodriguesVec(rv.sx*t, rv.sy*t, rv.sz*t) # to avoid ambiguity




# rotation properties
@inline rotation_angle(rv::RodriguesVec) = sqrt(rv.sx * rv.sx + rv.sy * rv.sy + rv.sz * rv.sz)
function rotation_axis(rv::RodriguesVec) # what should this return for theta = 0?
Expand Down
4 changes: 3 additions & 1 deletion src/core_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ Base.transpose(r::Rotation{N,T}) where {N,T<:Real} = inv(r)
rotation_angle(r::Rotation) = rotation_angle(AngleAxis(r))
rotation_axis(r::Rotation) = rotation_axis(AngleAxis(r))

# `convert` goes through the constructors, similar to e.g. `Number`
Base.convert(::Type{R}, rot::Rotation{N}) where {N,R<:Rotation{N}} = R(rot)

# Rotation matrices should be orthoginal/unitary. Only the operations we define,
# like multiplication, will stay as Rotations, otherwise users will get an
# SMatrix{3,3} (e.g. rot1 + rot2 -> SMatrix)
Expand Down Expand Up @@ -82,7 +85,6 @@ end
RotMatrix(x::SMatrix{N,N,T,L}) where {N,T,L} = RotMatrix{N,T,L}(x)

# These functions (plus size) are enough to satisfy the entire StaticArrays interface:
# @inline (::Type{R}){R<:RotMatrix}(t::Tuple) = error("No precise constructor found. Length of input was $(length(t)).")
for N = 2:3
L = N*N
RotMatrixN = Symbol(:RotMatrix, N)
Expand Down
9 changes: 0 additions & 9 deletions src/euler_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ for axis in [:X, :Y, :Z]
@inline $RotType(theta::T) where {T} = $RotType{T}(theta)
@inline $RotType(r::$RotType{T}) where {T} = $RotType{T}(r)

@inline Base.convert(::Type{R}, r::$RotType) where {R<:$RotType} = R(r)
@inline Base.convert(::Type{R}, r::R) where {R<:$RotType} = r

@inline (::Type{R})(t::NTuple{9}) where {R<:$RotType} = error("Cannot construct a cardinal axis rotation from a matrix")

@inline Base.:*(r1::$RotType, r2::$RotType) = $RotType(r1.theta + r2.theta)
Expand Down Expand Up @@ -228,9 +225,6 @@ for axis1 in [:X, :Y, :Z]
@inline $RotType(theta1::T1, theta2::T2) where {T1, T2} = $RotType{promote_type(T1, T2)}(theta1, theta2)
@inline $RotType(r::$RotType{T}) where {T} = $RotType{T}(r)

@inline Base.convert(::Type{R}, r::$RotType) where {R<:$RotType} = R(r)
@inline Base.convert(::Type{R}, r::R) where {R<:$RotType} = r

@inline function Base.getindex(r::$RotType{T}, i::Int) where T
Tuple(r)[i] # Slow...
end
Expand Down Expand Up @@ -511,9 +505,6 @@ for axis1 in [:X, :Y, :Z]
@inline $RotType(theta1::T1, theta2::T2, theta3::T3) where {T1, T2, T3} = $RotType{promote_type(promote_type(T1, T2), T3)}(theta1, theta2, theta3)
@inline $RotType(r::$RotType{T}) where {T} = $RotType{T}(r)

@inline Base.convert(::Type{R}, r::$RotType) where {R<:$RotType} = R(r)
@inline Base.convert(::Type{R}, r::R) where {R<:$RotType} = r

@inline function Base.getindex(r::$RotType{T}, i::Int) where T
Tuple(r)[i] # Slow...
end
Expand Down
2 changes: 1 addition & 1 deletion src/principal_value.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ the following properties:
"""
principal_value(r::RotMatrix) = r
principal_value(q::Quat{T}) where {T} = q.w < zero(T) ? Quat{T}(-q.w, -q.x, -q.y, -q.z) : q
principal_value(spq::SPQuat{T}) where {T} = convert(SPQuat, principal_value(convert(Quat, spq)))
principal_value(spq::SPQuat{T}) where {T} = SPQuat(principal_value(Quat(spq)))

function principal_value(aa::AngleAxis{T}) where {T}
theta = mod_minus_pi_to_pi(aa.theta)
Expand Down
20 changes: 2 additions & 18 deletions src/quaternion_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,6 @@ end
@inline function Quat(w::W, x::X, y::Y, z::Z, normalize::Bool = true) where {W, X, Y, Z}
Quat{promote_type(promote_type(promote_type(W, X), Y), Z)}(w, x, y, z, normalize)
end
@inline Quat(q::Quat{T}) where {T} = Quat{T}(q)

@inline Base.convert(::Type{Q}, q::Quat) where {Q<:Quat} = Q(q)
@inline Base.convert(::Type{Q}, q::Q) where {Q<:Quat} = q

# These 3 functions are enough to satisfy the entire StaticArrays interface:
function (::Type{Q})(t::NTuple{9}) where Q<:Quat
Expand Down Expand Up @@ -246,21 +242,13 @@ struct SPQuat{T} <: Rotation{3,T}
end

@inline SPQuat(x::X, y::Y, z::Z) where {X,Y,Z} = SPQuat{promote_type(promote_type(X, Y), Z)}(x, y, z)
@inline SPQuat(spq::SPQuat{T}) where {T} = SPQuat{T}(spq)

@inline Base.convert(::Type{SPQ}, spq::SPQuat) where {SPQ<:SPQuat} = SPQ(spq)
@inline Base.convert(::Type{SPQ}, spq::SPQ) where {SPQ<:SPQuat} = spq

# These functions are enough to satisfy the entire StaticArrays interface:
@inline (::Type{SPQ})(t::NTuple{9}) where {SPQ <: SPQuat} = convert(SPQ, Quat(t))
@inline Base.getindex(spq::SPQuat, i::Int) = convert(Quat, spq)[i]
@inline Base.Tuple(spq::SPQuat) = Tuple(convert(Quat, spq))

# Optimizations for going between Quat and SPQuat
@inline (::Type{SPQ})(q::Quat) where {SPQ <: SPQuat} = convert(SPQ, q)
@inline (::Type{Q})(spq::SPQuat) where {Q <: Quat} = convert(Q, spq)

@inline function Base.convert(::Type{Q}, spq::SPQuat) where Q <: Quat
@inline function (::Type{Q})(spq::SPQuat) where Q <: Quat
# Equation (45) in
# Terzakis et al., "A Recipe on the Parameterization of Rotation Matrices
# for Non-Linear Optimization using Quaternions":
Expand All @@ -269,7 +257,7 @@ end
Q((1 - alpha2) / (alpha2 + 1), scale * spq.x, scale * spq.y, scale * spq.z, false)
end

@inline function Base.convert(::Type{SPQ}, q::Quat) where SPQ <: SPQuat
@inline function (::Type{SPQ})(q::Quat) where SPQ <: SPQuat
# Simplification of (46) and (47) in
# Terzakis et al., "A Recipe on the Parameterization of Rotation Matrices
# for Non-Linear Optimization using Quaternions":
Expand All @@ -295,7 +283,3 @@ end

@inline Base.one(::Type{SPQuat}) = SPQuat(0.0, 0.0, 0.0)
@inline Base.one(::Type{SPQuat{T}}) where {T} = SPQuat{T}(zero(T), zero(T), zero(T))

# rotation properties
@inline rotation_angle(spq::SPQuat) = rotation_angle(Quat(spq))
@inline rotation_axis(spq::SPQuat) = rotation_axis(Quat(spq))
6 changes: 3 additions & 3 deletions test/derivative_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ using ForwardDiff

# test jacobian to a Rotation matrix
R_jac = Rotations.jacobian(Quat, spq)
FD_jac = ForwardDiff.jacobian(x -> (q = convert(Quat, SPQuat(x[1],x[2],x[3]));
FD_jac = ForwardDiff.jacobian(x -> (q = Quat(SPQuat(x[1],x[2],x[3]));
SVector(q.w, q.x, q.y, q.z)),
SVector(spq.x, spq.y, spq.z))

Expand All @@ -45,7 +45,7 @@ using ForwardDiff
for spq = [SPQuat(1.0, 0.0, 0.0), SPQuat(0.0, 1.0, 0.0), SPQuat(0.0, 0.0, 1.0)]
# test jacobian to a Rotation matrix
R_jac = Rotations.jacobian(Quat, spq)
FD_jac = ForwardDiff.jacobian(x -> (q = convert(Quat, SPQuat(x[1],x[2],x[3]));
FD_jac = ForwardDiff.jacobian(x -> (q = Quat(SPQuat(x[1],x[2],x[3]));
SVector(q.w, q.x, q.y, q.z)),
SVector(spq.x, spq.y, spq.z))

Expand All @@ -61,7 +61,7 @@ using ForwardDiff

# test jacobian to a Rotation matrix
R_jac = Rotations.jacobian(SPQuat, q)
FD_jac = ForwardDiff.jacobian(x -> (spq = convert(SPQuat, Quat(x[1], x[2], x[3], x[4]));
FD_jac = ForwardDiff.jacobian(x -> (spq = SPQuat(Quat(x[1], x[2], x[3], x[4]));
SVector(spq.x, spq.y, spq.z)),
SVector(q.w, q.x, q.y, q.z))

Expand Down

0 comments on commit 0b5bd21

Please sign in to comment.