Skip to content

Commit

Permalink
tweak shuffle logic in resampling strategy constructors #258
Browse files Browse the repository at this point in the history
  • Loading branch information
ablaom committed Dec 9, 2019
1 parent 9d354d9 commit e288ff5
Showing 1 changed file with 72 additions and 59 deletions.
131 changes: 72 additions & 59 deletions src/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,46 @@ train_test_pairs(s::ResamplingStrategy, rows, X, y) =
train_test_pairs(s::ResamplingStrategy, rows, y) =
train_test_pairs(s, rows)


# Helper to interpret rng, shuffle in case either is `nothing` or if
# `rng` is an integer:
function shuffle_and_rng(shuffle, rng)
if rng isa Integer
rng = MersenneTwister(rng)
end

if shuffle === nothing
shuffle = ifelse(rng===nothing, false, true)
end

if rng === nothing
rng = Random.GLOBAL_RNG
end

return shuffle, rng
end

"""
holdout = Holdout(; fraction_train=0.7,
shuffle=false,
rng=Random.GLOBAL_RNG)
shuffle=nothing,
rng=nothing)
Holdout resampling strategy, for use in `evaluate!`, `evaluate` and in tuning.
train_test_pairs(holdout, rows)
Returns the pair `[(train, test)]`, where `train` and `test` are
vectors such that `rows=vcat(train, test)` and
`length(train)/length(test) ≈ fraction_train`.
`length(train)/length(test)` is approximatey equal to fraction_train`.
Pre-shuffling of `rows` is controlled by `rng` and `shuffle`. If `rng`
is an integer, then the `Holdout` keyword constructor resets it to
`MersenneTwister(rng)`. Otherwise some `AbstractRNG` object is
expected.
If `rng` is an integer, then `MersenneTwister(rng)` is the random
number generator used for shuffling rows. Otherwise some `AbstractRNG`
object is expected.
If `rng` is left unspecified, `rng` is reset to `Random.GLOBAL_RNG`,
in which case rows are only pre-shuffled if `shuffle=true` is
specified.
"""
struct Holdout <: ResamplingStrategy
Expand All @@ -48,27 +72,20 @@ struct Holdout <: ResamplingStrategy
end

# Keyword Constructor
function Holdout(; fraction_train::Float64=0.7,
shuffle::Bool=false,
rng::Union{Int,AbstractRNG}=Random.GLOBAL_RNG)
Holdout(fraction_train, shuffle, rng)
end

Holdout(; fraction_train::Float64=0.7, shuffle=nothing, rng=nothing) =
Holdout(fraction_train, shuffle_and_rng(shuffle, rng)...)

function train_test_pairs(holdout::Holdout, rows)
if holdout.rng isa Integer
rng = MersenneTwister(holdout.rng)
else
rng = holdout.rng
end

train, test = partition(rows, holdout.fraction_train,
shuffle=holdout.shuffle, rng=rng)
shuffle=holdout.shuffle, rng=holdout.rng)
return [(train, test),]

end


"""
cv = CV(; nfolds=6, shuffle=false, rng=Random.GLOBAL_RNG)
cv = CV(; nfolds=6, shuffle=nothing, rng=nothing)
Cross-validation resampling strategy, for use in `evaluate!`,
`evaluate` and tuning.
Expand All @@ -78,16 +95,20 @@ Cross-validation resampling strategy, for use in `evaluate!`,
Returns an `nfolds`-length iterator of `(train, test)` pairs of
vectors (row indices), where each `train` and `test` is a sub-vector
of `rows`. The `test` vectors are mutually exclusive and exhaust
`rows`. Each `train` vector is the complement of the
corresponding `test` vector. With no shuffling, the order of `rows` is
`rows`. Each `train` vector is the complement of the corresponding
`test` vector. With no row pre-shuffling, the order of `rows` is
preserved, in the sense that `rows` coincides precisely with the
concatenation of the `test` vectors, in the order they are
generated. All but the last `test` vector have equal length.
Declaring `shuffle=true` results in `rows` being shuffled first. If
`rng` is an integer, then `MersenneTwister(rng)` is the random number
generator used for shuffling `rows`. Otherwise some `AbstractRNG`
object is expected.
Pre-shuffling of `rows` is controlled by `rng` and `shuffle`. If `rng`
is an integer, then the `CV` keyword constructor resets it to
`MersenneTwister(rng)`. Otherwise some `AbstractRNG` object is
expected.
If `rng` is left unspecified, `rng` is reset to `Random.GLOBAL_RNG`,
in which case rows are only pre-shuffled if `shuffle=true` is
explicitly specified.
"""
struct CV <: ResamplingStrategy
Expand All @@ -101,22 +122,16 @@ struct CV <: ResamplingStrategy
end

# Constructor with keywords
CV(; nfolds::Int=6, shuffle::Bool=false,
rng::Union{Int,AbstractRNG}=Random.GLOBAL_RNG) =
CV(nfolds, shuffle, rng)
CV(; nfolds::Int=6, shuffle=nothing, rng=nothing) =
CV(nfolds, shuffle_and_rng(shuffle, rng)...)

