Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stratified cross-validation #341

Merged
merged 4 commits into from
Nov 19, 2019
Merged

Stratified cross-validation #341

merged 4 commits into from
Nov 19, 2019

Conversation

ablaom
Copy link
Member

@ablaom ablaom commented Nov 18, 2019

resolves #108

I have also rewritten the resampling strategy docstrings to make them more technically concise. More "instructional" documentation should go in the manual (a little less concise) and MLJTutorials (the most verbose).

From the new docstring:


stratified_cv = StratifiedCV(; nfolds=6,  shuffle=false, rng=Random.GLOBAL_RNG)

Stratified cross-validation resampling strategy, for use in
evaluate!, evaluate and in tuning. Applies only to classification
problems (OrderedFactor or Multiclass targets).

train_test_pairs(stratified_cv, rows, X, y)        # X is ignored

Returns an nfolds-length iterator of (train, test) pairs of
vectors (row indices) where each train and test is a sub-vector of
rows. The test vectors are mutually exclusive and exhaust
rows. Each train vector is the complement of the corresponding
test vector.

Unlike regular cross-validation, the distribution of the levels of the
target y corresponding to each train and test is constrained, as
far as possible, to replicate that of y[rows] as a whole.

Specifically, the data is split into a number of groups on which y
is constant, and each individual group is resampled according to the
ordinary cross-validation strategy CV(nfolds=nfolds). To obtain the
final (train, test) pairs of row indices, the per-group pairs are
collated in such a way that each collated train and test respects
the original order of rows (after shuffling, if shuffle=true).

If rng is an integer, then MersenneTwister(rng) is the random
number generator used for shuffling rows. Otherwise some AbstractRNG
object is expected.

@ablaom ablaom requested a review from tlienart November 18, 2019 05:41
Copy link
Collaborator

@tlienart tlienart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks, with just one minor question.

Side note: there should be an explicit keyword in the partition function where you can do

train, test  = partition(eachindex(y), ...; stratify=y, ...)

granted this can probably be obtained with the train_test_pairs but I think people will expect a straightforward one too that looks like something above. (Can be for another PR)

`evaluate!`, `evaluate` and in tuning. Applies only to classification
problems (`OrderedFactor` or `Multiclass` targets).

train_test_pairs(stratified_cv, rows, X, y) # X is ignored
Copy link
Collaborator

@tlienart tlienart Nov 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since X is ignored could we not have a 3 argument call? (scv, rows, y) ? did you want 4 arguments for consistency with something else?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we can have a 3 argument call. Will implement

@ablaom
Copy link
Member Author

ablaom commented Nov 18, 2019

Thanks for the review.

Re partition. Sounds like good idea. Since partition is currently in MLJBase let's make this a separate PR. Raised issue here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add stratified sampling
2 participants