Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add R-square (coefficient of determination) #679

Merged
merged 6 commits into from Nov 18, 2021
Merged

Conversation

rikhuijzer
Copy link
Member

@rikhuijzer rikhuijzer commented Nov 17, 2021

This PR adds the R² metric.

R² is more informative for linear regressions than MSE and RMSE (Chicco et al, 2021) because the scale is dimensionless (a percentage).

I value -3 in the test is manually calculated by me. All looks good.

EDIT: It doesn't work yet in CV. I get an _check(::RSquare, Vector{Int}). I'll debug it tonight or so

@rikhuijzer rikhuijzer mentioned this pull request Nov 17, 2021
@rikhuijzer
Copy link
Member Author

rikhuijzer commented Nov 17, 2021

I don't get it (yet):

~/git/MLJBase.jl> julia --project

julia> using MLJBase

julia> using MLJModelInterface

julia> MMI = MLJModelInterface;

julia> include("test/_models/Constant.jl")

julia> m = ConstantRegressor();

julia> X, y = make_regression();

julia> e = evaluate(m, X, y; resampling=CV(), measures=[rms]);
Evaluating over 6 folds: 100%[=========================] Time: 0:00:00

julia> e = evaluate(m, X, y; resampling=CV(), measures=[rsq]);
Evaluating over 6 folds: 100%[=========================] Time: 0:00:00
ERROR: MethodError: no method matching _check(::RSquared, ::Vector{Float64})
Closest candidates are:
  _check(::MLJBase.Measure, ::Any, ::Any, ::AbstractArray) at /home/rik/git/MLJBase.jl/src/measures/measures.jl:61
  _check(::MLJBase.Measure, ::Any, ::Any) at /home/rik/git/MLJBase.jl/src/measures/measures.jl:53
  _check(::MLJBase.Measure, ::Any, ::Any, ::Any) at /home/rik/git/MLJBase.jl/src/measures/measures.jl:57
  ...
Stacktrace:
  [1] (::RSquared)(args::Vector{Float64})
    @ MLJBase ~/git/MLJBase.jl/src/measures/measures.jl:116
  [2] aggregate(v::Vector{Float64}, measure::RSquared)
    @ MLJBase ~/git/MLJBase.jl/src/measures/measures.jl:164
  [3] (::MLJBase.var"#298#304"{Vector{RSquared}, Vector{Vector{Float64}}})(k::Int64)
    @ MLJBase ~/git/MLJBase.jl/src/resampling.jl:1157
  [4] iterate
    @ ./generator.jl:47 [inlined]

EDIT: This was resolved by not setting the aggregation field via metadata_measure.

@rikhuijzer
Copy link
Member Author

I've read measures/README.md and checked the code and tests, but still have no idea whatsoever what an "aggregation" is. So, be careful when reviewing this PR in whether the aggregation is correctly specified (by being left empty).

@codecov-commenter
Copy link

codecov-commenter commented Nov 17, 2021

Codecov Report

Merging #679 (5543cf6) into dev (8b75a34) will decrease coverage by 0.12%.
The diff coverage is 50.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##              dev     #679      +/-   ##
==========================================
- Coverage   85.29%   85.17%   -0.13%     
==========================================
  Files          39       39              
  Lines        3414     3426      +12     
==========================================
+ Hits         2912     2918       +6     
- Misses        502      508       +6     
Impacted Files Coverage Δ
src/MLJBase.jl 100.00% <ø> (ø)
src/measures/continuous.jl 86.66% <50.00%> (-9.17%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 8b75a34...5543cf6. Read the comment docs.

@OkonSamuel
Copy link
Member

I've read measures/README.md and checked the code and tests, but still have no idea whatsoever what an "aggregation" is. So, be careful when reviewing this PR in whether the aggregation is correctly specified (by being left empty).

Thanks @rikhuijzer for this PR.
"aggregation" is a concept used in resampling in two ways

  1. To aggregate the results gotten from calling UnAggregated performance measures on observations. (see section on per_fold in documentation for evaluate! function.) See here also.
  2. To aggregate the results from several measures to have one common numerical value. This is useful when calling evaluate/evaluate! functions with a vector of measures. See here

Now the reason why your previous commit errored was because, you defined a wrong aggregation in your metadata function

aggregation             = RSquared()

Although RSquared() is a valid julia object (since you created it), it isn't a recognized StatisticalTraits.AggregationMode object. See here for definitions of existing aggregations and how to define new ones.

@ablaom ablaom self-requested a review November 17, 2021 21:03
src/measures/continuous.jl Outdated Show resolved Hide resolved
src/measures/continuous.jl Outdated Show resolved Hide resolved
Copy link
Member

@ablaom ablaom left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this contribution @rikhuijzer !

As @OkonSamuel correctly points out "aggregation" refers to how scores are combined in when we resample to estimate the expected value of the score on unseen data. So if I am doing 3-fold cross-validation, I'll get 3 R-squared scores and the questions is how should I combine them? Since R-squared is rather non-linear (in the sense I described above) I don't believe there's an obvious choice, so falling back to Mean(), as you currently have, is my recommendation.

If this was number-of-false-positives in classification, you would use Count().

For rms its possible to define an aggregation function f with the property that f([rms(A1), rms(A2), rms(A3)]) = rms(cat(A1, A2, A3)) (assuming A1, A2, A3 are data sets of same size) so that case has a special aggregator.

@rikhuijzer
Copy link
Member Author

Thanks both for the comments, @ablaom @OkonSamuel 😄. I've implemented them.

@rikhuijzer
Copy link
Member Author

Now that the weighted call is removed, should supports_weights or supports_class_weights also be set to false?

Copy link
Member

@ablaom ablaom left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Thanks again 🙏🏾

@ablaom ablaom merged commit 1dcfdb4 into JuliaAI:dev Nov 18, 2021
@rikhuijzer rikhuijzer deleted the r-square branch November 18, 2021 18:46
@ablaom
Copy link
Member

ablaom commented Nov 18, 2021

Now that the weighted call is removed, should supports_weights or supports_class_weights also be set to false?

Oops, good catch. They should both be set to false. Can you push the change?

@ablaom
Copy link
Member

ablaom commented Nov 18, 2021

Nevermind, I'll do it when I merge your other PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants