Skip to content

Commit

Permalink
Merge 73333aa into 785aaab
Browse files Browse the repository at this point in the history
  • Loading branch information
nalimilan committed Feb 12, 2019
2 parents 785aaab + 73333aa commit 1830d3c
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 47 deletions.
18 changes: 11 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ r = euclidean(x, y)

#### Computing distances between corresponding columns

Suppose you have two ``m-by-n`` matrix ``X`` and ``Y``, then you can compute all distances between corresponding columns of X and Y in one batch, using the ``colwise`` function, as
Suppose you have two ``m-by-n`` matrix ``X`` and ``Y``, then you can compute all distances between corresponding columns of ``X`` and ``Y`` in one batch, using the ``colwise`` function, as

```julia
r = colwise(dist, X, Y)
Expand All @@ -81,31 +81,35 @@ Note that either of ``X`` and ``Y`` can be just a single vector -- then the ``co

#### Computing pairwise distances

Let ``X`` and ``Y`` respectively have ``m`` and ``n`` columns. Then the ``pairwise`` function computes distances between each pair of columns in ``X`` and ``Y``:
Let ``X`` and ``Y`` respectively have ``m`` and ``n`` columns. Then the ``pairwise`` function with the ``dims=2`` argument computes distances between each pair of columns in ``X`` and ``Y``:

```julia
R = pairwise(dist, X, Y)
R = pairwise(dist, X, Y, dims=2)
```

In the output, ``R`` is a matrix of size ``(m, n)``, such that ``R[i,j]`` is the distance between ``X[:,i]`` and ``Y[:,j]``. Computing distances for all pairs using ``pairwise`` function is often remarkably faster than evaluting for each pair individually.

If you just want to just compute distances between columns of a matrix ``X``, you can write

```julia
R = pairwise(dist, X)
R = pairwise(dist, X, dims=2)
```

This statement will result in an ``m-by-m`` matrix, where ``R[i,j]`` is the distance between ``X[:,i]`` and ``X[:,j]``.
``pairwise(dist, X)`` is typically more efficient than ``pairwise(dist, X, X)``, as the former will take advantage of the symmetry when ``dist`` is a semi-metric (including metric).

For performance reasons, it is recommended to use matrices with observations in columns (as shown above). Indeed,
the ``Array`` type in Julia is column-major, making it more efficient to access memory column by column. However,
matrices with observations stored in rows are also supported via the argument ``dims=1``.

#### Computing column-wise and pairwise distances inplace

If the vector/matrix to store the results are pre-allocated, you may use the storage (without creating a new array) using the following syntax:
If the vector/matrix to store the results are pre-allocated, you may use the storage (without creating a new array) using the following syntax (``i`` being either ``1`` or ``2``):

```julia
colwise!(r, dist, X, Y)
pairwise!(R, dist, X, Y)
pairwise!(R, dist, X)
pairwise!(R, dist, X, Y, dims=i)
pairwise!(R, dist, X, dims=i)
```

Please pay attention to the difference, the functions for inplace computation are ``colwise!`` and ``pairwise!`` (instead of ``colwise`` and ``pairwise``).
Expand Down
2 changes: 1 addition & 1 deletion src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ end
function get_pairwise_dims(r::AbstractMatrix, a::AbstractMatrix, b::AbstractMatrix)
ma, na = size(a)
mb, nb = size(b)
ma == mb || throw(DimensionMismatch("The numbers of rows in a and b must match."))
ma == mb || throw(DimensionMismatch("The numbers of rows or columns in a and b must match."))
size(r) == (na, nb) || throw(DimensionMismatch("Incorrect size of r."))
return (ma, na, nb)
end
Expand Down
83 changes: 70 additions & 13 deletions src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ end

# Generic pairwise evaluation

function pairwise!(r::AbstractMatrix, metric::PreMetric, a::AbstractMatrix, b::AbstractMatrix)
function _pairwise!(r::AbstractMatrix, metric::PreMetric,
a::AbstractMatrix, b::AbstractMatrix=a)
na = size(a, 2)
nb = size(b, 2)
size(r) == (na, nb) || throw(DimensionMismatch("Incorrect size of r."))
Expand All @@ -94,11 +95,7 @@ function pairwise!(r::AbstractMatrix, metric::PreMetric, a::AbstractMatrix, b::A
r
end

function pairwise!(r::AbstractMatrix, metric::PreMetric, a::AbstractMatrix)
pairwise!(r, metric, a, a)
end

function pairwise!(r::AbstractMatrix, metric::SemiMetric, a::AbstractMatrix)
function _pairwise!(r::AbstractMatrix, metric::SemiMetric, a::AbstractMatrix)
n = size(a, 2)
size(r) == (n, n) || throw(DimensionMismatch("Incorrect size of r."))
@inbounds for j = 1:n
Expand All @@ -114,15 +111,75 @@ function pairwise!(r::AbstractMatrix, metric::SemiMetric, a::AbstractMatrix)
r
end

function pairwise(metric::PreMetric, a::AbstractMatrix, b::AbstractMatrix)
m = size(a, 2)
n = size(b, 2)
function deprecated_dims(dims::Union{Nothing,Integer})
if dims === nothing
Base.depwarn("implicit `dims=2` argument now has to be passed explicitly " *
"to specify that distances between columns should be computed",
:pairwise!)
return 2
else
return dims
end
end

function pairwise!(r::AbstractMatrix, metric::PreMetric,
a::AbstractMatrix, b::AbstractMatrix;
dims::Union{Nothing,Integer}=nothing)
dims = deprecated_dims(dims)
dims in (1, 2) || throw(ArgumentError("dims should be 1 or 2 (got $dims)"))
if dims == 1
na, ma = size(a)
nb, mb = size(b)
ma == mb || throw(DimensionMismatch("The numbers of columns in a and b " *
"must match (got $ma and $mb)."))
else
ma, na = size(a)
mb, nb = size(b)
ma == mb || throw(DimensionMismatch("The numbers of rows in a and b " *
"must match (got $ma and $mb)."))
end
size(r) == (na, nb) ||
throw(DimensionMismatch("Incorrect size of r (got $(size(r)), expected $((na, nb)))."))
if dims == 1
_pairwise!(r, metric, transpose(a), transpose(b))
else
_pairwise!(r, metric, a, b)
end
end

function pairwise!(r::AbstractMatrix, metric::PreMetric, a::AbstractMatrix;
dims::Union{Nothing,Integer}=nothing)
dims = deprecated_dims(dims)
dims in (1, 2) || throw(ArgumentError("dims should be 1 or 2 (got $dims)"))
if dims == 1
n, m = size(a)
else
m, n = size(a)
end
size(r) == (n, n) ||
throw(DimensionMismatch("Incorrect size of r (got $(size(r)), expected $((n, n)))."))
if dims == 1
_pairwise!(r, metric, transpose(a))
else
_pairwise!(r, metric, a)
end
end

function pairwise(metric::PreMetric, a::AbstractMatrix, b::AbstractMatrix;
dims::Union{Nothing,Integer}=nothing)
dims = deprecated_dims(dims)
dims in (1, 2) || throw(ArgumentError("dims should be 1 or 2 (got $dims)"))
m = size(a, dims)
n = size(b, dims)
r = Matrix{result_type(metric, a, b)}(undef, m, n)
pairwise!(r, metric, a, b)
pairwise!(r, metric, a, b, dims=dims)
end

function pairwise(metric::PreMetric, a::AbstractMatrix)
n = size(a, 2)
function pairwise(metric::PreMetric, a::AbstractMatrix;
dims::Union{Nothing,Integer}=nothing)
dims = deprecated_dims(dims)
dims in (1, 2) || throw(ArgumentError("dims should be 1 or 2 (got $dims)"))
n = size(a, dims)
r = Matrix{result_type(metric, a, a)}(undef, n, n)
pairwise!(r, metric, a)
pairwise!(r, metric, a, dims=dims)
end
16 changes: 10 additions & 6 deletions src/mahalanobis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ function colwise!(r::AbstractArray, dist::SqMahalanobis{T}, a::AbstractVector, b
dot_percol!(r, Q * z, z)
end

function pairwise!(r::AbstractMatrix, dist::SqMahalanobis{T}, a::AbstractMatrix, b::AbstractMatrix) where {T <: Real}
function _pairwise!(r::AbstractMatrix, dist::SqMahalanobis{T},
a::AbstractMatrix, b::AbstractMatrix) where {T <: Real}
Q = dist.qmat
m, na, nb = get_pairwise_dims(size(Q, 1), r, a, b)

Expand All @@ -58,7 +59,8 @@ function pairwise!(r::AbstractMatrix, dist::SqMahalanobis{T}, a::AbstractMatrix,
r
end

function pairwise!(r::AbstractMatrix, dist::SqMahalanobis{T}, a::AbstractMatrix) where {T <: Real}
function _pairwise!(r::AbstractMatrix, dist::SqMahalanobis{T},
a::AbstractMatrix) where {T <: Real}
Q = dist.qmat
m, n = get_pairwise_dims(size(Q, 1), r, a)

Expand Down Expand Up @@ -95,10 +97,12 @@ function colwise!(r::AbstractArray, dist::Mahalanobis{T}, a::AbstractVector, b::
sqrt!(colwise!(r, SqMahalanobis(dist.qmat), a, b))
end

function pairwise!(r::AbstractMatrix, dist::Mahalanobis{T}, a::AbstractMatrix, b::AbstractMatrix) where {T <: Real}
sqrt!(pairwise!(r, SqMahalanobis(dist.qmat), a, b))
function _pairwise!(r::AbstractMatrix, dist::Mahalanobis{T},
a::AbstractMatrix, b::AbstractMatrix) where {T <: Real}
sqrt!(_pairwise!(r, SqMahalanobis(dist.qmat), a, b))
end

function pairwise!(r::AbstractMatrix, dist::Mahalanobis{T}, a::AbstractMatrix) where {T <: Real}
sqrt!(pairwise!(r, SqMahalanobis(dist.qmat), a))
function _pairwise!(r::AbstractMatrix, dist::Mahalanobis{T},
a::AbstractMatrix) where {T <: Real}
sqrt!(_pairwise!(r, SqMahalanobis(dist.qmat), a))
end
24 changes: 14 additions & 10 deletions src/metrics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,8 @@ nrmsd(a, b) = evaluate(NormRMSDeviation(), a, b)
###########################################################

# SqEuclidean
function pairwise!(r::AbstractMatrix, dist::SqEuclidean, a::AbstractMatrix, b::AbstractMatrix)
function _pairwise!(r::AbstractMatrix, dist::SqEuclidean,
a::AbstractMatrix, b::AbstractMatrix)
mul!(r, a', b)
sa2 = sum(abs2, a, dims=1)
sb2 = sum(abs2, b, dims=1)
Expand Down Expand Up @@ -498,7 +499,7 @@ function pairwise!(r::AbstractMatrix, dist::SqEuclidean, a::AbstractMatrix, b::A
r
end

function pairwise!(r::AbstractMatrix, dist::SqEuclidean, a::AbstractMatrix)
function _pairwise!(r::AbstractMatrix, dist::SqEuclidean, a::AbstractMatrix)
m, n = get_pairwise_dims(r, a)
mul!(r, a', a)
sa2 = sumsq_percol(a)
Expand Down Expand Up @@ -531,7 +532,8 @@ function pairwise!(r::AbstractMatrix, dist::SqEuclidean, a::AbstractMatrix)
end

# Euclidean
function pairwise!(r::AbstractMatrix, dist::Euclidean, a::AbstractMatrix, b::AbstractMatrix)
function _pairwise!(r::AbstractMatrix, dist::Euclidean,
a::AbstractMatrix, b::AbstractMatrix)
m, na, nb = get_pairwise_dims(r, a, b)
mul!(r, a', b)
sa2 = sumsq_percol(a)
Expand All @@ -558,7 +560,7 @@ function pairwise!(r::AbstractMatrix, dist::Euclidean, a::AbstractMatrix, b::Abs
r
end

function pairwise!(r::AbstractMatrix, dist::Euclidean, a::AbstractMatrix)
function _pairwise!(r::AbstractMatrix, dist::Euclidean, a::AbstractMatrix)
m, n = get_pairwise_dims(r, a)
mul!(r, a', a)
sa2 = sumsq_percol(a)
Expand Down Expand Up @@ -586,7 +588,8 @@ end

# CosineDist

function pairwise!(r::AbstractMatrix, dist::CosineDist, a::AbstractMatrix, b::AbstractMatrix)
function _pairwise!(r::AbstractMatrix, dist::CosineDist,
a::AbstractMatrix, b::AbstractMatrix)
m, na, nb = get_pairwise_dims(r, a, b)
mul!(r, a', b)
ra = sqrt!(sumsq_percol(a))
Expand All @@ -598,7 +601,7 @@ function pairwise!(r::AbstractMatrix, dist::CosineDist, a::AbstractMatrix, b::Ab
end
r
end
function pairwise!(r::AbstractMatrix, dist::CosineDist, a::AbstractMatrix)
function _pairwise!(r::AbstractMatrix, dist::CosineDist, a::AbstractMatrix)
m, n = get_pairwise_dims(r, a)
mul!(r, a', a)
ra = sqrt!(sumsq_percol(a))
Expand All @@ -623,9 +626,10 @@ end
function colwise!(r::AbstractVector, dist::CorrDist, a::AbstractVector, b::AbstractMatrix)
colwise!(r, CosineDist(), _centralize_colwise(a), _centralize_colwise(b))
end
function pairwise!(r::AbstractMatrix, dist::CorrDist, a::AbstractMatrix, b::AbstractMatrix)
pairwise!(r, CosineDist(), _centralize_colwise(a), _centralize_colwise(b))
function _pairwise!(r::AbstractMatrix, dist::CorrDist,
a::AbstractMatrix, b::AbstractMatrix)
_pairwise!(r, CosineDist(), _centralize_colwise(a), _centralize_colwise(b))
end
function pairwise!(r::AbstractMatrix, dist::CorrDist, a::AbstractMatrix)
pairwise!(r, CosineDist(), _centralize_colwise(a))
function _pairwise!(r::AbstractMatrix, dist::CorrDist, a::AbstractMatrix)
_pairwise!(r, CosineDist(), _centralize_colwise(a))
end
15 changes: 9 additions & 6 deletions src/wmetrics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ whamming(a::AbstractArray, b::AbstractArray, w::AbstractArray) = evaluate(Weight
###########################################################

# SqEuclidean
function pairwise!(r::AbstractMatrix, dist::WeightedSqEuclidean, a::AbstractMatrix, b::AbstractMatrix)
function _pairwise!(r::AbstractMatrix, dist::WeightedSqEuclidean,
a::AbstractMatrix, b::AbstractMatrix)
w = dist.weights
m, na, nb = get_pairwise_dims(length(w), r, a, b)

Expand All @@ -131,7 +132,8 @@ function pairwise!(r::AbstractMatrix, dist::WeightedSqEuclidean, a::AbstractMatr
end
r
end
function pairwise!(r::AbstractMatrix, dist::WeightedSqEuclidean, a::AbstractMatrix)
function _pairwise!(r::AbstractMatrix, dist::WeightedSqEuclidean,
a::AbstractMatrix)
w = dist.weights
m, n = get_pairwise_dims(length(w), r, a)

Expand All @@ -157,9 +159,10 @@ end
function colwise!(r::AbstractArray, dist::WeightedEuclidean, a::AbstractVector, b::AbstractMatrix)
sqrt!(colwise!(r, WeightedSqEuclidean(dist.weights), a, b))
end
function pairwise!(r::AbstractMatrix, dist::WeightedEuclidean, a::AbstractMatrix, b::AbstractMatrix)
sqrt!(pairwise!(r, WeightedSqEuclidean(dist.weights), a, b))
function _pairwise!(r::AbstractMatrix, dist::WeightedEuclidean,
a::AbstractMatrix, b::AbstractMatrix)
sqrt!(_pairwise!(r, WeightedSqEuclidean(dist.weights), a, b))
end
function pairwise!(r::AbstractMatrix, dist::WeightedEuclidean, a::AbstractMatrix)
sqrt!(pairwise!(r, WeightedSqEuclidean(dist.weights), a))
function _pairwise!(r::AbstractMatrix, dist::WeightedEuclidean, a::AbstractMatrix)
sqrt!(_pairwise!(r, WeightedSqEuclidean(dist.weights), a))
end
11 changes: 7 additions & 4 deletions test/test_dists.jl
Original file line number Diff line number Diff line change
Expand Up @@ -441,10 +441,13 @@ function test_pairwise(dist, x, y, T)
for j = 1:nx, i = 1:nx
rxx[i, j] = evaluate(dist, x[:, i], x[:, j])
end
# ≈ and all( .≈ ) seem to behave slightly differently for F64
# And, as earlier, we have small rounding errors in accumulations
@test all(pairwise(dist, x, y) .+ one(T) .≈ rxy .+ one(T))
@test all(pairwise(dist, x) .+ one(T) .≈ rxx .+ one(T))
# As earlier, we have small rounding errors in accumulations
@test pairwise(dist, x, y) rxy
@test pairwise(dist, x) rxx
@test pairwise(dist, x, y, dims=2) rxy
@test pairwise(dist, x, dims=2) rxx
@test pairwise(dist, permutedims(x), permutedims(y), dims=1) rxy
@test pairwise(dist, permutedims(x), dims=1) rxx
end
end

Expand Down

0 comments on commit 1830d3c

Please sign in to comment.