From c432752569e6295ff21feacd21faca78bc7c4c90 Mon Sep 17 00:00:00 2001 From: Thibaut Lienart Date: Tue, 16 Jul 2019 18:15:34 +0100 Subject: [PATCH] use of Parameters.jl and more cosmetic fixes --- Project.toml | 1 + src/MLJ.jl | 4 ++ src/resampling.jl | 117 +++++++++++++++++++++++++++------------------- 3 files changed, 75 insertions(+), 47 deletions(-) diff --git a/Project.toml b/Project.toml index 61380f7ab..e76d0f323 100644 --- a/Project.toml +++ b/Project.toml @@ -15,6 +15,7 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7" +Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/src/MLJ.jl b/src/MLJ.jl index 6e97be437..729e0abdf 100644 --- a/src/MLJ.jl +++ b/src/MLJ.jl @@ -64,7 +64,11 @@ import Distributions import StatsBase using ProgressMeter import Tables +import Random + +# convenience packages using DocStringExtensions: SIGNATURES, TYPEDEF +using Parameters # to be extended: import Base.== diff --git a/src/resampling.jl b/src/resampling.jl index 2b9883974..69c9c3f01 100644 --- a/src/resampling.jl +++ b/src/resampling.jl @@ -1,37 +1,56 @@ ## RESAMPLING STRATEGIES abstract type ResamplingStrategy <: MLJType end -import Random + # resampling strategies are `==` if they have the same type and their # field values are `==`: -function ==(s1::S, s2::S) where S<:ResamplingStrategy - ret = true - for fld in fieldnames(S) - ret = ret && getfield(s1, fld) == getfield(s2, fld) - end - return ret +function ==(s1::S, s2::S) where S <: ResamplingStrategy + return all(getfield(s1, fld) == getfield(s2, fld) for fld in fieldnames(S)) end -mutable struct Holdout <: ResamplingStrategy - fraction_train::Float64 - shuffle::Bool - rng::Union{Int,AbstractRNG} +""" +$TYPEDEF + +Single train-test split with a (randomly selected) portion of the +data being selected for training and the rest for testing. + +* `fraction_train` a number between 0 and 1 indicating the proportion +of the samples to use for training +* `shuffle` a boolean indicating whether to select the training samples +at random +* `rng` a random number generator to use +""" +@with_kw mutable struct Holdout <: ResamplingStrategy + fraction_train::Float64 = 0.7 + shuffle::Bool = false + rng::Union{Int,AbstractRNG} = Random.GLOBAL_RNG + function Holdout(fraction_train, shuffle, rng) - 0 < fraction_train && fraction_train < 1 || - error("fraction_train must be between 0 and 1.") + 0 < fraction_train < 1 || error("`fraction_train` must be between 0 and 1.") return new(fraction_train, shuffle, rng) end end -Holdout(; fraction_train=0.7, shuffle=false, rng=Random.GLOBAL_RNG) = Holdout(fraction_train, shuffle, rng) + show_as_constructed(::Type{<:Holdout}) = true -mutable struct CV <: ResamplingStrategy - nfolds::Int - parallel::Bool - shuffle::Bool ## TODO: add seed/rng - rng::Union{Int,AbstractRNG} +""" +$TYPEDEF + +Cross validation resampling where the data is (randomly) partitioned in `nfolds` folds +and the model is evaluated `nfolds` time, each time taking one fold for testing and the +other folds for training and the test performances averaged. +For instance if `nfolds=3` then the data will be partitioned in three folds A, B and C +and the model will be trained three times, first with A and B and tested on C, then on +A, C and tested on B and finally on B, C and tested on A. The test performances are then +averaged over the three cases. +""" +@with_kw mutable struct CV <: ResamplingStrategy + nfolds::Int = 6 + parallel::Bool = true + shuffle::Bool = false ## TODO: add seed/rng + rng::Union{Int,AbstractRNG} = Random.GLOBAL_RNG end -CV(; nfolds=6, parallel=true, shuffle=false, rng=Random.GLOBAL_RNG) = CV(nfolds, parallel, shuffle, rng) + MLJBase.show_as_constructed(::Type{<:CV}) = true @@ -88,21 +107,20 @@ function evaluate!(mach::Machine, resampling::Holdout; y = mach.args[2] length(mach.args) == 2 || error("Multivariate targets not yet supported.") - all = - rows === nothing ? eachindex(y) : rows + unspecified_rows = (rows === nothing) + all = unspecified_rows ? eachindex(y) : rows train, test = partition(all, resampling.fraction_train, shuffle=resampling.shuffle, rng=rng) if verbosity > 0 - all == eachindex(y) ? "Resampling from all rows. " : "Resampling from a subset of all rows. " - which_rows = - all == eachindex(y) ? "Resampling from all rows. " : "Resampling from a subset of all rows. " - @info "Evaluating using a holdout set. \n"* - "fraction_train=$(resampling.fraction_train) \n"* - "shuffle=$(resampling.shuffle) \n"* - "measure=$_measures \n"* - "operation=$operation \n"* - "$which_rows" + which_rows = ifelse(unspecified_rows, "Resampling from all rows. ", + "Resampling from a subset of all rows. ") + @info "Evaluating using a holdout set. \n" * + "fraction_train=$(resampling.fraction_train) \n" * + "shuffle=$(resampling.shuffle) \n" * + "measure=$_measures \n" * + "operation=$operation \n" * + "$which_rows" end fit!(mach, rows=train, verbosity=verbosity-1, force=force) @@ -142,18 +160,18 @@ function evaluate!(mach::Machine, resampling::CV; y = mach.args[2] length(mach.args) == 2 || error("Multivariate targets not yet supported.") - all = - rows === nothing ? eachindex(y) : rows + unspecified_rows = (rows === nothing) + all = unspecified_rows ? eachindex(y) : rows if verbosity > 0 - which_rows = - all == eachindex(y) ? "Resampling from all rows. " : "Resampling from a subset of all rows. " - @info "Evaluating using cross-validation. \n"* - "nfolds=$(resampling.nfolds). \n"* - "shuffle=$(resampling.shuffle) \n"* - "measure=$_measures \n"* - "operation=$operation \n"* - "$which_rows" + which_rows = ifelse(unspecified_rows, "Resampling from all rows. ", + "Resampling from a subset of all rows. ") + @info "Evaluating using cross-validation. \n" * + "nfolds=$(resampling.nfolds). \n" * + "shuffle=$(resampling.shuffle) \n" * + "measure=$_measures \n" * + "operation=$operation \n" * + "$which_rows" end n_samples = length(all) @@ -163,7 +181,8 @@ function evaluate!(mach::Machine, resampling::CV; shuffle!(rng, collect(all)) end - k = floor(Int,n_samples/nfolds) + # number of samples per fold + k = floor(Int, n_samples/nfolds) # function to return the measures for the fold `all[f:s]`: function get_measure(f, s) @@ -184,8 +203,8 @@ function evaluate!(mach::Machine, resampling::CV; if resampling.parallel && nworkers() > 1 ## TODO: progress meter for distributed case if verbosity > 0 - @info "Distributing cross-validation computation "* - "among $(nworkers()) workers." + @info "Distributing cross-validation computation " * + "among $(nworkers()) workers." end measure_values = @distributed vcat for n in 1:nfolds [get_measure(firsts[n], seconds[n])] @@ -216,8 +235,11 @@ end ## RESAMPLER - A MODEL WRAPPER WITH `evaluate` OPERATION -# this is needed for the `TunedModel` `fit` defined in tuning.jl +""" +$TYPEDEF +Resampler structure for the `TunedModel` `fit` defined in `tuning.jl`. +""" mutable struct Resampler{S,M<:Supervised} <: Supervised model::M resampling::S # resampling strategy @@ -230,8 +252,9 @@ MLJBase.is_wrapper(::Type{<:Resampler}) = true Resampler(; model=ConstantRegressor(), resampling=Holdout(), - measure=nothing, operation=predict) = - Resampler(model, resampling, measure, operation) + measure=nothing, operation=predict) = + Resampler(model, resampling, measure, operation) + function MLJBase.fit(resampler::Resampler, verbosity::Int, X, y)