Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace Unitful aware givensAlgorithm with a generic version. #54078

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
31 changes: 17 additions & 14 deletions stdlib/LinearAlgebra/src/givens.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ transpose(R::AbstractRotation) = error("transpose not implemented for $(typeof(R
(*)(R::AbstractRotation, A::AbstractMatrix) = _rot_mul_vecormat(R, A)
function _rot_mul_vecormat(R::AbstractRotation{T}, A::AbstractVecOrMat{S}) where {T,S}
TS = typeof(zero(T)*zero(S) + zero(T)*zero(S))
lmul!(convert(AbstractRotation{TS}, R), copy_similar(A, TS))
lmul!(R, copy_similar(A, TS))
end

(*)(A::AbstractVector, R::AbstractRotation) = _vecormat_mul_rot(A, R)
(*)(A::AbstractMatrix, R::AbstractRotation) = _vecormat_mul_rot(A, R)
function _vecormat_mul_rot(A::AbstractVecOrMat{T}, R::AbstractRotation{S}) where {T,S}
TS = typeof(zero(T)*zero(S) + zero(T)*zero(S))
rmul!(copy_similar(A, TS), convert(AbstractRotation{TS}, R))
rmul!(copy_similar(A, TS), R)
end

"""
Expand Down Expand Up @@ -250,20 +250,23 @@ function givensAlgorithm(f::Complex{T}, g::Complex{T}) where T<:AbstractFloat
return cs, sn, r
end

# enable for unitful quantities
function givensAlgorithm(f::T, g::T) where T
fs = f / oneunit(T)
gs = g / oneunit(T)
typeof(fs) === T && typeof(gs) === T &&
!isa(fs, Union{AbstractFloat,Complex{<:AbstractFloat}}) &&
throw(MethodError(givensAlgorithm, (fs, gs)))

c, s, r = givensAlgorithm(fs, gs)
return c, s, r * oneunit(T)
# From Janovská, D., & Opfer, G. (2003). Givens’ Transformation Applied to Quaternion
# Valued Vectors. BIT Numerical Mathematics, 43(5), 991–1002.
# doi:10.1023/b:bitn.0000014561.58141.2c
function givensAlgorithm(f::Number, g::Number)
nrm = hypot(f, g)
c = abs(f) / nrm
s, u = if iszero(f)
-one(first(promote(f, g))), -g
andreasnoack marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

@devmotion devmotion Apr 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just noticed that this is inconsistent with the existing definition for Float64:

julia> using LinearAlgebra

julia> LinearAlgebra.givensAlgorithm(0.0, 2.0)
(0.0, 1.0, 2.0)

It seems to be consistent with the existing definitions we would need something like

Suggested change
-one(first(promote(f, g))), -g
one(first(promote(f, g))), g

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure it's worth trying to match the values of c and s from the two versions translated from LAPACK. Then we'd essentially have to reimplement the version for complex numbers and I'm not sure if it necessarily works for quaternions. In any case, which of the allowed values of c and s doesn't really matter. It's always the result of applying the rotation that matters and it isn't affected by this choice (except for rounding).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thought a bit more about this and concluded that you are right here. The fallback would have to make the same sign decisions as the specialized version so the approach of this PR can't work. We should probably try to make a version of the complex version that isn't restricted to Complex32/64.

else
# Note that the paper conjugates the argument in the definition of sign but the
# givens rotation implementation then ends up with conjugating twice.
signf̄ = sign(f)
(signf̄ * conj(g)) / nrm, nrm * signf̄
end
return c, s, u
end

givensAlgorithm(f, g) = givensAlgorithm(promote(float(f), float(g))...)

"""

givens(f::T, g::T, i1::Integer, i2::Integer) where {T} -> (G::Givens, r::T)
Expand Down
46 changes: 25 additions & 21 deletions stdlib/LinearAlgebra/test/givens.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ module TestGivens
using Test, LinearAlgebra, Random
using LinearAlgebra: Givens, Rotation, givensAlgorithm

isdefined(Main, :Quaternions) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "Quaternions.jl"))
using .Main.Quaternions

