Skip to content

Commit

Permalink
Add dims argument to pairwise
Browse files Browse the repository at this point in the history
This allows specifying whether observations are stored as columns (dims=2) or as rows (dims=1).
All functions are still written internally to perform computations over columns,
computations over rows are supported by calling transpose on the matrices.

Add a deprecation to require specifying explicitly dims=1 or dims=2: contrary to this package,
other implementations most frequently default to dims=1, and even in Julia the cov and cor functions
default to dims=1.
  • Loading branch information
nalimilan committed Feb 9, 2019
1 parent 785aaab commit af66c32
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 49 deletions.
18 changes: 11 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ r = euclidean(x, y)

#### Computing distances between corresponding columns

Suppose you have two ``m-by-n`` matrix ``X`` and ``Y``, then you can compute all distances between corresponding columns of X and Y in one batch, using the ``colwise`` function, as
Suppose you have two ``m-by-n`` matrix ``X`` and ``Y``, then you can compute all distances between corresponding columns of ``X`` and ``Y`` in one batch, using the ``colwise`` function, as

```julia
r = colwise(dist, X, Y)
Expand All @@ -81,31 +81,35 @@ Note that either of ``X`` and ``Y`` can be just a single vector -- then the ``co

#### Computing pairwise distances

Let ``X`` and ``Y`` respectively have ``m`` and ``n`` columns. Then the ``pairwise`` function computes distances between each pair of columns in ``X`` and ``Y``:
Let ``X`` and ``Y`` respectively have ``m`` and ``n`` columns. Then the ``pairwise`` function with the ``dims=2`` argument computes distances between each pair of columns in ``X`` and ``Y``:

```julia
R = pairwise(dist, X, Y)
R = pairwise(dist, X, Y, dims=2)
```

In the output, ``R`` is a matrix of size ``(m, n)``, such that ``R[i,j]`` is the distance between ``X[:,i]`` and ``Y[:,j]``. Computing distances for all pairs using ``pairwise`` function is often remarkably faster than evaluting for each pair individually.

If you just want to just compute distances between columns of a matrix ``X``, you can write

```julia
R = pairwise(dist, X)
R = pairwise(dist, X, dims=2)
```

This statement will result in an ``m-by-m`` matrix, where ``R[i,j]`` is the distance between ``X[:,i]`` and ``X[:,j]``.
``pairwise(dist, X)`` is typically more efficient than ``pairwise(dist, X, X)``, as the former will take advantage of the symmetry when ``dist`` is a semi-metric (including metric).

For performance reasons, it is recommended to use matrices with observations in columns (as shown above). Indeed,
the ``Array`` type in Julia is column-major, making it more efficient to access memory column by column. However,
matrices with observations stored in rows are also supported via the argument ``dims=1``.

#### Computing column-wise and pairwise distances inplace

If the vector/matrix to store the results are pre-allocated, you may use the storage (without creating a new array) using the following syntax:
If the vector/matrix to store the results are pre-allocated, you may use the storage (without creating a new array) using the following syntax (``i`` being either ``1`` or ``2``):

```julia
colwise!(r, dist, X, Y)
pairwise!(R, dist, X, Y)
pairwise!(R, dist, X)
pairwise!(R, dist, X, Y, dims=i)
pairwise!(R, dist, X, dims=i)
```

Please pay attention to the difference, the functions for inplace computation are ``colwise!`` and ``pairwise!`` (instead of ``colwise`` and ``pairwise``).
Expand Down
11 changes: 7 additions & 4 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,15 @@ end
function get_pairwise_dims(r::AbstractMatrix, a::AbstractMatrix, b::AbstractMatrix)
ma, na = size(a)
mb, nb = size(b)
ma == mb || throw(DimensionMismatch("The numbers of rows in a and b must match."))
ma == mb || throw(DimensionMismatch("The numbers of rows or columns in a and b must match."))
size(r) == (na, nb) || throw(DimensionMismatch("Incorrect size of r."))
return (ma, na, nb)
end

function get_pairwise_dims(r::AbstractMatrix, a::AbstractMatrix)
m, n = size(a)
size(r) == (n, n) || throw(DimensionMismatch("Incorrect size of r."))
size(r) == (n, n) ||
throw(DimensionMismatch("Incorrect size of r (got $(size(r)), expected $((n, n)))."))
return (m, n)
end

Expand Down Expand Up @@ -92,8 +93,10 @@ end
###########################################################

function sqrt!(a::AbstractArray)
@simd for i in eachindex(a)
@inbounds a[i] = sqrt(a[i])
@inbounds @simd for i in eachindex(a)
x = a[i]
# > 0 is there to tolerate small accumulation errors
a[i] = x > 0 ? sqrt(x) : zero(x)
end
a
end
Expand Down
66 changes: 55 additions & 11 deletions src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ end

# Generic pairwise evaluation

function pairwise!(r::AbstractMatrix, metric::PreMetric, a::AbstractMatrix, b::AbstractMatrix)
function _pairwise!(::Val{2}, r::AbstractMatrix, metric::PreMetric,
a::AbstractMatrix, b::AbstractMatrix)
na = size(a, 2)
nb = size(b, 2)
size(r) == (na, nb) || throw(DimensionMismatch("Incorrect size of r."))
Expand All @@ -94,11 +95,11 @@ function pairwise!(r::AbstractMatrix, metric::PreMetric, a::AbstractMatrix, b::A
r
end

function pairwise!(r::AbstractMatrix, metric::PreMetric, a::AbstractMatrix)
pairwise!(r, metric, a, a)
function _pairwise!(::Val{2}, r::AbstractMatrix, metric::PreMetric, a::AbstractMatrix)
_pairwise!(Val(2), r, metric, a, b)
end

function pairwise!(r::AbstractMatrix, metric::SemiMetric, a::AbstractMatrix)
function _pairwise!(::Val{2}, r::AbstractMatrix, metric::SemiMetric, a::AbstractMatrix)
n = size(a, 2)
size(r) == (n, n) || throw(DimensionMismatch("Incorrect size of r."))
@inbounds for j = 1:n
Expand All @@ -114,15 +115,58 @@ function pairwise!(r::AbstractMatrix, metric::SemiMetric, a::AbstractMatrix)
r
end

function pairwise(metric::PreMetric, a::AbstractMatrix, b::AbstractMatrix)
m = size(a, 2)
n = size(b, 2)
function deprecated_dims(dims::Union{Nothing,Integer})
if dims === nothing
Base.depwarn("implicit `dims=2` argument now has to be passed explicitly " *
"to specify that distances between columns should be computed",
:pairwise!)
return 2
else
return dims
end
end

function _pairwise!(::Val{1}, r::AbstractMatrix, metric::PreMetric,
a::AbstractMatrix, b::AbstractMatrix=a)
_pairwise!(Val(2), r, metric, transpose(a), transpose(b))
end

function pairwise!(r::AbstractMatrix, metric::PreMetric,
a::AbstractMatrix, b::AbstractMatrix=a;
dims::Union{Nothing,Integer}=nothing)
dims = deprecated_dims(dims)
dims in (1, 2) || throw(ArgumentError("dims should be 1 or 2 (got $dims)"))
if dims == 1
na, ma = size(a)
nb, mb = size(b)
ma == mb || throw(DimensionMismatch("The numbers of columns in a and b " *
"must match (got $ma and $mb)."))
else
ma, na = size(a)
mb, nb = size(b)
ma == mb || throw(DimensionMismatch("The numbers of rows in a and b " *
"must match (got $ma and $mb)."))
end
size(r) == (na, nb) ||
throw(DimensionMismatch("Incorrect size of r (got $(size(r)), expected $((na, nb)))."))
_pairwise!(Val(dims), r, metric, a, b)
end

function pairwise(metric::PreMetric, a::AbstractMatrix, b::AbstractMatrix;
dims::Union{Nothing,Integer}=nothing)
dims = deprecated_dims(dims)
dims in (1, 2) || throw(ArgumentError("dims should be 1 or 2 (got $dims)"))
m = size(a, dims)
n = size(b, dims)
r = Matrix{result_type(metric, a, b)}(undef, m, n)
pairwise!(r, metric, a, b)
pairwise!(r, metric, a, b, dims=dims)
end

function pairwise(metric::PreMetric, a::AbstractMatrix)
n = size(a, 2)
function pairwise(metric::PreMetric, a::AbstractMatrix;
dims::Union{Nothing,Integer}=nothing)
dims = deprecated_dims(dims)
dims in (1, 2) || throw(ArgumentError("dims should be 1 or 2 (got $dims)"))
n = size(a, dims)
r = Matrix{result_type(metric, a, a)}(undef, n, n)
pairwise!(r, metric, a)
pairwise!(r, metric, a, dims=dims)
end
16 changes: 10 additions & 6 deletions src/mahalanobis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ function colwise!(r::AbstractArray, dist::SqMahalanobis{T}, a::AbstractVector, b
dot_percol!(r, Q * z, z)
end

function pairwise!(r::AbstractMatrix, dist::SqMahalanobis{T}, a::AbstractMatrix, b::AbstractMatrix) where {T <: Real}
function _pairwise!(::Val{2}, r::AbstractMatrix, dist::SqMahalanobis{T},
a::AbstractMatrix, b::AbstractMatrix) where {T <: Real}
Q = dist.qmat
m, na, nb = get_pairwise_dims(size(Q, 1), r, a, b)

Expand All @@ -58,7 +59,8 @@ function pairwise!(r::AbstractMatrix, dist::SqMahalanobis{T}, a::AbstractMatrix,
r
end

function pairwise!(r::AbstractMatrix, dist::SqMahalanobis{T}, a::AbstractMatrix) where {T <: Real}
function _pairwise!(::Val{2}, r::AbstractMatrix, dist::SqMahalanobis{T},
a::AbstractMatrix) where {T <: Real}
Q = dist.qmat
m, n = get_pairwise_dims(size(Q, 1), r, a)

Expand Down Expand Up @@ -95,10 +97,12 @@ function colwise!(r::AbstractArray, dist::Mahalanobis{T}, a::AbstractVector, b::
sqrt!(colwise!(r, SqMahalanobis(dist.qmat), a, b))
end

function pairwise!(r::AbstractMatrix, dist::Mahalanobis{T}, a::AbstractMatrix, b::AbstractMatrix) where {T <: Real}
sqrt!(pairwise!(r, SqMahalanobis(dist.qmat), a, b))
function _pairwise!(::Val{2}, r::AbstractMatrix, dist::Mahalanobis{T},
a::AbstractMatrix, b::AbstractMatrix) where {T <: Real}
sqrt!(_pairwise!(Val(2), r, SqMahalanobis(dist.qmat), a, b))
end

function pairwise!(r::AbstractMatrix, dist::Mahalanobis{T}, a::AbstractMatrix) where {T <: Real}
sqrt!(pairwise!(r, SqMahalanobis(dist.qmat), a))
function _pairwise!(::Val{2}, r::AbstractMatrix, dist::Mahalanobis{T},
a::AbstractMatrix) where {T <: Real}
sqrt!(_pairwise!(Val(2), r, SqMahalanobis(dist.qmat), a))
end
27 changes: 16 additions & 11 deletions src/metrics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,8 @@ cosine_dist(a::AbstractArray, b::AbstractArray) = evaluate(CosineDist(), a, b)
_centralize(x::AbstractArray) = x .- mean(x)
evaluate(::CorrDist, a::AbstractArray, b::AbstractArray) = cosine_dist(_centralize(a), _centralize(b))
# Ambiguity resolution
evaluate(::CorrDist, a::Array, b::Array) = cosine_dist(_centralize(a), _centralize(b))
evaluate(::CorrDist, a::Union{Array, ArraySlice}, b::Union{Array, ArraySlice}) =
cosine_dist(_centralize(a), _centralize(b))
corr_dist(a::AbstractArray, b::AbstractArray) = evaluate(CorrDist(), a, b)
result_type(::CorrDist, a::AbstractArray, b::AbstractArray) = result_type(CosineDist(), a, b)

Expand Down Expand Up @@ -462,7 +463,8 @@ nrmsd(a, b) = evaluate(NormRMSDeviation(), a, b)
###########################################################

# SqEuclidean
function pairwise!(r::AbstractMatrix, dist::SqEuclidean, a::AbstractMatrix, b::AbstractMatrix)
function _pairwise!(::Val{2}, r::AbstractMatrix, dist::SqEuclidean,
a::AbstractMatrix, b::AbstractMatrix)
mul!(r, a', b)
sa2 = sum(abs2, a, dims=1)
sb2 = sum(abs2, b, dims=1)
Expand Down Expand Up @@ -498,7 +500,7 @@ function pairwise!(r::AbstractMatrix, dist::SqEuclidean, a::AbstractMatrix, b::A
r
end

function pairwise!(r::AbstractMatrix, dist::SqEuclidean, a::AbstractMatrix)
function _pairwise!(::Val{2}, r::AbstractMatrix, dist::SqEuclidean, a::AbstractMatrix)
m, n = get_pairwise_dims(r, a)
mul!(r, a', a)
sa2 = sumsq_percol(a)
Expand Down Expand Up @@ -531,7 +533,8 @@ function pairwise!(r::AbstractMatrix, dist::SqEuclidean, a::AbstractMatrix)
end

# Euclidean
function pairwise!(r::AbstractMatrix, dist::Euclidean, a::AbstractMatrix, b::AbstractMatrix)
function _pairwise!(::Val{2}, r::AbstractMatrix, dist::Euclidean,
a::AbstractMatrix, b::AbstractMatrix)
m, na, nb = get_pairwise_dims(r, a, b)
mul!(r, a', b)
sa2 = sumsq_percol(a)
Expand All @@ -558,7 +561,7 @@ function pairwise!(r::AbstractMatrix, dist::Euclidean, a::AbstractMatrix, b::Abs
r
end

function pairwise!(r::AbstractMatrix, dist::Euclidean, a::AbstractMatrix)
function _pairwise!(::Val{2}, r::AbstractMatrix, dist::Euclidean, a::AbstractMatrix)
m, n = get_pairwise_dims(r, a)
mul!(r, a', a)
sa2 = sumsq_percol(a)
Expand Down Expand Up @@ -586,7 +589,8 @@ end

# CosineDist

function pairwise!(r::AbstractMatrix, dist::CosineDist, a::AbstractMatrix, b::AbstractMatrix)
function _pairwise!(::Val{2}, r::AbstractMatrix, dist::CosineDist,
a::AbstractMatrix, b::AbstractMatrix)
m, na, nb = get_pairwise_dims(r, a, b)
mul!(r, a', b)
ra = sqrt!(sumsq_percol(a))
Expand All @@ -598,7 +602,7 @@ function pairwise!(r::AbstractMatrix, dist::CosineDist, a::AbstractMatrix, b::Ab
end
r
end
function pairwise!(r::AbstractMatrix, dist::CosineDist, a::AbstractMatrix)
function _pairwise!(::Val{2}, r::AbstractMatrix, dist::CosineDist, a::AbstractMatrix)
m, n = get_pairwise_dims(r, a)
mul!(r, a', a)
ra = sqrt!(sumsq_percol(a))
Expand All @@ -623,9 +627,10 @@ end
function colwise!(r::AbstractVector, dist::CorrDist, a::AbstractVector, b::AbstractMatrix)
colwise!(r, CosineDist(), _centralize_colwise(a), _centralize_colwise(b))
end
function pairwise!(r::AbstractMatrix, dist::CorrDist, a::AbstractMatrix, b::AbstractMatrix)
pairwise!(r, CosineDist(), _centralize_colwise(a), _centralize_colwise(b))
function _pairwise!(::Val{2}, r::AbstractMatrix, dist::CorrDist,
a::AbstractMatrix, b::AbstractMatrix)
_pairwise!(Val(2), r, CosineDist(), _centralize_colwise(a), _centralize_colwise(b))
end
function pairwise!(r::AbstractMatrix, dist::CorrDist, a::AbstractMatrix)
pairwise!(r, CosineDist(), _centralize_colwise(a))
function _pairwise!(::Val{2}, r::AbstractMatrix, dist::CorrDist, a::AbstractMatrix)
_pairwise!(Val(2), r, CosineDist(), _centralize_colwise(a))
end
15 changes: 9 additions & 6 deletions src/wmetrics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ whamming(a::AbstractArray, b::AbstractArray, w::AbstractArray) = evaluate(Weight
###########################################################

# SqEuclidean
function pairwise!(r::AbstractMatrix, dist::WeightedSqEuclidean, a::AbstractMatrix, b::AbstractMatrix)
function _pairwise!(::Val{2}, r::AbstractMatrix, dist::WeightedSqEuclidean,
a::AbstractMatrix, b::AbstractMatrix)
w = dist.weights
m, na, nb = get_pairwise_dims(length(w), r, a, b)

Expand All @@ -131,7 +132,8 @@ function pairwise!(r::AbstractMatrix, dist::WeightedSqEuclidean, a::AbstractMatr
end
r
end
function pairwise!(r::AbstractMatrix, dist::WeightedSqEuclidean, a::AbstractMatrix)
function _pairwise!(::Val{2}, r::AbstractMatrix, dist::WeightedSqEuclidean,
a::AbstractMatrix)
w = dist.weights
m, n = get_pairwise_dims(length(w), r, a)

Expand All @@ -157,9 +159,10 @@ end
function colwise!(r::AbstractArray, dist::WeightedEuclidean, a::AbstractVector, b::AbstractMatrix)
sqrt!(colwise!(r, WeightedSqEuclidean(dist.weights), a, b))
end
function pairwise!(r::AbstractMatrix, dist::WeightedEuclidean, a::AbstractMatrix, b::AbstractMatrix)
sqrt!(pairwise!(r, WeightedSqEuclidean(dist.weights), a, b))
function _pairwise!(::Val{2}, r::AbstractMatrix, dist::WeightedEuclidean,
a::AbstractMatrix, b::AbstractMatrix)
sqrt!(_pairwise!(Val(2), r, WeightedSqEuclidean(dist.weights), a, b))
end
function pairwise!(r::AbstractMatrix, dist::WeightedEuclidean, a::AbstractMatrix)
sqrt!(pairwise!(r, WeightedSqEuclidean(dist.weights), a))
function _pairwise!(::Val{2}, r::AbstractMatrix, dist::WeightedEuclidean, a::AbstractMatrix)
sqrt!(_pairwise!(Val(2), r, WeightedSqEuclidean(dist.weights), a))
end
11 changes: 7 additions & 4 deletions test/test_dists.jl
Original file line number Diff line number Diff line change
Expand Up @@ -441,10 +441,13 @@ function test_pairwise(dist, x, y, T)
for j = 1:nx, i = 1:nx
rxx[i, j] = evaluate(dist, x[:, i], x[:, j])
end
# ≈ and all( .≈ ) seem to behave slightly differently for F64
# And, as earlier, we have small rounding errors in accumulations
@test all(pairwise(dist, x, y) .+ one(T) .≈ rxy .+ one(T))
@test all(pairwise(dist, x) .+ one(T) .≈ rxx .+ one(T))
# As earlier, we have small rounding errors in accumulations
@test pairwise(dist, x, y) rxy
@test pairwise(dist, x) rxx
@test pairwise(dist, x, y, dims=2) rxy
@test pairwise(dist, x, dims=2) rxx
@test pairwise(dist, permutedims(x), permutedims(y), dims=1) rxy
@test pairwise(dist, permutedims(x), dims=1) rxx
end
end

Expand Down

0 comments on commit af66c32

Please sign in to comment.