Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move implementations into type overloading (aka. functor) #139

Merged
merged 17 commits into from
Aug 15, 2019
Merged
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,14 @@ Each distance corresponds to a *distance type*. You can always compute a certain

```julia
r = evaluate(dist, x, y)
r = dist(x, y)
```

Here, dist is an instance of a distance type. For example, the type for Euclidean distance is ``Euclidean`` (more distance types will be introduced in the next section), then you can compute the Euclidean distance between ``x`` and ``y`` as

```julia
r = evaluate(Euclidean(), x, y)
r = Euclidean()(x, y)
```

Common distances also come with convenient functions for distance evaluation. For example, you may also compute Euclidean distance between two vectors as below
Expand Down
14 changes: 6 additions & 8 deletions src/bhattacharyya.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,11 @@ bhattacharyya_coeff(a::T, b::T) where {T <: Number} = throw("Bhattacharyya coeff


# Bhattacharyya distance
evaluate(dist::BhattacharyyaDist, a::AbstractVector{T}, b::AbstractVector{T}) where {T <: Number} = -log(bhattacharyya_coeff(a, b))
bhattacharyya(a::AbstractVector, b::AbstractVector) = evaluate(BhattacharyyaDist(), a, b)
evaluate(dist::BhattacharyyaDist, a::T, b::T) where {T <: Number} = throw("Bhattacharyya distance cannot be calculated for scalars")
bhattacharyya(a::T, b::T) where {T <: Number} = evaluate(BhattacharyyaDist(), a, b)
(::BhattacharyyaDist)(a::AbstractVector{T}, b::AbstractVector{T}) where {T <: Number} = -log(bhattacharyya_coeff(a, b))
(::BhattacharyyaDist)(a::T, b::T) where {T <: Number} = throw("Bhattacharyya distance cannot be calculated for scalars")
bhattacharyya(a, b) = BhattacharyyaDist()(a, b)

# Hellinger distance
evaluate(dist::HellingerDist, a::AbstractVector{T}, b::AbstractVector{T}) where {T <: Number} = sqrt(1 - bhattacharyya_coeff(a, b))
hellinger(a::AbstractVector, b::AbstractVector) = evaluate(HellingerDist(), a, b)
evaluate(dist::HellingerDist, a::T, b::T) where {T <: Number} = throw("Hellinger distance cannot be calculated for scalars")
hellinger(a::T, b::T) where {T <: Number} = evaluate(HellingerDist(), a, b)
(::HellingerDist)(a::AbstractVector{T}, b::AbstractVector{T}) where {T <: Number} = sqrt(1 - bhattacharyya_coeff(a, b))
(::HellingerDist)(a::T, b::T) where {T <: Number} = throw("Hellinger distance cannot be calculated for scalars")
hellinger(a, b) = HellingerDist()(a, b)
4 changes: 2 additions & 2 deletions src/bregman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ end
Bregman(F, ∇) = Bregman(F, ∇, LinearAlgebra.dot)

# Evaluation fuction
function evaluate(dist::Bregman, p::AbstractVector, q::AbstractVector)
function (dist::Bregman)(p::AbstractVector, q::AbstractVector)
# Create cache vals.
FP_val = dist.F(p);
FQ_val = dist.F(q);
Expand All @@ -45,4 +45,4 @@ function evaluate(dist::Bregman, p::AbstractVector, q::AbstractVector)
end

# Convenience function.
bregman(F, ∇, x, y; inner = LinearAlgebra.dot) = evaluate(Bregman(F, ∇, inner), x, y)
bregman(F, ∇, x, y; inner = LinearAlgebra.dot) = Bregman(F, ∇, inner)(x, y)
11 changes: 6 additions & 5 deletions src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ abstract type SemiMetric <: PreMetric end
#
abstract type Metric <: SemiMetric end

evaluate(dist::PreMetric, a, b) = dist(a, b)

# Generic functions

Expand All @@ -41,7 +42,7 @@ function colwise!(r::AbstractArray, metric::PreMetric, a::AbstractVector, b::Abs
n = size(b, 2)
length(r) == n || throw(DimensionMismatch("Incorrect size of r."))
@inbounds for j = 1:n
r[j] = evaluate(metric, a, view(b, :, j))
r[j] = metric(a, view(b, :, j))
end
r
end
Expand All @@ -50,7 +51,7 @@ function colwise!(r::AbstractArray, metric::PreMetric, a::AbstractMatrix, b::Abs
n = size(a, 2)
length(r) == n || throw(DimensionMismatch("Incorrect size of r."))
@inbounds for j = 1:n
r[j] = evaluate(metric, view(a, :, j), b)
r[j] = metric(view(a, :, j), b)
end
r
end
Expand All @@ -59,7 +60,7 @@ function colwise!(r::AbstractArray, metric::PreMetric, a::AbstractMatrix, b::Abs
n = get_common_ncols(a, b)
length(r) == n || throw(DimensionMismatch("Incorrect size of r."))
@inbounds for j = 1:n
r[j] = evaluate(metric, view(a, :, j), view(b, :, j))
r[j] = metric(view(a, :, j), view(b, :, j))
end
r
end
Expand Down Expand Up @@ -97,7 +98,7 @@ function _pairwise!(r::AbstractMatrix, metric::PreMetric,
@inbounds for j = 1:size(b, 2)
bj = view(b, :, j)
for i = 1:size(a, 2)
r[i, j] = evaluate(metric, view(a, :, i), bj)
r[i, j] = metric(view(a, :, i), bj)
end
end
r
Expand All @@ -109,7 +110,7 @@ function _pairwise!(r::AbstractMatrix, metric::SemiMetric, a::AbstractMatrix)
@inbounds for j = 1:n
aj = view(a, :, j)
for i = (j + 1):n
r[i, j] = evaluate(metric, view(a, :, i), aj)
r[i, j] = metric(view(a, :, i), aj)
end
r[j, j] = 0
for i = 1:(j - 1)
Expand Down
4 changes: 2 additions & 2 deletions src/haversine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ end

const VecOrLengthTwoTuple{T} = Union{AbstractVector{T}, NTuple{2, T}}

function evaluate(dist::Haversine, x::VecOrLengthTwoTuple, y::VecOrLengthTwoTuple)
function (dist::Haversine)(x::VecOrLengthTwoTuple, y::VecOrLengthTwoTuple)
length(x) == length(y) == 2 || haversine_error()

@inbounds begin
Expand All @@ -33,6 +33,6 @@ function evaluate(dist::Haversine, x::VecOrLengthTwoTuple, y::VecOrLengthTwoTupl
2 * dist.radius * asin( min(√a, one(a)) ) # take care of floating point errors
end

haversine(x::VecOrLengthTwoTuple, y::VecOrLengthTwoTuple, radius::Real) = evaluate(Haversine(radius), x, y)
haversine(x::VecOrLengthTwoTuple, y::VecOrLengthTwoTuple, radius::Real) = Haversine(radius)(x, y)

@noinline haversine_error() = throw(ArgumentError("expected both inputs to have length 2 in Haversine distance"))
10 changes: 5 additions & 5 deletions src/mahalanobis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ result_type(::SqMahalanobis{T}, ::Type, ::Type) where {T} = T

# SqMahalanobis

function evaluate(dist::SqMahalanobis{T}, a::AbstractVector, b::AbstractVector) where {T <: Real}
function (dist::SqMahalanobis{T})(a::AbstractVector, b::AbstractVector) where {T <: Real}
if length(a) != length(b)
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
end
Expand All @@ -23,7 +23,7 @@ function evaluate(dist::SqMahalanobis{T}, a::AbstractVector, b::AbstractVector)
return dot(z, Q * z)
end

sqmahalanobis(a::AbstractVector, b::AbstractVector, Q::AbstractMatrix) = evaluate(SqMahalanobis(Q), a, b)
sqmahalanobis(a::AbstractVector, b::AbstractVector, Q::AbstractMatrix) = SqMahalanobis(Q)(a, b)

function colwise!(r::AbstractArray, dist::SqMahalanobis{T}, a::AbstractMatrix, b::AbstractMatrix) where {T <: Real}
Q = dist.qmat
Expand Down Expand Up @@ -83,11 +83,11 @@ end

# Mahalanobis

function evaluate(dist::Mahalanobis{T}, a::AbstractVector, b::AbstractVector) where {T <: Real}
sqrt(evaluate(SqMahalanobis(dist.qmat), a, b))
function (dist::Mahalanobis{T})(a::AbstractVector, b::AbstractVector) where {T <: Real}
sqrt(SqMahalanobis(dist.qmat)(a, b))
end

mahalanobis(a::AbstractVector, b::AbstractVector, Q::AbstractMatrix) = evaluate(Mahalanobis(Q), a, b)
mahalanobis(a::AbstractVector, b::AbstractVector, Q::AbstractMatrix) = Mahalanobis(Q)(a, b)

function colwise!(r::AbstractArray, dist::Mahalanobis{T}, a::AbstractMatrix, b::AbstractMatrix) where {T <: Real}
sqrt!(colwise!(r, SqMahalanobis(dist.qmat), a, b))
Expand Down