# Test givens rotations
@testset "Test Givens for $elty" for elty in (Float32, Float64, ComplexF32, ComplexF64)
if elty <: Real
Expand Down Expand Up @@ -88,28 +91,19 @@ const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test")
isdefined(Main, :Furlongs) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "Furlongs.jl"))
using .Main.Furlongs

@testset "testing dimensions with Furlongs" begin
@test_throws MethodError givens(Furlong(1.0), Furlong(2.0), 1, 2)
end

const TNumber = Union{Float64,ComplexF64}
struct MockUnitful{T<:TNumber} <: Number
data::T
MockUnitful(data::T) where T<:TNumber = new{T}(data)
end
import Base: *, /, one, oneunit
*(a::MockUnitful{T}, b::T) where T<:TNumber = MockUnitful(a.data * b)
*(a::T, b::MockUnitful{T}) where T<:TNumber = MockUnitful(a * b.data)
*(a::MockUnitful{T}, b::MockUnitful{T}) where T<:TNumber = MockUnitful(a.data * b.data)
/(a::MockUnitful{T}, b::MockUnitful{T}) where T<:TNumber = a.data / b.data
one(::Type{<:MockUnitful{T}}) where T = one(T)
oneunit(::Type{<:MockUnitful{T}}) where T = MockUnitful(one(T))

@testset "unitful givens rotation unitful $T " for T in (Float64, ComplexF64)
g, r = givens(MockUnitful(T(3)), MockUnitful(T(4)), 1, 2)
@test g.c ≈ 3/5
@test g.s ≈ 4/5
@test r.data ≈ 5.0

@testset "unitful givens rotation unitful $T " for T in (Float32, Float64, ComplexF32, ComplexF64)
g, r = givens(Furlong(T(3)), Furlong(T(4)), 1, 2)
@test g.c.val ≈ 3/5
@test g.c isa Furlong{0}
@test g.s.val ≈ 4/5
@test g.s isa Furlong{0}
@test r.val ≈ 5.0
@test r isa Furlong{1}
y = g * [Furlong(T(3)), Furlong(T(4))]
@test y[1].val ≈ r.val
@test y[2].val ≈ 0 atol = 10eps()
end

# 51554
Expand All @@ -121,4 +115,14 @@ end
@test !isfinite(r)
end

@testset "givensAlgorithm with quaternions" for (x, y) in
(
(Quaternion(randn(4)...), Quaternion(randn(4)...)),
(0, Quaternion(randn(4)...)),
)
c, s, r = givensAlgorithm(x, y)
@test c * x + s * y ≈ r
@test c * y ≈ s' * x
end

end # module TestGivens
2 changes: 2 additions & 0 deletions test/testhelpers/Quaternions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ end
Quaternion(s::Real, v1::Real, v2::Real, v3::Real) = Quaternion(promote(s, v1, v2, v3)...)
Base.convert(::Type{Quaternion{T}}, s::Real) where {T <: Real} =
Quaternion{T}(convert(T, s), zero(T), zero(T), zero(T))
Base.promote_rule(::Type{Quaternion{T}}, ::Type{S}) where {T,S} = Quaternion{promote_type(T,S)}
Base.abs2(q::Quaternion) = q.s*q.s + q.v1*q.v1 + q.v2*q.v2 + q.v3*q.v3
Base.float(z::Quaternion{T}) where T = Quaternion(float(z.s), float(z.v1), float(z.v2), float(z.v3))
Base.abs(q::Quaternion) = sqrt(abs2(q))
Expand All @@ -25,6 +26,7 @@ Base.conj(q::Quaternion) = Quaternion(q.s, -q.v1, -q.v2, -q.v3)
Base.isfinite(q::Quaternion) = isfinite(q.s) & isfinite(q.v1) & isfinite(q.v2) & isfinite(q.v3)
Base.zero(::Type{Quaternion{T}}) where T = Quaternion{T}(zero(T), zero(T), zero(T), zero(T))

Base.:(-)(q::Quaternion) = Quaternion(-q.s, -q.v1, -q.v2, -q.v3)
Base.:(+)(ql::Quaternion, qr::Quaternion) =
Quaternion(ql.s + qr.s, ql.v1 + qr.v1, ql.v2 + qr.v2, ql.v3 + qr.v3)
Base.:(-)(ql::Quaternion, qr::Quaternion) =
Expand Down