From 16e6e429e389a11d8cc90f95166afca454765a0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Mon, 26 Oct 2020 14:51:11 +0100 Subject: [PATCH 1/4] Added generic function for abstractvectors --- src/generic.jl | 67 ++++++++++++++++++++++++++++++++++++++++++++-- test/test_dists.jl | 4 +++ 2 files changed, 69 insertions(+), 2 deletions(-) diff --git a/src/generic.jl b/src/generic.jl index 1af118e..a5c624f 100644 --- a/src/generic.jl +++ b/src/generic.jl @@ -33,8 +33,8 @@ Infer the result type of metric `dist` with input type `Ta` and `Tb`, or input data `a` and `b`. """ result_type(::PreMetric, ::Type, ::Type) = Float64 # fallback -result_type(dist::PreMetric, a::AbstractArray, b::AbstractArray) = result_type(dist, eltype(a), eltype(b)) - +result_type(dist::PreMetric, a::AbstractArray{<:Number}, b::AbstractArray{<:Number}) = result_type(dist, eltype(a), eltype(b)) +result_type(dist::PreMetric, a::AbstractArray, b::AbstractArray) = result_type(dist, eltype(first(a)), eltype(first(b))) # Generic column-wise evaluation @@ -120,6 +120,35 @@ function _pairwise!(r::AbstractMatrix, metric::SemiMetric, a::AbstractMatrix) r end +function _pairwise!(r::AbstractMatrix, metric::PreMetric, + a::AbstractVector, b::AbstractVector=a) + na = length(a) + nb = length(b) + size(r) == (na, nb) || throw(DimensionMismatch("Incorrect size of r.")) + @inbounds for (j, bj) = enumerate(b) + for (i, ai) = enumerate(a) + r[i, j] = metric(ai, bj) + end + end + r +end + +function _pairwise!(r::AbstractMatrix, metric::SemiMetric, a::AbstractVector) + n = length(a) + size(r) == (n, n) || throw(DimensionMismatch("Incorrect size of r.")) + @inbounds for j = 1:n + for i = (j + 1):n + r[i, j] = metric(a[i], a[j]) + end + r[j, j] = 0 + for i = 1:(j - 1) + r[i, j] = r[j, i] # leveraging the symmetry of SemiMetric + end + end + r +end + + function deprecated_dims(dims::Union{Nothing,Integer}) if dims === nothing Base.depwarn("implicit `dims=2` argument now has to be passed explicitly " * @@ -185,8 +214,27 @@ function pairwise!(r::AbstractMatrix, metric::PreMetric, a::AbstractMatrix; end end +function pairwise!(r::AbstractMatrix, metric::PreMetric, + a::AbstractVector, b::AbstractVector; + dims::Union{Nothing,Integer}=nothing) + na = length(a) + nb = length(b) + size(r) == (na, nb) || + throw(DimensionMismatch("Incorrect size of r (got $(size(r)), expected $((na, nb))).")) + _pairwise!(r, metric, a, b) +end + +function pairwise!(r::AbstractMatrix, metric::PreMetric, a::AbstractVector; + dims::Union{Nothing,Integer}=nothing) + n = length(a) + size(r) == (n, n) || + throw(DimensionMismatch("Incorrect size of r (got $(size(r)), expected $((n, n))).")) + _pairwise!(r, metric, a) +end + """ pairwise(metric::PreMetric, a::AbstractMatrix, b::AbstractMatrix=a; dims) + pairwise(metric::PreMetric, a::AbstractVector, b::AbstractVector=a; dims) Compute distances between each pair of rows (if `dims=1`) or columns (if `dims=2`) in `a` and `b` according to distance `metric`. If a single matrix `a` is provided, @@ -212,3 +260,18 @@ function pairwise(metric::PreMetric, a::AbstractMatrix; r = Matrix{result_type(metric, a, a)}(undef, n, n) pairwise!(r, metric, a, dims=dims) end + + +function pairwise(metric::PreMetric, a::AbstractVector, b::AbstractVector) + m = length(a) + n = length(b) + r = Matrix{result_type(metric, a, b)}(undef, m, n) + pairwise!(r, metric, a, b) +end + + +function pairwise(metric::PreMetric, a::AbstractVector) + m = length(a) + r = Matrix{result_type(metric, a, a)}(undef, m, n) + pairwise!(r, metric, a, b) +end \ No newline at end of file diff --git a/test/test_dists.jl b/test/test_dists.jl index 53c4b9a..789e320 100644 --- a/test/test_dists.jl +++ b/test/test_dists.jl @@ -504,6 +504,8 @@ function test_pairwise(dist, x, y, T) for j = 1:nx, i = 1:nx rxx[i, j] = dist(x[:, i], x[:, j]) end + vecx = collect(eachcol(x)) + vecy = collect(eachcol(y)) # As earlier, we have small rounding errors in accumulations @test pairwise(dist, x, y, dims=2) ≈ rxy @test pairwise(dist, x, dims=2) ≈ rxx @@ -511,6 +513,8 @@ function test_pairwise(dist, x, y, T) @test pairwise(dist, x, dims=2) ≈ rxx @test pairwise(dist, permutedims(x), permutedims(y), dims=1) ≈ rxy @test pairwise(dist, permutedims(x), dims=1) ≈ rxx + @test pairwise(dist, vecx, vecy) ≈ rxy + @test pairwise(dist, vecx) ≈ rxx end end From f43ceaac2040290082aae23ad8e5eef02d76b95e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Mon, 26 Oct 2020 15:16:51 +0100 Subject: [PATCH 2/4] Solving ambiguities --- src/metrics.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/metrics.jl b/src/metrics.jl index 3ddd284..80574ed 100644 --- a/src/metrics.jl +++ b/src/metrics.jl @@ -487,8 +487,10 @@ js_divergence(a::AbstractArray, b::AbstractArray) = JSDivergence()(a, b) # SpanNormDist -result_type(dist::SpanNormDist, a::AbstractArray, b::AbstractArray) = +result_type(dist::SpanNormDist, a::AbstractArray{<:Number}, b::AbstractArray{<:Number}) = typeof(eval_op(dist, oneunit(eltype(a)), oneunit(eltype(b)))) +result_type(dist::SpanNormDist, a::AbstractArray, b::AbstractArray) = + typeof(eval_op(dist, oneunit(eltype(first(a))), oneunit(eltype(first(b))))) Base.@propagate_inbounds function eval_start(::SpanNormDist, a::AbstractArray, b::AbstractArray) a[1] - b[1], a[1] - b[1] end From cbec7d1bd6bd3c10bf436b6f2909980c32fa3997 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Mon, 26 Oct 2020 15:37:57 +0100 Subject: [PATCH 3/4] Fixed docs and error --- src/generic.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/generic.jl b/src/generic.jl index a5c624f..85ad890 100644 --- a/src/generic.jl +++ b/src/generic.jl @@ -163,6 +163,8 @@ end """ pairwise!(r::AbstractMatrix, metric::PreMetric, a::AbstractMatrix, b::AbstractMatrix=a; dims) + pairwise!(r::AbstractMatrix, metric::PreMetric, + a::AbstractVector, b::AbstractVector=a) Compute distances between each pair of rows (if `dims=1`) or columns (if `dims=2`) in `a` and `b` according to distance `metric`, and store the result in `r`. @@ -271,7 +273,7 @@ end function pairwise(metric::PreMetric, a::AbstractVector) - m = length(a) - r = Matrix{result_type(metric, a, a)}(undef, m, n) - pairwise!(r, metric, a, b) + n = length(a) + r = Matrix{result_type(metric, a, a)}(undef, n, n) + pairwise!(r, metric, a) end \ No newline at end of file From 266d52134a03470db4a617c6c13e8bd4d9177b38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Mon, 26 Oct 2020 15:56:25 +0100 Subject: [PATCH 4/4] Replaced eachcol --- test/test_dists.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_dists.jl b/test/test_dists.jl index 789e320..ed41399 100644 --- a/test/test_dists.jl +++ b/test/test_dists.jl @@ -504,8 +504,8 @@ function test_pairwise(dist, x, y, T) for j = 1:nx, i = 1:nx rxx[i, j] = dist(x[:, i], x[:, j]) end - vecx = collect(eachcol(x)) - vecy = collect(eachcol(y)) + vecx = collect(x[:, i] for i in 1:nx) + vecy = collect(y[:, i] for i in 1:ny) # As earlier, we have small rounding errors in accumulations @test pairwise(dist, x, y, dims=2) ≈ rxy @test pairwise(dist, x, dims=2) ≈ rxx