diff --git a/README.md b/README.md index 97f2762..4069190 100644 --- a/README.md +++ b/README.md @@ -54,12 +54,14 @@ Each distance corresponds to a *distance type*. You can always compute a certain ```julia r = evaluate(dist, x, y) +r = dist(x, y) ``` Here, dist is an instance of a distance type. For example, the type for Euclidean distance is ``Euclidean`` (more distance types will be introduced in the next section), then you can compute the Euclidean distance between ``x`` and ``y`` as ```julia r = evaluate(Euclidean(), x, y) +r = Euclidean()(x, y) ``` Common distances also come with convenient functions for distance evaluation. For example, you may also compute Euclidean distance between two vectors as below diff --git a/src/bhattacharyya.jl b/src/bhattacharyya.jl index 06ec9ba..60a7ae1 100644 --- a/src/bhattacharyya.jl +++ b/src/bhattacharyya.jl @@ -37,13 +37,11 @@ bhattacharyya_coeff(a::T, b::T) where {T <: Number} = throw("Bhattacharyya coeff # Bhattacharyya distance -evaluate(dist::BhattacharyyaDist, a::AbstractVector{T}, b::AbstractVector{T}) where {T <: Number} = -log(bhattacharyya_coeff(a, b)) -bhattacharyya(a::AbstractVector, b::AbstractVector) = evaluate(BhattacharyyaDist(), a, b) -evaluate(dist::BhattacharyyaDist, a::T, b::T) where {T <: Number} = throw("Bhattacharyya distance cannot be calculated for scalars") -bhattacharyya(a::T, b::T) where {T <: Number} = evaluate(BhattacharyyaDist(), a, b) +(::BhattacharyyaDist)(a::AbstractVector{T}, b::AbstractVector{T}) where {T <: Number} = -log(bhattacharyya_coeff(a, b)) +(::BhattacharyyaDist)(a::T, b::T) where {T <: Number} = throw("Bhattacharyya distance cannot be calculated for scalars") +bhattacharyya(a, b) = BhattacharyyaDist()(a, b) # Hellinger distance -evaluate(dist::HellingerDist, a::AbstractVector{T}, b::AbstractVector{T}) where {T <: Number} = sqrt(1 - bhattacharyya_coeff(a, b)) -hellinger(a::AbstractVector, b::AbstractVector) = evaluate(HellingerDist(), a, b) -evaluate(dist::HellingerDist, a::T, b::T) where {T <: Number} = throw("Hellinger distance cannot be calculated for scalars") -hellinger(a::T, b::T) where {T <: Number} = evaluate(HellingerDist(), a, b) +(::HellingerDist)(a::AbstractVector{T}, b::AbstractVector{T}) where {T <: Number} = sqrt(1 - bhattacharyya_coeff(a, b)) +(::HellingerDist)(a::T, b::T) where {T <: Number} = throw("Hellinger distance cannot be calculated for scalars") +hellinger(a, b) = HellingerDist()(a, b) diff --git a/src/bregman.jl b/src/bregman.jl index c4ae26e..afaa2ef 100644 --- a/src/bregman.jl +++ b/src/bregman.jl @@ -22,7 +22,7 @@ end Bregman(F, ∇) = Bregman(F, ∇, LinearAlgebra.dot) # Evaluation fuction -function evaluate(dist::Bregman, p::AbstractVector, q::AbstractVector) +function (dist::Bregman)(p::AbstractVector, q::AbstractVector) # Create cache vals. FP_val = dist.F(p); FQ_val = dist.F(q); @@ -45,4 +45,4 @@ function evaluate(dist::Bregman, p::AbstractVector, q::AbstractVector) end # Convenience function. -bregman(F, ∇, x, y; inner = LinearAlgebra.dot) = evaluate(Bregman(F, ∇, inner), x, y) +bregman(F, ∇, x, y; inner = LinearAlgebra.dot) = Bregman(F, ∇, inner)(x, y) diff --git a/src/generic.jl b/src/generic.jl index f296a01..1af118e 100644 --- a/src/generic.jl +++ b/src/generic.jl @@ -21,6 +21,7 @@ abstract type SemiMetric <: PreMetric end # abstract type Metric <: SemiMetric end +evaluate(dist::PreMetric, a, b) = dist(a, b) # Generic functions @@ -41,7 +42,7 @@ function colwise!(r::AbstractArray, metric::PreMetric, a::AbstractVector, b::Abs n = size(b, 2) length(r) == n || throw(DimensionMismatch("Incorrect size of r.")) @inbounds for j = 1:n - r[j] = evaluate(metric, a, view(b, :, j)) + r[j] = metric(a, view(b, :, j)) end r end @@ -50,7 +51,7 @@ function colwise!(r::AbstractArray, metric::PreMetric, a::AbstractMatrix, b::Abs n = size(a, 2) length(r) == n || throw(DimensionMismatch("Incorrect size of r.")) @inbounds for j = 1:n - r[j] = evaluate(metric, view(a, :, j), b) + r[j] = metric(view(a, :, j), b) end r end @@ -59,7 +60,7 @@ function colwise!(r::AbstractArray, metric::PreMetric, a::AbstractMatrix, b::Abs n = get_common_ncols(a, b) length(r) == n || throw(DimensionMismatch("Incorrect size of r.")) @inbounds for j = 1:n - r[j] = evaluate(metric, view(a, :, j), view(b, :, j)) + r[j] = metric(view(a, :, j), view(b, :, j)) end r end @@ -97,7 +98,7 @@ function _pairwise!(r::AbstractMatrix, metric::PreMetric, @inbounds for j = 1:size(b, 2) bj = view(b, :, j) for i = 1:size(a, 2) - r[i, j] = evaluate(metric, view(a, :, i), bj) + r[i, j] = metric(view(a, :, i), bj) end end r @@ -109,7 +110,7 @@ function _pairwise!(r::AbstractMatrix, metric::SemiMetric, a::AbstractMatrix) @inbounds for j = 1:n aj = view(a, :, j) for i = (j + 1):n - r[i, j] = evaluate(metric, view(a, :, i), aj) + r[i, j] = metric(view(a, :, i), aj) end r[j, j] = 0 for i = 1:(j - 1) diff --git a/src/haversine.jl b/src/haversine.jl index 018924e..8a08cff 100644 --- a/src/haversine.jl +++ b/src/haversine.jl @@ -12,7 +12,7 @@ end const VecOrLengthTwoTuple{T} = Union{AbstractVector{T}, NTuple{2, T}} -function evaluate(dist::Haversine, x::VecOrLengthTwoTuple, y::VecOrLengthTwoTuple) +function (dist::Haversine)(x::VecOrLengthTwoTuple, y::VecOrLengthTwoTuple) length(x) == length(y) == 2 || haversine_error() @inbounds begin @@ -33,6 +33,6 @@ function evaluate(dist::Haversine, x::VecOrLengthTwoTuple, y::VecOrLengthTwoTupl 2 * dist.radius * asin( min(√a, one(a)) ) # take care of floating point errors end -haversine(x::VecOrLengthTwoTuple, y::VecOrLengthTwoTuple, radius::Real) = evaluate(Haversine(radius), x, y) +haversine(x::VecOrLengthTwoTuple, y::VecOrLengthTwoTuple, radius::Real) = Haversine(radius)(x, y) @noinline haversine_error() = throw(ArgumentError("expected both inputs to have length 2 in Haversine distance")) diff --git a/src/mahalanobis.jl b/src/mahalanobis.jl index 79ab8b0..d04790f 100644 --- a/src/mahalanobis.jl +++ b/src/mahalanobis.jl @@ -13,7 +13,7 @@ result_type(::SqMahalanobis{T}, ::Type, ::Type) where {T} = T # SqMahalanobis -function evaluate(dist::SqMahalanobis{T}, a::AbstractVector, b::AbstractVector) where {T <: Real} +function (dist::SqMahalanobis{T})(a::AbstractVector, b::AbstractVector) where {T <: Real} if length(a) != length(b) throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b)).")) end @@ -23,7 +23,7 @@ function evaluate(dist::SqMahalanobis{T}, a::AbstractVector, b::AbstractVector) return dot(z, Q * z) end -sqmahalanobis(a::AbstractVector, b::AbstractVector, Q::AbstractMatrix) = evaluate(SqMahalanobis(Q), a, b) +sqmahalanobis(a::AbstractVector, b::AbstractVector, Q::AbstractMatrix) = SqMahalanobis(Q)(a, b) function colwise!(r::AbstractArray, dist::SqMahalanobis{T}, a::AbstractMatrix, b::AbstractMatrix) where {T <: Real} Q = dist.qmat @@ -83,11 +83,11 @@ end # Mahalanobis -function evaluate(dist::Mahalanobis{T}, a::AbstractVector, b::AbstractVector) where {T <: Real} - sqrt(evaluate(SqMahalanobis(dist.qmat), a, b)) +function (dist::Mahalanobis{T})(a::AbstractVector, b::AbstractVector) where {T <: Real} + sqrt(SqMahalanobis(dist.qmat)(a, b)) end -mahalanobis(a::AbstractVector, b::AbstractVector, Q::AbstractMatrix) = evaluate(Mahalanobis(Q), a, b) +mahalanobis(a::AbstractVector, b::AbstractVector, Q::AbstractMatrix) = Mahalanobis(Q)(a, b) function colwise!(r::AbstractArray, dist::Mahalanobis{T}, a::AbstractMatrix, b::AbstractMatrix) where {T <: Real} sqrt!(colwise!(r, SqMahalanobis(dist.qmat), a, b)) diff --git a/src/metrics.jl b/src/metrics.jl index 2c714bc..6047b27 100644 --- a/src/metrics.jl +++ b/src/metrics.jl @@ -25,6 +25,7 @@ end struct Hamming <: Metric end struct CosineDist <: SemiMetric end +# CorrDist is excluded from `UnionMetrics` struct CorrDist <: SemiMetric end struct BrayCurtis <: SemiMetric end @@ -103,7 +104,8 @@ struct PeriodicEuclidean{W <: AbstractArray{<: Real}} <: Metric periods::W end -const UnionMetrics = Union{Euclidean,SqEuclidean,PeriodicEuclidean,Chebyshev,Cityblock,TotalVariation,Minkowski,Hamming,Jaccard,RogersTanimoto,CosineDist,CorrDist,ChiSqDist,KLDivergence,RenyiDivergence,BrayCurtis,JSDivergence,SpanNormDist,GenKLDivergence} +const metrics = (Euclidean,SqEuclidean,PeriodicEuclidean,Chebyshev,Cityblock,TotalVariation,Minkowski,Hamming,Jaccard,RogersTanimoto,CosineDist,ChiSqDist,KLDivergence,RenyiDivergence,BrayCurtis,JSDivergence,SpanNormDist,GenKLDivergence) +const UnionMetrics = Union{metrics...} """ Euclidean([thresh]) @@ -165,7 +167,7 @@ PeriodicEuclidean() = PeriodicEuclidean(Int[]) ########################################################### # -# Define Evaluate +# Implementations # ########################################################### @@ -173,8 +175,10 @@ const ArraySlice{T} = SubArray{T,1,Array{T,2},Tuple{Base.Slice{Base.OneTo{Int}}, @inline parameters(::UnionMetrics) = nothing +# breaks the implementation into eval_start, eval_op, eval_reduce and eval_end + # Specialized for Arrays and avoids a branch on the size -@inline Base.@propagate_inbounds function evaluate(d::UnionMetrics, a::Union{Array, ArraySlice}, b::Union{Array, ArraySlice}) +@inline Base.@propagate_inbounds function _evaluate(d::UnionMetrics, a::Union{Array, ArraySlice}, b::Union{Array, ArraySlice}) @boundscheck if length(a) != length(b) throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b)).")) end @@ -205,7 +209,7 @@ const ArraySlice{T} = SubArray{T,1,Array{T,2},Tuple{Base.Slice{Base.OneTo{Int}}, end end -@inline function evaluate(d::UnionMetrics, a::AbstractArray, b::AbstractArray) +@inline function _evaluate(d::UnionMetrics, a::AbstractArray, b::AbstractArray) @boundscheck if length(a) != length(b) throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b)).")) end @@ -258,21 +262,24 @@ eval_start(d::UnionMetrics, a::AbstractArray, b::AbstractArray) = zero(result_type(d, a, b)) eval_end(d::UnionMetrics, s) = s -evaluate(dist::UnionMetrics, a::Number, b::Number) = eval_end(dist, eval_op(dist, a, b)) +for M in metrics + @eval @inline (dist::$M)(a::AbstractArray, b::AbstractArray) = _evaluate(dist, a, b) + @eval @inline (dist::$M)(a::Number, b::Number) = eval_end(dist, eval_op(dist, a, b)) +end # SqEuclidean @inline eval_op(::SqEuclidean, ai, bi) = abs2(ai - bi) @inline eval_reduce(::SqEuclidean, s1, s2) = s1 + s2 -sqeuclidean(a::AbstractArray, b::AbstractArray) = evaluate(SqEuclidean(), a, b) -sqeuclidean(a::Number, b::Number) = evaluate(SqEuclidean(), a, b) +sqeuclidean(a::AbstractArray, b::AbstractArray) = SqEuclidean()(a, b) +sqeuclidean(a::Number, b::Number) = SqEuclidean()(a, b) # Euclidean @inline eval_op(::Euclidean, ai, bi) = abs2(ai - bi) @inline eval_reduce(::Euclidean, s1, s2) = s1 + s2 eval_end(::Euclidean, s) = sqrt(s) -euclidean(a::AbstractArray, b::AbstractArray) = evaluate(Euclidean(), a, b) -euclidean(a::Number, b::Number) = evaluate(Euclidean(), a, b) +euclidean(a::AbstractArray, b::AbstractArray) = Euclidean()(a, b) +euclidean(a::Number, b::Number) = Euclidean()(a, b) # PeriodicEuclidean Base.eltype(d::PeriodicEuclidean) = eltype(d.periods) @@ -291,42 +298,42 @@ end @inline eval_reduce(::PeriodicEuclidean, s1, s2) = s1 + s2 @inline eval_end(::PeriodicEuclidean, s) = sqrt(s) peuclidean(a::AbstractArray, b::AbstractArray, p::AbstractArray{<: Real}) = - evaluate(PeriodicEuclidean(p), a, b) -peuclidean(a::Number, b::Number, p::Real) = evaluate(PeriodicEuclidean([p]), a, b) + PeriodicEuclidean(p)(a, b) +peuclidean(a::Number, b::Number, p::Real) = PeriodicEuclidean([p])(a, b) # Cityblock @inline eval_op(::Cityblock, ai, bi) = abs(ai - bi) @inline eval_reduce(::Cityblock, s1, s2) = s1 + s2 -cityblock(a::AbstractArray, b::AbstractArray) = evaluate(Cityblock(), a, b) -cityblock(a::Number, b::Number) = evaluate(Cityblock(), a, b) +cityblock(a::AbstractArray, b::AbstractArray) = Cityblock()(a, b) +cityblock(a::Number, b::Number) = Cityblock()(a, b) # Total variation @inline eval_op(::TotalVariation, ai, bi) = abs(ai - bi) @inline eval_reduce(::TotalVariation, s1, s2) = s1 + s2 eval_end(::TotalVariation, s) = s / 2 -totalvariation(a::AbstractArray, b::AbstractArray) = evaluate(TotalVariation(), a, b) -totalvariation(a::Number, b::Number) = evaluate(TotalVariation(), a, b) +totalvariation(a::AbstractArray, b::AbstractArray) = TotalVariation()(a, b) +totalvariation(a::Number, b::Number) = TotalVariation()(a, b) # Chebyshev @inline eval_op(::Chebyshev, ai, bi) = abs(ai - bi) @inline eval_reduce(::Chebyshev, s1, s2) = max(s1, s2) # if only NaN, will output NaN @inline Base.@propagate_inbounds eval_start(::Chebyshev, a::AbstractArray, b::AbstractArray) = abs(a[1] - b[1]) -chebyshev(a::AbstractArray, b::AbstractArray) = evaluate(Chebyshev(), a, b) -chebyshev(a::Number, b::Number) = evaluate(Chebyshev(), a, b) +chebyshev(a::AbstractArray, b::AbstractArray) = Chebyshev()(a, b) +chebyshev(a::Number, b::Number) = Chebyshev()(a, b) # Minkowski @inline eval_op(dist::Minkowski, ai, bi) = abs(ai - bi).^dist.p @inline eval_reduce(::Minkowski, s1, s2) = s1 + s2 eval_end(dist::Minkowski, s) = s.^(1 / dist.p) -minkowski(a::AbstractArray, b::AbstractArray, p::Real) = evaluate(Minkowski(p), a, b) -minkowski(a::Number, b::Number, p::Real) = evaluate(Minkowski(p), a, b) +minkowski(a::AbstractArray, b::AbstractArray, p::Real) = Minkowski(p)(a, b) +minkowski(a::Number, b::Number, p::Real) = Minkowski(p)(a, b) # Hamming @inline eval_op(::Hamming, ai, bi) = ai != bi ? 1 : 0 @inline eval_reduce(::Hamming, s1, s2) = s1 + s2 -hamming(a::AbstractArray, b::AbstractArray) = evaluate(Hamming(), a, b) -hamming(a::Number, b::Number) = evaluate(Hamming(), a, b) +hamming(a::AbstractArray, b::AbstractArray) = Hamming()(a, b) +hamming(a::Number, b::Number) = Hamming()(a, b) # Cosine dist @inline function eval_start(dist::CosineDist, a::AbstractArray, b::AbstractArray) @@ -343,32 +350,30 @@ function eval_end(::CosineDist, s) ab, a2, b2 = s max(1 - ab / (sqrt(a2) * sqrt(b2)), zero(eltype(ab))) end -cosine_dist(a::AbstractArray, b::AbstractArray) = evaluate(CosineDist(), a, b) +cosine_dist(a::AbstractArray, b::AbstractArray) = CosineDist()(a, b) # Correlation Dist _centralize(x::AbstractArray) = x .- mean(x) -evaluate(::CorrDist, a::AbstractArray, b::AbstractArray) = cosine_dist(_centralize(a), _centralize(b)) -# Ambiguity resolution -evaluate(::CorrDist, a::Array, b::Array) = cosine_dist(_centralize(a), _centralize(b)) -corr_dist(a::AbstractArray, b::AbstractArray) = evaluate(CorrDist(), a, b) +(dist::CorrDist)(a::AbstractArray, b::AbstractArray) = CosineDist()(_centralize(a), _centralize(b)) +corr_dist(a::AbstractArray, b::AbstractArray) = CorrDist()(a, b) result_type(::CorrDist, Ta::Type, Tb::Type) = result_type(CosineDist(), Ta, Tb) # ChiSqDist @inline eval_op(::ChiSqDist, ai, bi) = (d = abs2(ai - bi) / (ai + bi); ifelse(ai != bi, d, zero(d))) @inline eval_reduce(::ChiSqDist, s1, s2) = s1 + s2 -chisq_dist(a::AbstractArray, b::AbstractArray) = evaluate(ChiSqDist(), a, b) +chisq_dist(a::AbstractArray, b::AbstractArray) = ChiSqDist()(a, b) # KLDivergence @inline eval_op(dist::KLDivergence, ai, bi) = ai > 0 ? ai * log(ai / bi) : zero(eval_op(dist, oneunit(ai), bi)) @inline eval_reduce(::KLDivergence, s1, s2) = s1 + s2 -kl_divergence(a::AbstractArray, b::AbstractArray) = evaluate(KLDivergence(), a, b) +kl_divergence(a::AbstractArray, b::AbstractArray) = KLDivergence()(a, b) # GenKLDivergence @inline eval_op(dist::GenKLDivergence, ai, bi) = ai > 0 ? ai * log(ai / bi) - ai + bi : oftype(eval_op(dist, oneunit(ai), bi), bi) @inline eval_reduce(::GenKLDivergence, s1, s2) = s1 + s2 -gkl_divergence(a::AbstractArray, b::AbstractArray) = evaluate(GenKLDivergence(), a, b) +gkl_divergence(a::AbstractArray, b::AbstractArray) = GenKLDivergence()(a, b) # RenyiDivergence @inline Base.@propagate_inbounds function eval_start(::RenyiDivergence, a::AbstractArray{T}, b::AbstractArray{T}) where {T <: Real} @@ -415,7 +420,7 @@ function eval_end(dist::RenyiDivergence, s::Tuple{T,T,T,T}) where {T <: Real} end end -renyi_divergence(a::AbstractArray, b::AbstractArray, q::Real) = evaluate(RenyiDivergence(q), a, b) +renyi_divergence(a::AbstractArray, b::AbstractArray, q::Real) = RenyiDivergence(q)(a, b) # Combine docs with RenyiDivergence. Fetching the docstring with @doc causes # problems during package compilation; see # https://github.com/JuliaLang/julia/issues/31640 @@ -432,7 +437,7 @@ end ta + tb - tu end @inline eval_reduce(::JSDivergence, s1, s2) = s1 + s2 -js_divergence(a::AbstractArray, b::AbstractArray) = evaluate(JSDivergence(), a, b) +js_divergence(a::AbstractArray, b::AbstractArray) = JSDivergence()(a, b) # SpanNormDist @inline Base.@propagate_inbounds function eval_start(::SpanNormDist, a::AbstractArray, b::AbstractArray) @@ -450,7 +455,7 @@ end end eval_end(::SpanNormDist, s) = s[2] - s[1] -spannorm_dist(a::AbstractArray, b::AbstractArray) = evaluate(SpanNormDist(), a, b) +spannorm_dist(a::AbstractArray, b::AbstractArray) = SpanNormDist()(a, b) result_type(dist::SpanNormDist, Ta::Type, Tb::Type) = typeof(eval_op(dist, oneunit(Ta), oneunit(Tb))) @@ -475,7 +480,7 @@ end @inbounds v = 1 - (a[1] / a[2]) return v end -jaccard(a::AbstractArray, b::AbstractArray) = evaluate(Jaccard(), a, b) +jaccard(a::AbstractArray, b::AbstractArray) = Jaccard()(a, b) # BrayCurtis @@ -498,7 +503,7 @@ end @inbounds v = a[1] / a[2] return v end -braycurtis(a::AbstractArray, b::AbstractArray) = evaluate(BrayCurtis(), a, b) +braycurtis(a::AbstractArray, b::AbstractArray) = BrayCurtis()(a, b) # Tanimoto @@ -525,24 +530,24 @@ end @inbounds denominator = a[1] + a[4] + 2(a[2] + a[3]) numerator / denominator end -rogerstanimoto(a::AbstractArray{T}, b::AbstractArray{T}) where {T <: Bool} = evaluate(RogersTanimoto(), a, b) +rogerstanimoto(a::AbstractArray{T}, b::AbstractArray{T}) where {T <: Bool} = RogersTanimoto()(a, b) # Deviations -evaluate(::MeanAbsDeviation, a, b) = cityblock(a, b) / length(a) -meanad(a, b) = evaluate(MeanAbsDeviation(), a, b) +(dist::MeanAbsDeviation)(a, b) = cityblock(a, b) / length(a) +meanad(a, b) = MeanAbsDeviation()(a, b) -evaluate(::MeanSqDeviation, a, b) = sqeuclidean(a, b) / length(a) -msd(a, b) = evaluate(MeanSqDeviation(), a, b) +(dist::MeanSqDeviation)(a, b) = sqeuclidean(a, b) / length(a) +msd(a, b) = MeanSqDeviation()(a, b) -evaluate(::RMSDeviation, a, b) = sqrt(evaluate(MeanSqDeviation(), a, b)) -rmsd(a, b) = evaluate(RMSDeviation(), a, b) +(dist::RMSDeviation)(a, b) = sqrt(MeanSqDeviation()(a, b)) +rmsd(a, b) = RMSDeviation()(a, b) -function evaluate(::NormRMSDeviation, a, b) +function (dist::NormRMSDeviation)(a, b) amin, amax = extrema(a) - return evaluate(RMSDeviation(), a, b) / (amax - amin) + return RMSDeviation()(a, b) / (amax - amin) end -nrmsd(a, b) = evaluate(NormRMSDeviation(), a, b) +nrmsd(a, b) = NormRMSDeviation()(a, b) ########################################################### @@ -706,20 +711,3 @@ function _pairwise!(r::AbstractMatrix, dist::CosineDist, a::AbstractMatrix) end r end - -# CorrDist -_centralize_colwise(x::AbstractVector) = x .- mean(x) -_centralize_colwise(x::AbstractMatrix) = x .- mean(x, dims=1) -function colwise!(r::AbstractVector, dist::CorrDist, a::AbstractMatrix, b::AbstractMatrix) - colwise!(r, CosineDist(), _centralize_colwise(a), _centralize_colwise(b)) -end -function colwise!(r::AbstractVector, dist::CorrDist, a::AbstractVector, b::AbstractMatrix) - colwise!(r, CosineDist(), _centralize_colwise(a), _centralize_colwise(b)) -end -function _pairwise!(r::AbstractMatrix, dist::CorrDist, - a::AbstractMatrix, b::AbstractMatrix) - _pairwise!(r, CosineDist(), _centralize_colwise(a), _centralize_colwise(b)) -end -function _pairwise!(r::AbstractMatrix, dist::CorrDist, a::AbstractMatrix) - _pairwise!(r, CosineDist(), _centralize_colwise(a)) -end diff --git a/src/wmetrics.jl b/src/wmetrics.jl index cb8c0c7..9e41bdc 100644 --- a/src/wmetrics.jl +++ b/src/wmetrics.jl @@ -30,18 +30,15 @@ struct WeightedHamming{W <: RealAbstractArray} <: Metric weights::W end - -const UnionWeightedMetrics{W} = Union{WeightedEuclidean{W},WeightedSqEuclidean{W},WeightedCityblock{W},WeightedMinkowski{W},WeightedHamming{W}} +const weightedmetrics = (WeightedEuclidean,WeightedSqEuclidean,WeightedCityblock,WeightedMinkowski,WeightedHamming) +const UnionWeightedMetrics{W} = Union{map(M->M{W}, weightedmetrics)...} Base.eltype(x::UnionWeightedMetrics) = eltype(x.weights) ########################################################### # -# Evaluate +# Implementations # ########################################################### -function evaluate(dist::UnionWeightedMetrics, a::Number, b::Number) - eval_end(dist, eval_op(dist, a, b, oneunit(eltype(dist)))) -end result_type(dist::UnionWeightedMetrics, Ta::Type, Tb::Type) = typeof(evaluate(dist, oneunit(Ta), oneunit(Tb))) @@ -52,7 +49,7 @@ eval_end(d::UnionWeightedMetrics, s) = s -@inline function evaluate(d::UnionWeightedMetrics, a::AbstractArray, b::AbstractArray) +@inline function _evaluate(d::UnionWeightedMetrics, a::AbstractArray, b::AbstractArray) @boundscheck if length(a) != length(b) throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b)).")) end @@ -83,32 +80,37 @@ eval_end(d::UnionWeightedMetrics, s) = s return eval_end(d, s) end +for M in weightedmetrics + @eval (dist::$M)(a::AbstractArray, b::AbstractArray) = _evaluate(dist, a, b) + @eval (dist::$M)(a::Number, b::Number) = eval_end(dist, eval_op(dist, a, b, oneunit(eltype(dist)))) +end + # Squared Euclidean @inline eval_op(::WeightedSqEuclidean, ai, bi, wi) = abs2(ai - bi) * wi @inline eval_reduce(::WeightedSqEuclidean, s1, s2) = s1 + s2 -wsqeuclidean(a::AbstractArray, b::AbstractArray, w::AbstractArray) = evaluate(WeightedSqEuclidean(w), a, b) +wsqeuclidean(a::AbstractArray, b::AbstractArray, w::AbstractArray) = WeightedSqEuclidean(w)(a, b) # Weighted Euclidean @inline eval_op(::WeightedEuclidean, ai, bi, wi) = abs2(ai - bi) * wi @inline eval_reduce(::WeightedEuclidean, s1, s2) = s1 + s2 @inline eval_end(::WeightedEuclidean, s) = sqrt(s) -weuclidean(a::AbstractArray, b::AbstractArray, w::AbstractArray) = evaluate(WeightedEuclidean(w), a, b) +weuclidean(a::AbstractArray, b::AbstractArray, w::AbstractArray) = WeightedEuclidean(w)(a, b) # City Block @inline eval_op(::WeightedCityblock, ai, bi, wi) = abs((ai - bi) * wi) @inline eval_reduce(::WeightedCityblock, s1, s2) = s1 + s2 -wcityblock(a::AbstractArray, b::AbstractArray, w::AbstractArray) = evaluate(WeightedCityblock(w), a, b) +wcityblock(a::AbstractArray, b::AbstractArray, w::AbstractArray) = WeightedCityblock(w)(a, b) # Minkowski @inline eval_op(dist::WeightedMinkowski, ai, bi, wi) = abs(ai - bi).^dist.p * wi @inline eval_reduce(::WeightedMinkowski, s1, s2) = s1 + s2 eval_end(dist::WeightedMinkowski, s) = s.^(1 / dist.p) -wminkowski(a::AbstractArray, b::AbstractArray, w::AbstractArray, p::Real) = evaluate(WeightedMinkowski(w, p), a, b) +wminkowski(a::AbstractArray, b::AbstractArray, w::AbstractArray, p::Real) = WeightedMinkowski(w, p)(a, b) # WeightedHamming @inline eval_op(::WeightedHamming, ai, bi, wi) = ai != bi ? wi : zero(eltype(wi)) @inline eval_reduce(::WeightedHamming, s1, s2) = s1 + s2 -whamming(a::AbstractArray, b::AbstractArray, w::AbstractArray) = evaluate(WeightedHamming(w), a, b) +whamming(a::AbstractArray, b::AbstractArray, w::AbstractArray) = WeightedHamming(w)(a, b) ########################################################### # diff --git a/test/runtests.jl b/test/runtests.jl index eb9b147..6afb39a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,5 +5,7 @@ using LinearAlgebra using Random using Statistics +@test isempty(detect_ambiguities(Distances)) + include("F64.jl") include("test_dists.jl") diff --git a/test/test_dists.jl b/test/test_dists.jl index 1a31e60..06d2bc1 100644 --- a/test/test_dists.jl +++ b/test/test_dists.jl @@ -2,36 +2,38 @@ function test_metricity(dist, x, y, z) @testset "Test metricity of $(typeof(dist))" begin - dxy = evaluate(dist, x, y) - dxz = evaluate(dist, x, z) - dyz = evaluate(dist, y, z) + @test dist(x, y) == evaluate(dist, x, y) + + dxy = dist(x, y) + dxz = dist(x, z) + dyz = dist(y, z) if isa(dist, PreMetric) # Unfortunately small non-zero numbers (~10^-16) are appearing # in our tests due to accumulating floating point rounding errors. # We either need to allow small errors in our tests or change the # way we do accumulations... - @test evaluate(dist, x, x) + one(eltype(x)) ≈ one(eltype(x)) - @test evaluate(dist, y, y) + one(eltype(y)) ≈ one(eltype(y)) - @test evaluate(dist, z, z) + one(eltype(z)) ≈ one(eltype(z)) + @test dist(x, x) + one(eltype(x)) ≈ one(eltype(x)) + @test dist(y, y) + one(eltype(y)) ≈ one(eltype(y)) + @test dist(z, z) + one(eltype(z)) ≈ one(eltype(z)) @test dxy ≥ zero(eltype(x)) @test dxz ≥ zero(eltype(x)) @test dyz ≥ zero(eltype(x)) end if isa(dist, SemiMetric) - @test dxy ≈ evaluate(dist, y, x) - @test dxz ≈ evaluate(dist, z, x) - @test dyz ≈ evaluate(dist, y, z) + @test dxy ≈ dist(y, x) + @test dxz ≈ dist(z, x) + @test dyz ≈ dist(y, z) else # Not symmetric, so more PreMetric tests - @test evaluate(dist, y, x) ≥ zero(eltype(x)) - @test evaluate(dist, z, x) ≥ zero(eltype(x)) - @test evaluate(dist, z, y) ≥ zero(eltype(x)) + @test dist(y, x) ≥ zero(eltype(x)) + @test dist(z, x) ≥ zero(eltype(x)) + @test dist(z, y) ≥ zero(eltype(x)) end if isa(dist, Metric) # Again we have small rounding errors in accumulations @test dxz ≤ dxy + dyz || dxz ≈ dxy + dyz - dyx = evaluate(dist, y, x) + dyx = dist(y, x) @test dyz ≤ dyx + dxz || dyz ≈ dyx + dxz - dzy = evaluate(dist, z, y) + dzy = dist(z, y) @test dxy ≤ dxz + dzy || dxy ≈ dxz + dzy end end @@ -189,9 +191,9 @@ end @test whamming(a, b, w) === sum((a .!= b) .* w) # Minimal test of Jaccard - test return type stability. - @inferred evaluate(Jaccard(), rand(T, 3), rand(T, 3)) - @inferred evaluate(Jaccard(), [1, 2, 3], [1, 2, 3]) - @inferred evaluate(Jaccard(), [true, false, true], [false, true, true]) + @inferred Jaccard()(rand(T, 3), rand(T, 3)) + @inferred Jaccard()([1, 2, 3], [1, 2, 3]) + @inferred Jaccard()([true, false, true], [false, true, true]) # Test Bray-Curtis. Should be 1 if no elements are shared, 0 if all are the same @test braycurtis([1,0,3],[0,1,0]) == 1.0 @@ -295,8 +297,8 @@ end # testset @test_throws DimensionMismatch colwise!(mat23, Euclidean(), mat23, q) @test_throws DimensionMismatch colwise!(mat23, Euclidean(), mat23, mat22) @test_throws DimensionMismatch colwise!(mat23, Bregman(x -> sqeuclidean(x, zero(x)), x -> 2*x), mat23, mat22) - @test_throws DimensionMismatch evaluate(Bregman(x -> sqeuclidean(x, zero(x)), x -> 2*x), [1, 2, 3], [1, 2]) - @test_throws DimensionMismatch evaluate(Bregman(x -> sqeuclidean(x, zero(x)), x -> [1, 2]), [1, 2, 3], [1, 2, 3]) + @test_throws DimensionMismatch Bregman(x -> sqeuclidean(x, zero(x)), x -> 2*x)([1, 2, 3], [1, 2]) + @test_throws DimensionMismatch Bregman(x -> sqeuclidean(x, zero(x)), x -> [1, 2])([1, 2, 3], [1, 2, 3]) end # testset @testset "Different input types" begin @@ -409,9 +411,9 @@ function test_colwise(dist, x, y, T) r2 = zeros(T, n) r3 = zeros(T, n) for j = 1:n - r1[j] = evaluate(dist, x[:, j], y[:, j]) - r2[j] = evaluate(dist, x[:, 1], y[:, j]) - r3[j] = evaluate(dist, x[:, j], y[:, 1]) + r1[j] = dist(x[:, j], y[:, j]) + r2[j] = dist(x[:, 1], y[:, j]) + r3[j] = dist(x[:, j], y[:, 1]) end # ≈ and all( .≈ ) seem to behave slightly differently for F64 @test all(colwise(dist, x, y) .≈ r1) @@ -485,10 +487,10 @@ function test_pairwise(dist, x, y, T) rxy = zeros(T, nx, ny) rxx = zeros(T, nx, nx) for j = 1:ny, i = 1:nx - rxy[i, j] = evaluate(dist, x[:, i], y[:, j]) + rxy[i, j] = dist(x[:, i], y[:, j]) end for j = 1:nx, i = 1:nx - rxx[i, j] = evaluate(dist, x[:, i], x[:, j]) + rxx[i, j] = dist(x[:, i], x[:, j]) end # As earlier, we have small rounding errors in accumulations @test pairwise(dist, x, y) ≈ rxy @@ -582,7 +584,7 @@ end q = rand(4) p = p/sum(p); q = q/sum(q); - @test evaluate(testDist, p, q) ≈ gkl_divergence(p, q) + @test testDist(p, q) ≈ gkl_divergence(p, q) # Test if Bregman() correctly implements the squared euclidean dist. between them. @test bregman(x -> norm(x)^2, x -> 2*x, p, q) ≈ sqeuclidean(p, q) # Test if Bregman() correctly implements the IS distance.