function train_test_pairs(cv::CV, rows)
if cv.rng isa Integer
rng = MersenneTwister(cv.rng)
else
rng = cv.rng
end

n_observations = length(rows)
nfolds = cv.nfolds

if cv.shuffle
rows=shuffle!(rng, collect(rows))
rows=shuffle!(cv.rng, collect(rows))
end

# number of observations per fold
Expand All @@ -140,21 +155,21 @@ function train_test_pairs(cv::CV, rows)
end

"""
stratified_cv = StratifiedCV(; nfolds=6,
shuffle=false,
stratified_cv = StratifiedCV(; nfolds=6,
shuffle=false,
rng=Random.GLOBAL_RNG)
Stratified cross-validation resampling strategy, for use in
`evaluate!`, `evaluate` and in tuning. Applies only to classification
problems (`OrderedFactor` or `Multiclass` targets).
train_test_pairs(stratified_cv, rows, y)
Returns an `nfolds`-length iterator of `(train, test)` pairs of
vectors (row indices) where each `train` and `test` is a sub-vector of
`rows`. The `test` vectors are mutually exclusive and exhaust
`rows`. Each `train` vector is the complement of the corresponding
`test` vector.
`test` vector.
Unlike regular cross-validation, the distribution of the levels of the
target `y` corresponding to each `train` and `test` is constrained, as
Expand All @@ -167,9 +182,14 @@ final `(train, test)` pairs of row indices, the per-group pairs are
collated in such a way that each collated `train` and `test` respects
the original order of `rows` (after shuffling, if `shuffle=true`).
If `rng` is an integer, then `MersenneTwister(rng)` is the random
number generator used for shuffling rows. Otherwise some `AbstractRNG`
object is expected.
Pre-shuffling of `rows` is controlled by `rng` and `shuffle`. If `rng`
is an integer, then the `StratifedCV` keyword constructor resets it to
`MersenneTwister(rng)`. Otherwise some `AbstractRNG` object is
expected.
If `rng` is left unspecified, `rng` is reset to `Random.GLOBAL_RNG`,
in which case rows are only pre-shuffled if `shuffle=true` is
explicitly specified.
"""
struct StratifiedCV <: ResamplingStrategy
Expand All @@ -183,23 +203,16 @@ struct StratifiedCV <: ResamplingStrategy
end

# Constructor with keywords
StratifiedCV(; nfolds::Int=6, shuffle::Bool=false,
rng::Union{Int,AbstractRNG}=Random.GLOBAL_RNG) =
StratifiedCV(nfolds, shuffle, rng)

StratifiedCV(; nfolds::Int=6, shuffle=nothing, rng=nothing) =
StratifiedCV(nfolds, shuffle_and_rng(shuffle, rng)...)

function train_test_pairs(stratified_cv::StratifiedCV, rows, X, y)
if stratified_cv.rng isa Integer
rng = MersenneTwister(stratified_cv.rng)
else
rng = stratified_cv.rng
end

n_observations = length(rows)
nfolds = stratified_cv.nfolds

if stratified_cv.shuffle
rows=shuffle!(rng, collect(rows))
rows=shuffle!(stratified_cv.rng, collect(rows))
end

st = scitype(y)
Expand All @@ -213,7 +226,7 @@ function train_test_pairs(stratified_cv::StratifiedCV, rows, X, y)
error("The number of observations for which the target takes on a "*
"given class must, for each class, exceed `nfolds`. Try "*
"reducing `nfolds`. ")

levels_seen = keys(freq_given_level) |> collect

cv = CV(nfolds=nfolds)
Expand All @@ -232,21 +245,21 @@ function train_test_pairs(stratified_cv::StratifiedCV, rows, X, y)
# just the test rows in each level:
tests_per_level = map(x -> last.(x),
train_test_pairs_per_level)
# for each fold, concatenate the train rows over levels:

# for each fold, concatenate the train rows over levels:
trains_per_fold = map(x->vcat(x...), zip(trains_per_level...))
# for each fold, concatenate the test rows over levels:

# for each fold, concatenate the test rows over levels:
tests_per_fold = map(x->vcat(x...), zip(tests_per_level...))

# restore ordering specified by rows:
trains_per_fold = map(trains_per_fold) do train
filter(in(train), rows)
end
tests_per_fold = map(tests_per_fold) do test
filter(in(test), rows)
end

# re-assemble:
return zip(trains_per_fold, tests_per_fold) |> collect

Expand Down Expand Up @@ -368,7 +381,7 @@ restrict the data used in evaluation by specifying `rows`.
An optional `weights` vector may be passed for measures that support
sample weights (`MLJ.supports_weights(measure) == true`), which is
ignored by those that don't.
ignored by those that don't.
*Important:* If `mach` already wraps sample weights `w` (as in `mach =
machine(model, X, y, w)`) then these weights, which are used for
Expand Down

0 comments on commit e288ff5

Please sign in to comment.