Skip to content

Commit

Permalink
Merge 266d521 into 068f820
Browse files Browse the repository at this point in the history
  • Loading branch information
theogf committed Oct 26, 2020
2 parents 068f820 + 266d521 commit 94301d4
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 3 deletions.
69 changes: 67 additions & 2 deletions src/generic.jl
Expand Up @@ -33,8 +33,8 @@ Infer the result type of metric `dist` with input type `Ta` and `Tb`, or input
data `a` and `b`.
"""
result_type(::PreMetric, ::Type, ::Type) = Float64 # fallback
result_type(dist::PreMetric, a::AbstractArray, b::AbstractArray) = result_type(dist, eltype(a), eltype(b))

result_type(dist::PreMetric, a::AbstractArray{<:Number}, b::AbstractArray{<:Number}) = result_type(dist, eltype(a), eltype(b))
result_type(dist::PreMetric, a::AbstractArray, b::AbstractArray) = result_type(dist, eltype(first(a)), eltype(first(b)))

# Generic column-wise evaluation

Expand Down Expand Up @@ -120,6 +120,35 @@ function _pairwise!(r::AbstractMatrix, metric::SemiMetric, a::AbstractMatrix)
r
end

function _pairwise!(r::AbstractMatrix, metric::PreMetric,
a::AbstractVector, b::AbstractVector=a)
na = length(a)
nb = length(b)
size(r) == (na, nb) || throw(DimensionMismatch("Incorrect size of r."))
@inbounds for (j, bj) = enumerate(b)
for (i, ai) = enumerate(a)
r[i, j] = metric(ai, bj)
end
end
r
end

function _pairwise!(r::AbstractMatrix, metric::SemiMetric, a::AbstractVector)
n = length(a)
size(r) == (n, n) || throw(DimensionMismatch("Incorrect size of r."))
@inbounds for j = 1:n
for i = (j + 1):n
r[i, j] = metric(a[i], a[j])
end
r[j, j] = 0
for i = 1:(j - 1)
r[i, j] = r[j, i] # leveraging the symmetry of SemiMetric
end
end
r
end


function deprecated_dims(dims::Union{Nothing,Integer})
if dims === nothing
Base.depwarn("implicit `dims=2` argument now has to be passed explicitly " *
Expand All @@ -134,6 +163,8 @@ end
"""
pairwise!(r::AbstractMatrix, metric::PreMetric,
a::AbstractMatrix, b::AbstractMatrix=a; dims)
pairwise!(r::AbstractMatrix, metric::PreMetric,
a::AbstractVector, b::AbstractVector=a)
Compute distances between each pair of rows (if `dims=1`) or columns (if `dims=2`)
in `a` and `b` according to distance `metric`, and store the result in `r`.
Expand Down Expand Up @@ -185,8 +216,27 @@ function pairwise!(r::AbstractMatrix, metric::PreMetric, a::AbstractMatrix;
end
end

function pairwise!(r::AbstractMatrix, metric::PreMetric,
a::AbstractVector, b::AbstractVector;
dims::Union{Nothing,Integer}=nothing)
na = length(a)
nb = length(b)
size(r) == (na, nb) ||
throw(DimensionMismatch("Incorrect size of r (got $(size(r)), expected $((na, nb)))."))
_pairwise!(r, metric, a, b)
end

function pairwise!(r::AbstractMatrix, metric::PreMetric, a::AbstractVector;
dims::Union{Nothing,Integer}=nothing)
n = length(a)
size(r) == (n, n) ||
throw(DimensionMismatch("Incorrect size of r (got $(size(r)), expected $((n, n)))."))
_pairwise!(r, metric, a)
end

"""
pairwise(metric::PreMetric, a::AbstractMatrix, b::AbstractMatrix=a; dims)
pairwise(metric::PreMetric, a::AbstractVector, b::AbstractVector=a; dims)
Compute distances between each pair of rows (if `dims=1`) or columns (if `dims=2`)
in `a` and `b` according to distance `metric`. If a single matrix `a` is provided,
Expand All @@ -212,3 +262,18 @@ function pairwise(metric::PreMetric, a::AbstractMatrix;
r = Matrix{result_type(metric, a, a)}(undef, n, n)
pairwise!(r, metric, a, dims=dims)
end


function pairwise(metric::PreMetric, a::AbstractVector, b::AbstractVector)
m = length(a)
n = length(b)
r = Matrix{result_type(metric, a, b)}(undef, m, n)
pairwise!(r, metric, a, b)
end


function pairwise(metric::PreMetric, a::AbstractVector)
n = length(a)
r = Matrix{result_type(metric, a, a)}(undef, n, n)
pairwise!(r, metric, a)
end
4 changes: 3 additions & 1 deletion src/metrics.jl
Expand Up @@ -487,8 +487,10 @@ js_divergence(a::AbstractArray, b::AbstractArray) = JSDivergence()(a, b)

# SpanNormDist

result_type(dist::SpanNormDist, a::AbstractArray, b::AbstractArray) =
result_type(dist::SpanNormDist, a::AbstractArray{<:Number}, b::AbstractArray{<:Number}) =
typeof(eval_op(dist, oneunit(eltype(a)), oneunit(eltype(b))))
result_type(dist::SpanNormDist, a::AbstractArray, b::AbstractArray) =
typeof(eval_op(dist, oneunit(eltype(first(a))), oneunit(eltype(first(b)))))
Base.@propagate_inbounds function eval_start(::SpanNormDist, a::AbstractArray, b::AbstractArray)
a[1] - b[1], a[1] - b[1]
end
Expand Down
4 changes: 4 additions & 0 deletions test/test_dists.jl
Expand Up @@ -504,13 +504,17 @@ function test_pairwise(dist, x, y, T)
for j = 1:nx, i = 1:nx
rxx[i, j] = dist(x[:, i], x[:, j])
end
vecx = collect(x[:, i] for i in 1:nx)
vecy = collect(y[:, i] for i in 1:ny)
# As earlier, we have small rounding errors in accumulations
@test pairwise(dist, x, y, dims=2) rxy
@test pairwise(dist, x, dims=2) rxx
@test pairwise(dist, x, y, dims=2) rxy
@test pairwise(dist, x, dims=2) rxx
@test pairwise(dist, permutedims(x), permutedims(y), dims=1) rxy
@test pairwise(dist, permutedims(x), dims=1) rxx
@test pairwise(dist, vecx, vecy) rxy
@test pairwise(dist, vecx) rxx
end
end

Expand Down

0 comments on commit 94301d4

Please sign in to comment.