Skip to content

Commit

Permalink
final minor fixes before PR
Browse files Browse the repository at this point in the history
  • Loading branch information
tlienart committed Jul 16, 2019
1 parent c432752 commit fb97aed
Showing 1 changed file with 43 additions and 38 deletions.
81 changes: 43 additions & 38 deletions src/tuning.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,42 @@
abstract type TuningStrategy <: MLJ.MLJType end

mutable struct Grid <: TuningStrategy
resolution::Int
parallel::Bool
"""
$TYPEDEF
Grid object for a grid search tuning strategy.
"""
@with_kw mutable struct Grid <: TuningStrategy
resolution::Int = 10
parallel::Bool = true
end

Grid(;resolution=10, parallel=true) = Grid(resolution, parallel)
MLJBase.show_as_constructed(::Type{<:Grid}) = true

"""
$TYPEDEF
Container for a deterministic tuning strategy.
"""
mutable struct DeterministicTunedModel{T,M<:Deterministic} <: MLJ.Deterministic
model::M
tuning::T # tuning strategy
tuning::T # tuning strategy
resampling # resampling strategy
measure
operation
ranges::Union{Vector,ParamRange}
minimize::Bool
full_report::Bool
train_best::Bool
train_best::Bool
end

"""
$TYPEDEF
Container for a probabilistic tuning strategy.
"""
mutable struct ProbabilisticTunedModel{T,M<:Probabilistic} <: MLJ.Probabilistic
model::M
tuning::T # tuning strategy
tuning::T # tuning strategy
resampling # resampling strategy
measure
operation
Expand Down Expand Up @@ -76,13 +89,13 @@ function TunedModel(;model=nothing,
minimize=true,
full_report=true,
train_best=true)

!isempty(ranges) || error("You need to specify ranges=... ")
model !== nothing || error("You need to specify model=... ")

message = clean!(model)
isempty(message) || @info message

if model isa Deterministic
return DeterministicTunedModel(model, tuning, resampling,
measure, operation, ranges, minimize, full_report, train_best)
Expand Down Expand Up @@ -111,13 +124,13 @@ function MLJBase.fit(tuned_model::EitherTunedModel{Grid,M}, verbosity::Int, X, y
end

ranges isa Vector{<:ParamRange} ||
error("ranges must be a ParamRange object or a vector of "*
error("ranges must be a ParamRange object or a vector of " *
"ParamRange objects. ")


parameter_names = [string(r.field) for r in ranges]
scales = [scale(r) for r in ranges]

# the mutating model:
clone = deepcopy(tuned_model.model)

Expand All @@ -137,7 +150,7 @@ function MLJBase.fit(tuned_model::EitherTunedModel{Grid,M}, verbosity::Int, X, y
elseif range isa MLJ.NumericRange
MLJ.iterator(range, tuned_model.tuning.resolution)
else
throw(TypeError(:iterator, "", MLJ.ParamRange, rrange))
throw(TypeError(:iterator, "", MLJ.ParamRange, rrange))
end
end

Expand All @@ -154,24 +167,22 @@ function MLJBase.fit(tuned_model::EitherTunedModel{Grid,M}, verbosity::Int, X, y

# initialize search for best model:
best_model = deepcopy(tuned_model.model)
best_measurement =
tuned_model.minimize ? Inf : -Inf
s =
tuned_model.minimize ? 1 : -1
best_measurement = ifelse(tuned_model.minimize, Inf, -Inf)
s = ifelse(tuned_model.minimize, 1, -1)

# evaluate all the models using specified resampling:
# TODO: parallelize!

meter = Progress(N+1, dt=0, desc="Iterating over a $N-point grid: ",
barglyphs=BarGlyphs("[=> ]"), barlen=25, color=:yellow)
verbosity != 1 || next!(meter)
for i in 1:N

for i in 1:N
verbosity != 1 || next!(meter)

A_row = Tuple(A[i,:])

# new_params = copy(nested_iterators, A_row)
# new_params = copy(nested_iterators, A_row)

# mutate `clone` (the model to which `resampler` points):
for k in 1:n_iterators
Expand All @@ -181,19 +192,17 @@ function MLJBase.fit(tuned_model::EitherTunedModel{Grid,M}, verbosity::Int, X, y

if verbosity == 2
fit!(resampling_machine, verbosity=0)

else
fit!(resampling_machine, verbosity=verbosity-1)
end
e = mean(evaluate(resampling_machine))

if verbosity > 1
text = reduce(*, ["$(parameter_names[j])=$(A_row[j]) \t"
for j in 1:length(A_row)])
text = prod("$(parameter_names[j])=$(A_row[j]) \t" for j in 1:length(A_row))
text *= "measurement=$e"
println(text)
end

if s*(best_measurement - e) > 0
best_model = deepcopy(clone)
best_measurement = e
Expand All @@ -203,7 +212,7 @@ function MLJBase.fit(tuned_model::EitherTunedModel{Grid,M}, verbosity::Int, X, y
# models[i] = deepcopy(clone)
measurements[i] = e
end

end

if tuned_model.train_best
Expand Down Expand Up @@ -234,11 +243,11 @@ function MLJBase.fit(tuned_model::EitherTunedModel{Grid,M}, verbosity::Int, X, y
measurements=[best_measurement, ][1:0], # empty vector
best_measurement=best_measurement)
end

cache = nothing

return fitresult, cache, report

end

MLJBase.fitted_params(::EitherTunedModel, fitresult) = (best_model=fitresult.model,)
Expand Down Expand Up @@ -268,7 +277,7 @@ MLJBase.target_scitype_union(::Type{<:ProbabilisticTunedModel{T,M}}) where {T,M}
MLJBase.input_is_multivariate(::Type{<:ProbabilisticTunedModel{T,M}}) where {T,M} = MLJBase.input_is_multivariate(M)


## LEARNING CURVES
## LEARNING CURVES

"""
curve = learning_curve!(mach; resolution=30, resampling=Holdout(), measure=rms, operation=predict, range=nothing, n=1)
Expand Down Expand Up @@ -310,9 +319,9 @@ plot(curves.parameter_values, curves.measurements, xlab=curves.parameter_name)
"""
function learning_curve!(mach::Machine{<:Supervised};
resolution=30,
resampling=Holdout(),
measure=rms, operation=predict, range=nothing, verbosity=1, n=1)
resolution=30, resampling=Holdout(),
measure=rms, operation=predict,
range=nothing, verbosity=1, n=1)

range !== nothing || error("No param range specified. Use range=... ")

Expand All @@ -328,14 +337,10 @@ function learning_curve!(mach::Machine{<:Supervised};
parameter_name=report.parameter_names[1]
parameter_scale=report.parameter_scales[1]
parameter_values=[report.parameter_values[:, 1]...]
measurements_ =
n == 1 ? [measurements...] : measurements

measurements_ = (n == 1) ? [measurements...] : measurements

return (parameter_name=parameter_name,
parameter_scale=parameter_scale,
parameter_values=parameter_values,
measurements = measurements_)
end



0 comments on commit fb97aed

Please sign in to comment.