Skip to content

Commit

Permalink
Merge pull request #114 from mikeingold/isapprox
Browse files Browse the repository at this point in the history
Improved Base.isapprox support
  • Loading branch information
MilesCranmer committed Mar 11, 2024
2 parents 73df126 + 750b66e commit 4b4ba62
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 7 deletions.
5 changes: 4 additions & 1 deletion ext/DynamicQuantitiesLinearAlgebraExt.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
module DynamicQuantitiesLinearAlgebraExt

using DynamicQuantities: UnionAbstractQuantity, QuantityArray, ustrip, dimension, new_quantity

import LinearAlgebra: norm
import DynamicQuantities: UnionAbstractQuantity, ustrip, dimension, new_quantity
import DynamicQuantities: _norm

_norm(u::AbstractArray) = norm(u)
norm(q::UnionAbstractQuantity, p::Real=2) = new_quantity(typeof(q), norm(ustrip(q), p), dimension(q))

end
40 changes: 40 additions & 0 deletions src/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,46 @@ end
Base.fill(x::UnionAbstractQuantity, dims::Dims...) = QuantityArray(fill(ustrip(x), dims...), dimension(x), typeof(x))
Base.fill(x::UnionAbstractQuantity, t::Tuple{}) = QuantityArray(fill(ustrip(x), t), dimension(x), typeof(x))

# Will be overloaded by `DynamicQuantitiesLinearAlgebraExt`:
_norm(_) = error("Please load the `LinearAlgebra.jl` package.")

# Define isapprox for vectors of Quantity's
struct AutoTolerance end

atoldefault(_, atol) = atol
rtoldefault(_, _, _, rtol) = rtol

function atoldefault(el, ::AutoTolerance)
return zero(el)
end
function rtoldefault(::Union{T1,Type{T1}}, ::Union{T2,Type{T2}}, atol, ::AutoTolerance) where {T1,T2}
rtol = max(Base.rtoldefault(real(T1)), Base.rtoldefault(real(T2)))
return iszero(atol) ? rtol : zero(rtol)
end

all_dimensions_equal(A::QuantityArray, B::QuantityArray) = dimension(A) == dimension(B)
all_dimensions_equal(A::QuantityArray, B::AbstractArray{<:UnionAbstractQuantity}) = all(i -> dimension(A) == dimension(B[i]), eachindex(B))
all_dimensions_equal(A::AbstractArray{<:UnionAbstractQuantity}, B::QuantityArray) = all(i -> dimension(B) == dimension(A[i]), eachindex(A))
function all_dimensions_equal(A::AbstractArray{<:UnionAbstractQuantity}, B::AbstractArray{<:UnionAbstractQuantity})
d = dimension(first(A))
return d == dimension(first(B)) && all(i -> d == dimension(A[i]), eachindex(A)) && all(i -> d == dimension(B[i]), eachindex(B))
end

function Base.isapprox(
u::Union{QuantityArray,AbstractArray{<:UnionAbstractQuantity}},
v::Union{QuantityArray,AbstractArray{<:UnionAbstractQuantity}};
atol=AutoTolerance(),
rtol=AutoTolerance(),
norm::F=_norm
) where {F<:Function}
all_dimensions_equal(u, v) || throw(DimensionError(u, v))
d = norm(u .- v)
_atol = atoldefault(first(u), atol)
_rtol = rtoldefault(ustrip(first(u)), ustrip(first(v)), _atol, rtol)
return iszero(_rtol) ? d <= _atol : d <= max(_atol, _rtol*max(norm(u), norm(v)))
end

# Unit functions
ulength(q::QuantityArray) = ulength(dimension(q))
umass(q::QuantityArray) = umass(dimension(q))
utime(q::QuantityArray) = utime(dimension(q))
Expand Down
15 changes: 12 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,15 @@ for (type, _, _) in ABSTRACT_QUANTITY_TYPES
end
Base.keys(q::UnionAbstractQuantity) = keys(ustrip(q))

