Skip to content

Commit

Permalink
Merge 89a8494 into 777f023
Browse files Browse the repository at this point in the history
  • Loading branch information
rmcaixeta committed Oct 7, 2020
2 parents 777f023 + 89a8494 commit 32ad5b0
Showing 1 changed file with 27 additions and 20 deletions.
47 changes: 27 additions & 20 deletions src/mahalanobis.jl
@@ -1,19 +1,26 @@
# Mahalanobis distances

struct Mahalanobis{T} <: Metric
qmat::Matrix{T}
struct Mahalanobis{M<:AbstractMatrix} <: Metric
qmat::M
end

struct SqMahalanobis{T} <: SemiMetric
qmat::Matrix{T}
struct SqMahalanobis{M<:AbstractMatrix} <: SemiMetric
qmat::M
end

result_type(::Mahalanobis{T}, ::Type, ::Type) where {T} = T
result_type(::SqMahalanobis{T}, ::Type, ::Type) where {T} = T
function result_type(d::Mahalanobis, ::Type{T1}, ::Type{T2}) where {T1,T2}
z = zero(T1) - zero(T2)
return typeof(sqrt(z * zero(eltype(d.qmat)) * z))
end

function result_type(d::SqMahalanobis, ::Type{T1}, ::Type{T2}) where {T1,T2}
z = zero(T1) - zero(T2)
return typeof(z * zero(eltype(d.qmat)) * z)
end

# SqMahalanobis

function (dist::SqMahalanobis{T})(a::AbstractVector, b::AbstractVector) where {T <: Real}
function (dist::SqMahalanobis)(a::AbstractVector, b::AbstractVector)
if length(a) != length(b)
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
end
Expand All @@ -25,23 +32,23 @@ end

sqmahalanobis(a::AbstractVector, b::AbstractVector, Q::AbstractMatrix) = SqMahalanobis(Q)(a, b)

function colwise!(r::AbstractArray, dist::SqMahalanobis{T}, a::AbstractMatrix, b::AbstractMatrix) where {T <: Real}
function colwise!(r::AbstractArray, dist::SqMahalanobis, a::AbstractMatrix, b::AbstractMatrix)
Q = dist.qmat
m, n = get_colwise_dims(size(Q, 1), r, a, b)
z = a - b
dot_percol!(r, Q * z, z)
end

function colwise!(r::AbstractArray, dist::SqMahalanobis{T}, a::AbstractVector, b::AbstractMatrix) where {T <: Real}
function colwise!(r::AbstractArray, dist::SqMahalanobis, a::AbstractVector, b::AbstractMatrix)
Q = dist.qmat
m, n = get_colwise_dims(size(Q, 1), r, a, b)
z = a .- b
Qz = Q * z
dot_percol!(r, Q * z, z)
end

function _pairwise!(r::AbstractMatrix, dist::SqMahalanobis{T},
a::AbstractMatrix, b::AbstractMatrix) where {T <: Real}
function _pairwise!(r::AbstractMatrix, dist::SqMahalanobis,
a::AbstractMatrix, b::AbstractMatrix)
Q = dist.qmat
m, na, nb = get_pairwise_dims(size(Q, 1), r, a, b)

Expand All @@ -59,8 +66,8 @@ function _pairwise!(r::AbstractMatrix, dist::SqMahalanobis{T},
r
end

function _pairwise!(r::AbstractMatrix, dist::SqMahalanobis{T},
a::AbstractMatrix) where {T <: Real}
function _pairwise!(r::AbstractMatrix, dist::SqMahalanobis,
a::AbstractMatrix)
Q = dist.qmat
m, n = get_pairwise_dims(size(Q, 1), r, a)

Expand All @@ -83,26 +90,26 @@ end

# Mahalanobis

function (dist::Mahalanobis{T})(a::AbstractVector, b::AbstractVector) where {T <: Real}
function (dist::Mahalanobis)(a::AbstractVector, b::AbstractVector)
sqrt(SqMahalanobis(dist.qmat)(a, b))
end

mahalanobis(a::AbstractVector, b::AbstractVector, Q::AbstractMatrix) = Mahalanobis(Q)(a, b)

function colwise!(r::AbstractArray, dist::Mahalanobis{T}, a::AbstractMatrix, b::AbstractMatrix) where {T <: Real}
function colwise!(r::AbstractArray, dist::Mahalanobis, a::AbstractMatrix, b::AbstractMatrix)
sqrt!(colwise!(r, SqMahalanobis(dist.qmat), a, b))
end

function colwise!(r::AbstractArray, dist::Mahalanobis{T}, a::AbstractVector, b::AbstractMatrix) where {T <: Real}
function colwise!(r::AbstractArray, dist::Mahalanobis, a::AbstractVector, b::AbstractMatrix)
sqrt!(colwise!(r, SqMahalanobis(dist.qmat), a, b))
end

function _pairwise!(r::AbstractMatrix, dist::Mahalanobis{T},
a::AbstractMatrix, b::AbstractMatrix) where {T <: Real}
function _pairwise!(r::AbstractMatrix, dist::Mahalanobis,
a::AbstractMatrix, b::AbstractMatrix)
sqrt!(_pairwise!(r, SqMahalanobis(dist.qmat), a, b))
end

function _pairwise!(r::AbstractMatrix, dist::Mahalanobis{T},
a::AbstractMatrix) where {T <: Real}
function _pairwise!(r::AbstractMatrix, dist::Mahalanobis,
a::AbstractMatrix)
sqrt!(_pairwise!(r, SqMahalanobis(dist.qmat), a))
end

0 comments on commit 32ad5b0

Please sign in to comment.