From a4c8610b967d9938ea8d87e4991260428e7d5315 Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Mon, 19 Oct 2020 16:12:50 +0200 Subject: [PATCH 1/7] support multivariate kNN regression --- src/NearestNeighbors.jl | 2 +- test/NearestNeighbors.jl | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/NearestNeighbors.jl b/src/NearestNeighbors.jl index 023a89a9..110c9ec6 100644 --- a/src/NearestNeighbors.jl +++ b/src/NearestNeighbors.jl @@ -128,7 +128,7 @@ end function MMI.predict(m::KNNRegressor, (tree, y, w), X) Xmatrix = MMI.matrix(X, transpose=true) # NOTE: copies the data idxs, dists = NN.knn(tree, Xmatrix, m.K) - preds = zeros(length(idxs)) + preds = similar(y, length(idxs)) w_ = ones(m.K) diff --git a/test/NearestNeighbors.jl b/test/NearestNeighbors.jl index 30dca37f..020f7ab2 100644 --- a/test/NearestNeighbors.jl +++ b/test/NearestNeighbors.jl @@ -109,6 +109,20 @@ p2 = predict(knnr, f2, xtest) @test all(p[ntest+1:2*ntest] .≈ 2.0) @test all(p[2*ntest+1:end] .≈ -2.0) +y1v = fill( [0.0], n) +y2v = fill( [2.0], n) +y3v = fill([-2.0], n) + +yv = vcat(y1v, y2v, y3v) + +fv,_,_ = fit(knnr, 1, x, yv) +f2v,_,_ = fit(knnr, 1, x, yv, w) + +pv = predict(knnr, fv, xtest) + +@test all(pv[1:ntest] .≈ [[0.0]]) +@test all(pv[ntest+1:2*ntest] .≈ [[2.0]]) +@test all(pv[2*ntest+1:end] .≈ [[-2.0]]) From 0b8c3bd4d613862acd57b40df2c16ada7bde8f0e Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Tue, 20 Oct 2020 20:37:36 +0200 Subject: [PATCH 2/7] updated target of kNN regressor --- src/NearestNeighbors.jl | 2 +- test/NearestNeighbors.jl | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/NearestNeighbors.jl b/src/NearestNeighbors.jl index 110c9ec6..65f7b848 100644 --- a/src/NearestNeighbors.jl +++ b/src/NearestNeighbors.jl @@ -161,7 +161,7 @@ metadata_pkg.((KNNRegressor, KNNClassifier), metadata_model(KNNRegressor, input = Table(Continuous), - target = AbstractVector{Continuous}, + target = Union{Table(Continuous), AbstractVector{Continuous}}, weights = true, descr = KNNRegressorDescription ) diff --git a/test/NearestNeighbors.jl b/test/NearestNeighbors.jl index 020f7ab2..e62a5b35 100644 --- a/test/NearestNeighbors.jl +++ b/test/NearestNeighbors.jl @@ -142,8 +142,7 @@ infos[:docstring] infos = info_dict(knnr) @test infos[:input_scitype] == Table(Continuous) -@test infos[:target_scitype] == AbstractVector{Continuous} - +@test infos[:target_scitype] == Union{Table(Continuous), AbstractVector{Continuous}} infos[:docstring] end From 865a35200e858aa396f0dc807b7bcc568f0dd6e2 Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Tue, 20 Oct 2020 21:39:33 +0200 Subject: [PATCH 3/7] changing target of KNNRegressor --- src/NearestNeighbors.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/NearestNeighbors.jl b/src/NearestNeighbors.jl index 65f7b848..9ee57abc 100644 --- a/src/NearestNeighbors.jl +++ b/src/NearestNeighbors.jl @@ -161,7 +161,7 @@ metadata_pkg.((KNNRegressor, KNNClassifier), metadata_model(KNNRegressor, input = Table(Continuous), - target = Union{Table(Continuous), AbstractVector{Continuous}}, + target = Union{AbstractVector{Continuous}, AbstractVector{<:AbstractVector{Continuous}}}, weights = true, descr = KNNRegressorDescription ) From 3062a05d3de93298337b34b4abe371d7102b0d34 Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Tue, 20 Oct 2020 21:43:43 +0200 Subject: [PATCH 4/7] target of kNN regressor again --- src/NearestNeighbors.jl | 2 +- test/NearestNeighbors.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/NearestNeighbors.jl b/src/NearestNeighbors.jl index 9ee57abc..ae6809ff 100644 --- a/src/NearestNeighbors.jl +++ b/src/NearestNeighbors.jl @@ -161,7 +161,7 @@ metadata_pkg.((KNNRegressor, KNNClassifier), metadata_model(KNNRegressor, input = Table(Continuous), - target = Union{AbstractVector{Continuous}, AbstractVector{<:AbstractVector{Continuous}}}, + target = Union{AbstractVector{Continuous}, AbstractVector{<:AbstractArray{Continuous}}}, weights = true, descr = KNNRegressorDescription ) diff --git a/test/NearestNeighbors.jl b/test/NearestNeighbors.jl index e62a5b35..c5cfdb9c 100644 --- a/test/NearestNeighbors.jl +++ b/test/NearestNeighbors.jl @@ -142,7 +142,7 @@ infos[:docstring] infos = info_dict(knnr) @test infos[:input_scitype] == Table(Continuous) -@test infos[:target_scitype] == Union{Table(Continuous), AbstractVector{Continuous}} +@test infos[:target_scitype] == Union{AbstractVector{Continuous}, AbstractVector{<:AbstractArray{Continuous}}} infos[:docstring] end From e52ea0aafaef89543159b608528c9f9d120926b1 Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Wed, 21 Oct 2020 15:44:30 +0200 Subject: [PATCH 5/7] trying to make the multi-target kNN regressor work with tables --- src/NearestNeighbors.jl | 26 ++++++++++++++++++++------ test/NearestNeighbors.jl | 16 +++++++--------- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/src/NearestNeighbors.jl b/src/NearestNeighbors.jl index ae6809ff..3b810e1d 100644 --- a/src/NearestNeighbors.jl +++ b/src/NearestNeighbors.jl @@ -128,26 +128,40 @@ end function MMI.predict(m::KNNRegressor, (tree, y, w), X) Xmatrix = MMI.matrix(X, transpose=true) # NOTE: copies the data idxs, dists = NN.knn(tree, Xmatrix, m.K) - preds = similar(y, length(idxs)) + + if typeof(y) <: AbstractVector + ymat = reshape(y, length(y), 1) + preds = similar(ymat, length(idxs), 1) + else # for multi-target prediction + ymat = MMI.matrix(y) + preds = similar(ymat, length(idxs), size(ymat, 2)) + end w_ = ones(m.K) for i in eachindex(idxs) idxs_ = idxs[i] + println(idxs_) dists_ = dists[i] - values = y[idxs_] + values = [ymat[j,:] for j in idxs_] if w !== nothing w_ = w[idxs_] end + println(preds) if m.weights == :uniform - preds[i] = sum(values .* w_) / sum(w_) + preds[i,:] .= sum(values .* w_) / sum(w_) else - preds[i] = sum(values .* w_ .* (1.0 .- dists_ ./ sum(dists_))) / (sum(w_) - 1) + preds[i,:] .= sum(values .* w_ .* (1.0 .- dists_ ./ sum(dists_))) / (sum(w_) - 1) end end - return preds + if typeof(x) <: AbstractArray + return preds + else + return MMI.table(preds, names=Tables.schema(y).names, prototype=y) + end end + # ==== metadata_pkg.((KNNRegressor, KNNClassifier), @@ -161,7 +175,7 @@ metadata_pkg.((KNNRegressor, KNNClassifier), metadata_model(KNNRegressor, input = Table(Continuous), - target = Union{AbstractVector{Continuous}, AbstractVector{<:AbstractArray{Continuous}}}, + target = Union{AbstractVector{Continuous}, Table{Continuous}}, weights = true, descr = KNNRegressorDescription ) diff --git a/test/NearestNeighbors.jl b/test/NearestNeighbors.jl index c5cfdb9c..5d39dbec 100644 --- a/test/NearestNeighbors.jl +++ b/test/NearestNeighbors.jl @@ -109,21 +109,19 @@ p2 = predict(knnr, f2, xtest) @test all(p[ntest+1:2*ntest] .≈ 2.0) @test all(p[2*ntest+1:end] .≈ -2.0) -y1v = fill( [0.0], n) -y2v = fill( [2.0], n) -y3v = fill([-2.0], n) - -yv = vcat(y1v, y2v, y3v) +ymat = vcat(fill( 0.0, n, 2), fill(2.0, n, 2), fill(-2.0, n, 2)) +yv = Tables.table(ymat; header = [:a, :b]) fv,_,_ = fit(knnr, 1, x, yv) f2v,_,_ = fit(knnr, 1, x, yv, w) pv = predict(knnr, fv, xtest) -@test all(pv[1:ntest] .≈ [[0.0]]) -@test all(pv[ntest+1:2*ntest] .≈ [[2.0]]) -@test all(pv[2*ntest+1:end] .≈ [[-2.0]]) - +for col in [:a, :b] + @test all(pv[col][1:ntest] .≈ [0.0]) + @test all(pv[col][ntest+1:2*ntest] .≈ [2.0]) + @test all(pv[col][2*ntest+1:end] .≈ [-2.0]) +end From a794ade31464a255a10a228aa102f529a209c23a Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Wed, 21 Oct 2020 20:04:01 +0200 Subject: [PATCH 6/7] fixing kNN regressor --- src/NearestNeighbors.jl | 5 ++--- test/NearestNeighbors.jl | 3 ++- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/NearestNeighbors.jl b/src/NearestNeighbors.jl index 3b810e1d..ea15ebb4 100644 --- a/src/NearestNeighbors.jl +++ b/src/NearestNeighbors.jl @@ -10,6 +10,7 @@ const MMI = MLJModelInterface using Distances import ..NearestNeighbors +import ..Tables const NN = NearestNeighbors @@ -141,20 +142,18 @@ function MMI.predict(m::KNNRegressor, (tree, y, w), X) for i in eachindex(idxs) idxs_ = idxs[i] - println(idxs_) dists_ = dists[i] values = [ymat[j,:] for j in idxs_] if w !== nothing w_ = w[idxs_] end - println(preds) if m.weights == :uniform preds[i,:] .= sum(values .* w_) / sum(w_) else preds[i,:] .= sum(values .* w_ .* (1.0 .- dists_ ./ sum(dists_))) / (sum(w_) - 1) end end - if typeof(x) <: AbstractArray + if typeof(y) <: AbstractArray return preds else return MMI.table(preds, names=Tables.schema(y).names, prototype=y) diff --git a/test/NearestNeighbors.jl b/test/NearestNeighbors.jl index 5d39dbec..6ae0a8a0 100644 --- a/test/NearestNeighbors.jl +++ b/test/NearestNeighbors.jl @@ -7,6 +7,7 @@ using MLJModels.NearestNeighbors_ using CategoricalArrays using MLJBase using Random +using Tables Random.seed!(5151) @@ -140,7 +141,7 @@ infos[:docstring] infos = info_dict(knnr) @test infos[:input_scitype] == Table(Continuous) -@test infos[:target_scitype] == Union{AbstractVector{Continuous}, AbstractVector{<:AbstractArray{Continuous}}} +@test infos[:target_scitype] == Union{AbstractVector{Continuous}, Table{Continuous}} infos[:docstring] end From 6f280403fe0816fc30cacb760cee5c27ca1b2ab2 Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Thu, 22 Oct 2020 10:20:37 +0200 Subject: [PATCH 7/7] code review fixes --- src/NearestNeighbors.jl | 4 ++-- test/NearestNeighbors.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/NearestNeighbors.jl b/src/NearestNeighbors.jl index ea15ebb4..60b31753 100644 --- a/src/NearestNeighbors.jl +++ b/src/NearestNeighbors.jl @@ -153,7 +153,7 @@ function MMI.predict(m::KNNRegressor, (tree, y, w), X) preds[i,:] .= sum(values .* w_ .* (1.0 .- dists_ ./ sum(dists_))) / (sum(w_) - 1) end end - if typeof(y) <: AbstractArray + if typeof(y) <: AbstractVector return preds else return MMI.table(preds, names=Tables.schema(y).names, prototype=y) @@ -174,7 +174,7 @@ metadata_pkg.((KNNRegressor, KNNClassifier), metadata_model(KNNRegressor, input = Table(Continuous), - target = Union{AbstractVector{Continuous}, Table{Continuous}}, + target = Union{AbstractVector{Continuous}, Table(Continuous)}, weights = true, descr = KNNRegressorDescription ) diff --git a/test/NearestNeighbors.jl b/test/NearestNeighbors.jl index 6ae0a8a0..4048499a 100644 --- a/test/NearestNeighbors.jl +++ b/test/NearestNeighbors.jl @@ -141,7 +141,7 @@ infos[:docstring] infos = info_dict(knnr) @test infos[:input_scitype] == Table(Continuous) -@test infos[:target_scitype] == Union{AbstractVector{Continuous}, Table{Continuous}} +@test infos[:target_scitype] == Union{AbstractVector{Continuous}, Table(Continuous)} infos[:docstring] end