Skip to content

Commit

Permalink
Merge pull request #179 from ayush-1506/master
Browse files Browse the repository at this point in the history
Add rng to resampling methods (cv and holdout)
  • Loading branch information
ablaom authored Jul 15, 2019
2 parents fd6f2f5 + 19f9dd8 commit e7aeeb4
Showing 1 changed file with 22 additions and 8 deletions.
30 changes: 22 additions & 8 deletions src/resampling.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
## RESAMPLING STRATEGIES

abstract type ResamplingStrategy <: MLJType end

import Random
# resampling strategies are `==` if they have the same type and their
# field values are `==`:
function ==(s1::S, s2::S) where S<:ResamplingStrategy
Expand All @@ -15,21 +15,23 @@ end
mutable struct Holdout <: ResamplingStrategy
fraction_train::Float64
shuffle::Bool
function Holdout(fraction_train, shuffle)
rng::Union{Int,AbstractRNG}
function Holdout(fraction_train, shuffle, rng)
0 < fraction_train && fraction_train < 1 ||
error("fraction_train must be between 0 and 1.")
return new(fraction_train, shuffle)
return new(fraction_train, shuffle, rng)
end
end
Holdout(; fraction_train=0.7, shuffle=false) = Holdout(fraction_train, shuffle)
Holdout(; fraction_train=0.7, shuffle=false, rng=Random.GLOBAL_RNG) = Holdout(fraction_train, shuffle, rng)
show_as_constructed(::Type{<:Holdout}) = true

mutable struct CV <: ResamplingStrategy
nfolds::Int
parallel::Bool
shuffle::Bool ## TODO: add seed/rng
shuffle::Bool ## TODO: add seed/rng
rng::Union{Int,AbstractRNG}
end
CV(; nfolds=6, parallel=true, shuffle=false) = CV(nfolds, parallel, shuffle)
CV(; nfolds=6, parallel=true, shuffle=false, rng=Random.GLOBAL_RNG) = CV(nfolds, parallel, shuffle, rng)
MLJBase.show_as_constructed(::Type{<:CV}) = true


Expand Down Expand Up @@ -67,6 +69,12 @@ function evaluate!(mach::Machine, resampling::Holdout;
measure=nothing, operation=predict,
rows=nothing, force=false,verbosity=1)

if resampling.rng isa Integer
rng = MersenneTwister(resampling.rng)
else
rng = resampling.rng
end

if measure == nothing
_measures = default_measure(mach.model)
if _measures == nothing
Expand All @@ -84,7 +92,7 @@ function evaluate!(mach::Machine, resampling::Holdout;
rows == nothing ? eachindex(y) : rows

train, test = partition(all, resampling.fraction_train,
shuffle=resampling.shuffle)
shuffle=resampling.shuffle, rng=rng)
if verbosity > 0
all == eachindex(y) ? "Resampling from all rows. " : "Resampling from a subset of all rows. "
which_rows =
Expand Down Expand Up @@ -115,6 +123,12 @@ function evaluate!(mach::Machine, resampling::CV;
measure=nothing, operation=predict,
rows=nothing, force=false, verbosity=1)

if resampling.rng isa Integer
rng = MersenneTwister(resampling.rng)
else
rng = resampling.rng
end

if measure == nothing
_measures = default_measure(mach.model)
if _measures == nothing
Expand Down Expand Up @@ -146,7 +160,7 @@ function evaluate!(mach::Machine, resampling::CV;
nfolds = resampling.nfolds

if resampling.shuffle
all = shuffle(all)
shuffle!(rng, collect(all))
end

k = floor(Int,n_samples/nfolds)
Expand Down

0 comments on commit e7aeeb4

Please sign in to comment.