-
Notifications
You must be signed in to change notification settings - Fork 157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stratified cross-validation #341
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks, with just one minor question.
Side note: there should be an explicit keyword in the partition
function where you can do
train, test = partition(eachindex(y), ...; stratify=y, ...)
granted this can probably be obtained with the train_test_pairs
but I think people will expect a straightforward one too that looks like something above. (Can be for another PR)
src/resampling.jl
Outdated
`evaluate!`, `evaluate` and in tuning. Applies only to classification | ||
problems (`OrderedFactor` or `Multiclass` targets). | ||
|
||
train_test_pairs(stratified_cv, rows, X, y) # X is ignored |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since X
is ignored could we not have a 3 argument call? (scv, rows, y
) ? did you want 4 arguments for consistency with something else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, we can have a 3 argument call. Will implement
Thanks for the review. Re partition. Sounds like good idea. Since |
resolves #108
I have also rewritten the resampling strategy docstrings to make them more technically concise. More "instructional" documentation should go in the manual (a little less concise) and MLJTutorials (the most verbose).
From the new docstring:
Stratified cross-validation resampling strategy, for use in
evaluate!
,evaluate
and in tuning. Applies only to classificationproblems (
OrderedFactor
orMulticlass
targets).Returns an
nfolds
-length iterator of(train, test)
pairs ofvectors (row indices) where each
train
andtest
is a sub-vector ofrows
. Thetest
vectors are mutually exclusive and exhaustrows
. Eachtrain
vector is the complement of the correspondingtest
vector.Unlike regular cross-validation, the distribution of the levels of the
target
y
corresponding to eachtrain
andtest
is constrained, asfar as possible, to replicate that of
y[rows]
as a whole.Specifically, the data is split into a number of groups on which
y
is constant, and each individual group is resampled according to the
ordinary cross-validation strategy
CV(nfolds=nfolds)
. To obtain thefinal
(train, test)
pairs of row indices, the per-group pairs arecollated in such a way that each collated
train
andtest
respectsthe original order of
rows
(after shuffling, ifshuffle=true
).If
rng
is an integer, thenMersenneTwister(rng)
is the randomnumber generator used for shuffling rows. Otherwise some
AbstractRNG
object is expected.