# If atol specified in kwargs, validate its dimensions and then strip units
@inline function _validate_isapprox(dimcheck, kws)
if haskey(kws, :atol)
dimension(dimcheck) == dimension(kws[:atol]) || throw(DimensionError(dimcheck, kws[:atol]))
return (; kws..., atol=ustrip(kws[:atol]))
else
return kws
end
end

# Numeric checks
for op in (:(<=), :(<), :(>=), :(>), :isless), (type, true_base_type, _) in ABSTRACT_QUANTITY_TYPES
Expand Down Expand Up @@ -224,15 +233,15 @@ for (type, true_base_type, _) in ABSTRACT_QUANTITY_TYPES
function Base.isapprox(l::$type, r::$type; kws...)
l, r = promote_except_value(l, r)
dimension(l) == dimension(r) || throw(DimensionError(l, r))
return isapprox(ustrip(l), ustrip(r); kws...)
return isapprox(ustrip(l), ustrip(r); _validate_isapprox(l, kws)...)
end
function Base.isapprox(l::$base_type, r::$type; kws...)
iszero(dimension(r)) || throw(DimensionError(l, r))
return isapprox(l, ustrip(r); kws...)
return isapprox(l, ustrip(r); _validate_isapprox(r, kws)...)
end
function Base.isapprox(l::$type, r::$base_type; kws...)
iszero(dimension(l)) || throw(DimensionError(l, r))
return isapprox(ustrip(l), r; kws...)
return isapprox(ustrip(l), r; _validate_isapprox(l, kws)...)
end
end
for (type2, _, _) in ABSTRACT_QUANTITY_TYPES
Expand Down
7 changes: 5 additions & 2 deletions test/test_unitful.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@ end
for T in [DEFAULT_VALUE_TYPE, Float16, Float32, Float64], R in [DEFAULT_DIM_BASE_TYPE, Rational{Int16}, Rational{Int32}, SimpleRatio{Int}, SimpleRatio{SafeInt16}]
D = DynamicQuantities.Dimensions{R}
x = DynamicQuantities.Quantity(T(0.2), D, length=1, amount=2, current=-1 // 2, luminosity=2 // 5)
tol_dq = DynamicQuantities.Quantity(T(1e-6), D, length=1, amount=2, current=-1 // 2, luminosity=2 // 5)
x_unitful = T(0.2)u"m*mol^2*A^(-1//2)*cd^(2//5)"
tol_unitful = T(1e-6)u"m*mol^2*A^(-1//2)*cd^(2//5)"

@test risapprox(convert(Unitful.Quantity, x), x_unitful; atol=1e-6)
@test typeof(convert(DynamicQuantities.Quantity, convert(Unitful.Quantity, x))) <: DynamicQuantities.Quantity{T,DynamicQuantities.DEFAULT_DIM_TYPE}
@test isapprox(convert(DynamicQuantities.Quantity, convert(Unitful.Quantity, x)), x; atol=1e-6)
@test isapprox(convert(DynamicQuantities.Quantity, convert(Unitful.Quantity, x)), x; atol=tol_dq)
@test_throws DynamicQuantities.DimensionError isapprox(convert(DynamicQuantities.Quantity, convert(Unitful.Quantity, x)), x; atol=1e-6)

@test isapprox(convert(DynamicQuantities.Quantity{T,D}, x_unitful), x; atol=1e-6)
@test isapprox(convert(DynamicQuantities.Quantity{T,D}, x_unitful), x; atol=tol_dq)
@test risapprox(convert(Unitful.Quantity, convert(DynamicQuantities.Quantity{T,D}, x_unitful)), Unitful.upreferred(x_unitful); atol=1e-6)

@test typeof(convert(DynamicQuantities.Dimensions, Unitful.dimension(x_unitful))) == DynamicQuantities.Dimensions{DEFAULT_DIM_BASE_TYPE}
Expand Down
61 changes: 60 additions & 1 deletion test/unittests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,12 @@ end
uX = X .* Quantity(2, length=2.5, luminosity=0.5)
@test sum(X) == 0.5 * ustrip(sum(uX))

@test isapprox(uX, uX, atol=Quantity{T,D}(1e-6, length=2.5, luminosity=0.5))
@test_throws DimensionError isapprox(uX, uX, atol=1e-6)
@test_throws ErrorException DynamicQuantities._norm(1.0)
VERSION >= v"1.9" &&
@test_throws "Please load the `LinearAlgebra.jl` package." DynamicQuantities._norm(1.0)

x = GenericQuantity(ones(T, 32))
@test ustrip(x + ones(T, 32))[32] == 2
@test typeof(x + ones(T, 32)) <: GenericQuantity{Vector{T}}
Expand All @@ -297,6 +303,55 @@ end
@test ones(T, 32) / GenericQuantity(T(1), length=1) == GenericQuantity(ones(T, 32), length=-1)
end

@testset "isapprox" begin
A = QuantityArray([1, 2, 3], u"m")
B = QuantityArray([1, 2, 3], u"m")
@test isapprox(A, B)

A = QuantityArray([1, 2, 3], u"m")
B = QuantityArray([1, 2, 3], u"s")
@test_throws DimensionError isapprox(A, B)

A = QuantityArray([1, 2, 3], u"m")
B = QuantityArray([1.001, 2.001, 3.001], u"m")
@test !isapprox(A, B, atol=1e-4u"m")
@test isapprox(A, B, atol=1e-2u"m")

A = QuantityArray([1, 2, 3], u"m")
B = QuantityArray([1.1, 2.1, 3.1], u"m")
@test !isapprox(A, B, rtol=0.01)
@test isapprox(A, B, rtol=0.05)

A = [1u"m", 2u"m", 3u"m"]
B = QuantityArray([1, 2, 3], u"m")
@test isapprox(A, B)

A = [1u"m", 2u"m", 3u"s"]
B = QuantityArray([1, 2, 3], u"m")
@test_throws DimensionError isapprox(A, B)
@test_throws DimensionError isapprox(B, A)

A = [1u"m", 2u"m"]
B = [1u"m", 2u"s"]
@test_throws DimensionError isapprox(A, B)
@test_throws DimensionError isapprox(B, A)

# With different rtoldefault:
A = QuantityArray([1, 2, 3], Quantity{Float16}(u"m"))
B = QuantityArray([1.01, 2.01, 3.01], Quantity{Float64}(u"m"))
@test isapprox(A, B)
@test isapprox(B, A) # Because we get it from Float16
@test !isapprox(Quantity{Float64}.(A), B)
@test !isapprox(B, Quantity{Float64}.(A))

# With explicit atol=0, rtol=0
A = [1u"m", 2u"m", 3u"m"]
B = [1u"m", 2u"m", 3u"m"]
@test isapprox(A, B, atol=0u"m", rtol=0)
B = [1.00000001u"m", 2u"m", 3u"m"]
@test !isapprox(A, B, atol=0u"m", rtol=0)
end

x = randn(32) .* u"km/s"
@test ustrip.(x) == [ustrip(xi) for xi in x]
@test dimension.(x) == [dimension(u"km/s") for xi in x]
Expand Down Expand Up @@ -817,7 +872,7 @@ end
q = convert(Q{Float16}, 1.5u"g")
qs = uconvert(convert(Q{Float16}, us"g"), 5 * q)
@test typeof(qs) <: Q{Float16,<:SymbolicDimensions{<:Any}}
@test isapprox(qs, 7.5us"g"; atol=0.01)
@test isapprox(qs, 7.5us"g"; atol=0.01us"g")

# Arrays
x = [1.0, 2.0, 3.0] .* Q(u"kg")
Expand Down Expand Up @@ -1021,6 +1076,10 @@ end
@test dimension(output_s_x) == dimension(x)^2
fv_square2(x) = (xi -> xi^2).(x)
@inferred fv_square2(s_x)

# isapprox for QuantityArray's
@test isapprox(x, x, atol=Q(1e-6u"km/s"))
@test_throws DimensionError isapprox(x, x, atol=1e-6)
end

@testset "Copying $Q" begin
Expand Down

0 comments on commit 4b4ba62

Please sign in to comment.