Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bug in predicting with KNN models for :brutetree algorithm and metric. #47

Merged
merged 5 commits into from
Jan 21, 2022

Conversation

OkonSamuel
Copy link
Member

@OkonSamuel OkonSamuel commented Jan 20, 2022

julia> using MLJ, NearestNeighborModels, DataFrames

julia> X_rand = rand(50,3);

julia> coef = [1; 0; 1];

julia> y_rand = X_rand*coef + rand(50,1);

julia> X_rand_df = DataFrames.DataFrame(X_rand, :auto) #convert to dataframe for MLJ inputs

julia> y_rand_df = DataFrames.DataFrame(y_rand, :auto);

julia> train_idx, test_idx = partition(eachindex(y_rand_df[:,1]), 0.7, shuffle=true);

julia> tree1 = KNNRegressor(algorithm=:brutetree, metric= Euclidean(), reorder = false, leafsize=0)
KNNRegressor(
    K = 5,
    algorithm = :brutetree,
    metric = Euclidean(0.0),
    leafsize = 0,
    reorder = false,
    weights = Uniform())

julia> tree2 = KNNRegressor(algorithm=:brutetree, metric= Cityblock(), reorder = false, leafsize=0)
KNNRegressor(
    K = 5,
    algorithm = :brutetree,
    metric = Cityblock(),
    leafsize = 0,
    reorder = false,
    weights = Uniform())

julia> tree3 = KNNRegressor(algorithm=:brutetree)
KNNRegressor(
    K = 5,
    algorithm = :brutetree,
    metric = Euclidean(0.0),
    leafsize = 0,
    reorder = false,
    weights = Uniform())

julia> mach1 = machine(tree1,  X_rand_df[train_idx, :], y_rand_df[train_idx, 1]);

julia> mach2 = machine(tree2,  X_rand_df[train_idx, :], y_rand_df[train_idx, 1]);

julia> mach3 = machine(tree3,  X_rand_df[train_idx, :], y_rand_df[train_idx, 1]);

julia> fit!(mach1)
[ Info: Training Machine{KNNRegressor,}.
Machine{KNNRegressor,} trained 1 time; caches data
  model: KNNRegressor
  args:
    1:  Source @344`Table{AbstractVector{Continuous}}`
    2:  Source @234`AbstractVector{Continuous}`


julia> fit!(mach2)
[ Info: Training Machine{KNNRegressor,}.
Machine{KNNRegressor,} trained 1 time; caches data
  model: KNNRegressor
  args:
    1:  Source @582`Table{AbstractVector{Continuous}}`
    2:  Source @640`AbstractVector{Continuous}`

julia> fit!(mach3)
[ Info: Training Machine{KNNRegressor,}.
Machine{KNNRegressor,} trained 1 time; caches data
  model: KNNRegressor
  args:
    1:  Source @492`Table{AbstractVector{Continuous}}`
    2:  Source @836`AbstractVector{Continuous}`

julia> preds1 = MLJ.predict(mach1, X_rand_df[test_idx,:]);

julia> preds2 = MLJ.predict(mach2, X_rand_df[test_idx,:]);

julia> preds1 == preds2
false

julia> preds3 = MLJ.predict(mach3, X_rand_df[test_idx,:])
15-element Vector{Float64}:
 1.7029412653743503
 1.4591779632945534
 1.582290743528469
 2.101230332732603
 2.2316740656510454
 1.4370807503964989
 1.70609816653188
 1.5000572634358513
 1.6903333540586924
 1.5000572634358516
 1.5193869250110221
 1.0055116437574894
 1.1873231287446022
 1.3289743625868122
 1.1843582933771475

closes #46, closes #45.

@codecov-commenter
Copy link

codecov-commenter commented Jan 20, 2022

Codecov Report

Merging #47 (d79bc0a) into dev (faca85c) will increase coverage by 0.91%.
The diff coverage is 94.87%.

Impacted file tree graph

@@            Coverage Diff             @@
##              dev      #47      +/-   ##
==========================================
+ Coverage   92.39%   93.31%   +0.91%     
==========================================
  Files           3        4       +1     
  Lines         342      389      +47     
==========================================
+ Hits          316      363      +47     
  Misses         26       26              
Impacted Files Coverage Δ
src/NearestNeighborModels.jl 100.00% <ø> (ø)
src/models.jl 92.42% <94.87%> (+2.03%) ⬆️
src/utils.jl 100.00% <0.00%> (ø)
src/kernels.jl 90.35% <0.00%> (+0.17%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update faca85c...d79bc0a. Read the comment docs.

@ablaom
Copy link
Member

ablaom commented Jan 20, 2022

Thanks for this!

I will try to review in the next 24 hours.

src/models.jl Show resolved Hide resolved
test/models.jl Show resolved Hide resolved
Copy link
Member

@ablaom ablaom left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just the minor comments on the doc-strings.

Very thorough refactoring of the default constructor, thanks.

Please merge, make a release PR onto master, and tag a new release at your leisure.

@OkonSamuel OkonSamuel merged commit eda1f05 into dev Jan 21, 2022
@OkonSamuel OkonSamuel mentioned this pull request Jan 21, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants