Skip to content

Commit

Permalink
use of Parameters.jl and more cosmetic fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
tlienart committed Jul 16, 2019
1 parent 58d1b72 commit c432752
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 47 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
4 changes: 4 additions & 0 deletions src/MLJ.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,11 @@ import Distributions
import StatsBase
using ProgressMeter
import Tables
import Random

# convenience packages
using DocStringExtensions: SIGNATURES, TYPEDEF
using Parameters

# to be extended:
import Base.==
Expand Down
117 changes: 70 additions & 47 deletions src/resampling.jl
Original file line number Diff line number Diff line change
@@ -1,37 +1,56 @@
## RESAMPLING STRATEGIES

abstract type ResamplingStrategy <: MLJType end
import Random

# resampling strategies are `==` if they have the same type and their
# field values are `==`:
function ==(s1::S, s2::S) where S<:ResamplingStrategy
ret = true
for fld in fieldnames(S)
ret = ret && getfield(s1, fld) == getfield(s2, fld)
end
return ret
function ==(s1::S, s2::S) where S <: ResamplingStrategy
return all(getfield(s1, fld) == getfield(s2, fld) for fld in fieldnames(S))
end

mutable struct Holdout <: ResamplingStrategy
fraction_train::Float64
shuffle::Bool
rng::Union{Int,AbstractRNG}
"""
$TYPEDEF
Single train-test split with a (randomly selected) portion of the
data being selected for training and the rest for testing.
* `fraction_train` a number between 0 and 1 indicating the proportion
of the samples to use for training
* `shuffle` a boolean indicating whether to select the training samples
at random
* `rng` a random number generator to use
"""
@with_kw mutable struct Holdout <: ResamplingStrategy
fraction_train::Float64 = 0.7
shuffle::Bool = false
rng::Union{Int,AbstractRNG} = Random.GLOBAL_RNG

function Holdout(fraction_train, shuffle, rng)
0 < fraction_train && fraction_train < 1 ||
error("fraction_train must be between 0 and 1.")
0 < fraction_train < 1 || error("`fraction_train` must be between 0 and 1.")
return new(fraction_train, shuffle, rng)
end
end
Holdout(; fraction_train=0.7, shuffle=false, rng=Random.GLOBAL_RNG) = Holdout(fraction_train, shuffle, rng)

show_as_constructed(::Type{<:Holdout}) = true

mutable struct CV <: ResamplingStrategy
nfolds::Int
parallel::Bool
shuffle::Bool ## TODO: add seed/rng
rng::Union{Int,AbstractRNG}
"""
$TYPEDEF
Cross validation resampling where the data is (randomly) partitioned in `nfolds` folds
and the model is evaluated `nfolds` time, each time taking one fold for testing and the
other folds for training and the test performances averaged.
For instance if `nfolds=3` then the data will be partitioned in three folds A, B and C
and the model will be trained three times, first with A and B and tested on C, then on
A, C and tested on B and finally on B, C and tested on A. The test performances are then
averaged over the three cases.
"""
@with_kw mutable struct CV <: ResamplingStrategy
nfolds::Int = 6
parallel::Bool = true
shuffle::Bool = false ## TODO: add seed/rng
rng::Union{Int,AbstractRNG} = Random.GLOBAL_RNG
end
CV(; nfolds=6, parallel=true, shuffle=false, rng=Random.GLOBAL_RNG) = CV(nfolds, parallel, shuffle, rng)

MLJBase.show_as_constructed(::Type{<:CV}) = true


Expand Down Expand Up @@ -88,21 +107,20 @@ function evaluate!(mach::Machine, resampling::Holdout;
y = mach.args[2]
length(mach.args) == 2 || error("Multivariate targets not yet supported.")

all =
rows === nothing ? eachindex(y) : rows
unspecified_rows = (rows === nothing)
all = unspecified_rows ? eachindex(y) : rows

train, test = partition(all, resampling.fraction_train,
shuffle=resampling.shuffle, rng=rng)
if verbosity > 0
all == eachindex(y) ? "Resampling from all rows. " : "Resampling from a subset of all rows. "
which_rows =
all == eachindex(y) ? "Resampling from all rows. " : "Resampling from a subset of all rows. "
@info "Evaluating using a holdout set. \n"*
"fraction_train=$(resampling.fraction_train) \n"*
"shuffle=$(resampling.shuffle) \n"*
"measure=$_measures \n"*
"operation=$operation \n"*
"$which_rows"
which_rows = ifelse(unspecified_rows, "Resampling from all rows. ",
"Resampling from a subset of all rows. ")
@info "Evaluating using a holdout set. \n" *
"fraction_train=$(resampling.fraction_train) \n" *
"shuffle=$(resampling.shuffle) \n" *
"measure=$_measures \n" *
"operation=$operation \n" *
"$which_rows"
end

fit!(mach, rows=train, verbosity=verbosity-1, force=force)
Expand Down Expand Up @@ -142,18 +160,18 @@ function evaluate!(mach::Machine, resampling::CV;
y = mach.args[2]
length(mach.args) == 2 || error("Multivariate targets not yet supported.")

all =
rows === nothing ? eachindex(y) : rows
unspecified_rows = (rows === nothing)
all = unspecified_rows ? eachindex(y) : rows

if verbosity > 0
which_rows =
all == eachindex(y) ? "Resampling from all rows. " : "Resampling from a subset of all rows. "
@info "Evaluating using cross-validation. \n"*
"nfolds=$(resampling.nfolds). \n"*
"shuffle=$(resampling.shuffle) \n"*
"measure=$_measures \n"*
"operation=$operation \n"*
"$which_rows"
which_rows = ifelse(unspecified_rows, "Resampling from all rows. ",
"Resampling from a subset of all rows. ")
@info "Evaluating using cross-validation. \n" *
"nfolds=$(resampling.nfolds). \n" *
"shuffle=$(resampling.shuffle) \n" *
"measure=$_measures \n" *
"operation=$operation \n" *
"$which_rows"
end

n_samples = length(all)
Expand All @@ -163,7 +181,8 @@ function evaluate!(mach::Machine, resampling::CV;
shuffle!(rng, collect(all))
end

k = floor(Int,n_samples/nfolds)
# number of samples per fold
k = floor(Int, n_samples/nfolds)

# function to return the measures for the fold `all[f:s]`:
function get_measure(f, s)
Expand All @@ -184,8 +203,8 @@ function evaluate!(mach::Machine, resampling::CV;
if resampling.parallel && nworkers() > 1
## TODO: progress meter for distributed case
if verbosity > 0
@info "Distributing cross-validation computation "*
"among $(nworkers()) workers."
@info "Distributing cross-validation computation " *
"among $(nworkers()) workers."
end
measure_values = @distributed vcat for n in 1:nfolds
[get_measure(firsts[n], seconds[n])]
Expand Down Expand Up @@ -216,8 +235,11 @@ end

## RESAMPLER - A MODEL WRAPPER WITH `evaluate` OPERATION

# this is needed for the `TunedModel` `fit` defined in tuning.jl
"""
$TYPEDEF
Resampler structure for the `TunedModel` `fit` defined in `tuning.jl`.
"""
mutable struct Resampler{S,M<:Supervised} <: Supervised
model::M
resampling::S # resampling strategy
Expand All @@ -230,8 +252,9 @@ MLJBase.is_wrapper(::Type{<:Resampler}) = true


Resampler(; model=ConstantRegressor(), resampling=Holdout(),
measure=nothing, operation=predict) =
Resampler(model, resampling, measure, operation)
measure=nothing, operation=predict) =
Resampler(model, resampling, measure, operation)


function MLJBase.fit(resampler::Resampler, verbosity::Int, X, y)

Expand Down

0 comments on commit c432752

Please sign in to comment.