Skip to content

Commit

Permalink
Replace Unitful aware givensAlgorithm with a generic version.
Browse files Browse the repository at this point in the history
The implementation is based on

Janovská, D., & Opfer, G. (2003) "Givens’ Transformation Applied to
Quaternion Valued Vectors."
  • Loading branch information
andreasnoack committed Apr 15, 2024
1 parent b9aeafa commit d367449
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 16 deletions.
27 changes: 15 additions & 12 deletions stdlib/LinearAlgebra/src/givens.jl
Original file line number Diff line number Diff line change
Expand Up @@ -250,20 +250,23 @@ function givensAlgorithm(f::Complex{T}, g::Complex{T}) where T<:AbstractFloat
return cs, sn, r
end

# enable for unitful quantities
function givensAlgorithm(f::T, g::T) where T
fs = f / oneunit(T)
gs = g / oneunit(T)
typeof(fs) === T && typeof(gs) === T &&
!isa(fs, Union{AbstractFloat,Complex{<:AbstractFloat}}) &&
throw(MethodError(givensAlgorithm, (fs, gs)))

c, s, r = givensAlgorithm(fs, gs)
return c, s, r * oneunit(T)
# From Janovská, D., & Opfer, G. (2003). Givens’ Transformation Applied to Quaternion
# Valued Vectors. BIT Numerical Mathematics, 43(5), 991–1002.
# doi:10.1023/b:bitn.0000014561.58141.2c
function givensAlgorithm(f::Number, g::Number)
nrm = hypot(f, g)
c = abs(f) / nrm
s, u = if iszero(f)
-one(first(promote(f, g))), -g
else
# Note that the paper conjugates the argument in the definition of sign but the
# givens rotation implementation then ends up with conjugating twice.
signf̄ = f / abs(f)
(signf̄ * conj(g)) / nrm, nrm * signf̄
end
return c, s, u
end

givensAlgorithm(f, g) = givensAlgorithm(promote(float(f), float(g))...)

"""
givens(f::T, g::T, i1::Integer, i2::Integer) where {T} -> (G::Givens, r::T)
Expand Down
25 changes: 21 additions & 4 deletions stdlib/LinearAlgebra/test/givens.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ module TestGivens
using Test, LinearAlgebra, Random
using LinearAlgebra: Givens, Rotation, givensAlgorithm

isdefined(Main, :Quaternions) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "Quaternions.jl"))
using .Main.Quaternions

# Test givens rotations
@testset "Test Givens for $elty" for elty in (Float32, Float64, ComplexF32, ComplexF64)
if elty <: Real
Expand Down Expand Up @@ -95,15 +98,19 @@ end
const TNumber = Union{Float64,ComplexF64}
struct MockUnitful{T<:TNumber} <: Number
data::T
MockUnitful(data::T) where T<:TNumber = new{T}(data)
end
import Base: *, /, one, oneunit
*(a::MockUnitful{T}, b::T) where T<:TNumber = MockUnitful(a.data * b)
*(a::T, b::MockUnitful{T}) where T<:TNumber = MockUnitful(a * b.data)
import Base: *, /, convert, conj, float, abs, one, oneunit, promote_rule
*(a::MockUnitful, b::TNumber) = MockUnitful(a.data * b)
*(a::TNumber, b::MockUnitful) = MockUnitful(a * b.data)
*(a::MockUnitful{T}, b::MockUnitful{T}) where T<:TNumber = MockUnitful(a.data * b.data)
/(a::MockUnitful{T}, b::MockUnitful{T}) where T<:TNumber = a.data / b.data
abs(a::MockUnitful) = MockUnitful(abs(a.data))
conj(a::MockUnitful) = MockUnitful(conj(a.data))
convert(::Type{MockUnitful{T}}, x::MockUnitful) where {T} = MockUnitful(convert(T, x.data))
float(a::MockUnitful) = MockUnitful(float(a.data))
one(::Type{<:MockUnitful{T}}) where T = one(T)
oneunit(::Type{<:MockUnitful{T}}) where T = MockUnitful(one(T))
promote_rule(::Type{MockUnitful{T}}, ::Type{MockUnitful{S}}) where {T,S} = MockUnitful{promote_type(T, S)}

@testset "unitful givens rotation unitful $T " for T in (Float64, ComplexF64)
g, r = givens(MockUnitful(T(3)), MockUnitful(T(4)), 1, 2)
Expand All @@ -121,4 +128,14 @@ end
@test !isfinite(r)
end

@testset "givensAlgorithm with quaternions" for (x, y) in
(
(Quaternion(randn(4)...), Quaternion(randn(4)...)),
(0, Quaternion(randn(4)...)),
)
c, s, r = givensAlgorithm(x, y)
@test c * x + s * y r
@test c * y s' * x
end

end # module TestGivens
2 changes: 2 additions & 0 deletions test/testhelpers/Quaternions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ end
Quaternion(s::Real, v1::Real, v2::Real, v3::Real) = Quaternion(promote(s, v1, v2, v3)...)
Base.convert(::Type{Quaternion{T}}, s::Real) where {T <: Real} =
Quaternion{T}(convert(T, s), zero(T), zero(T), zero(T))
Base.promote_rule(::Type{Quaternion{T}}, ::Type{S}) where {T,S} = Quaternion{promote_type(T,S)}
Base.abs2(q::Quaternion) = q.s*q.s + q.v1*q.v1 + q.v2*q.v2 + q.v3*q.v3
Base.float(z::Quaternion{T}) where T = Quaternion(float(z.s), float(z.v1), float(z.v2), float(z.v3))
Base.abs(q::Quaternion) = sqrt(abs2(q))
Expand All @@ -25,6 +26,7 @@ Base.conj(q::Quaternion) = Quaternion(q.s, -q.v1, -q.v2, -q.v3)
Base.isfinite(q::Quaternion) = isfinite(q.s) & isfinite(q.v1) & isfinite(q.v2) & isfinite(q.v3)
Base.zero(::Type{Quaternion{T}}) where T = Quaternion{T}(zero(T), zero(T), zero(T), zero(T))

Base.:(-)(q::Quaternion) = Quaternion(-q.s, -q.v1, -q.v2, -q.v3)
Base.:(+)(ql::Quaternion, qr::Quaternion) =
Quaternion(ql.s + qr.s, ql.v1 + qr.v1, ql.v2 + qr.v2, ql.v3 + qr.v3)
Base.:(-)(ql::Quaternion, qr::Quaternion) =
Expand Down

0 comments on commit d367449

Please sign in to comment.