Skip to content

Commit

Permalink
put mutations in clean! method
Browse files Browse the repository at this point in the history
  • Loading branch information
OkonSamuel committed Jun 11, 2020
1 parent accd57f commit 0f1a480
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions src/tuned_models.jl
Expand Up @@ -159,26 +159,19 @@ function TunedModel(; model=nothing,
model == nothing && error("You need to specify model=... .\n"*
"If `tuning=Explicit()`, any model in the "*
"range will do. ")

if (acceleration isa CPUThreads &&
acceleration_resampling isa CPUProcesses)
acceleration = CPUProcesses()
acceleration_resampling = CPUThreads()
end
_acceleration = _process_accel_settings(acceleration)

if model isa Deterministic
tuned_model = DeterministicTunedModel(model, tuning, resampling,
measure, weights, operation, range,
train_best, repeats, n,
_acceleration,
acceleration,
acceleration_resampling,
check_measure)
elseif model isa Probabilistic
tuned_model = ProbabilisticTunedModel(model, tuning, resampling,
measure, weights, operation, range,
train_best, repeats, n,
_acceleration,
acceleration,
acceleration_resampling,
check_measure)
else
Expand All @@ -205,6 +198,7 @@ function MLJBase.clean!(tuned_model::EitherTunedModel)
"Setting measure=$(tuned_model.measure). "
end
end

if (tuned_model.acceleration isa CPUProcesses &&
tuned_model.acceleration_resampling isa CPUProcesses)
message *=
Expand All @@ -213,7 +207,8 @@ function MLJBase.clean!(tuned_model::EitherTunedModel)
" not generally optimal. You may want to consider setting"*
" `acceleration = CPUProcesses()` and"*
" `acceleration_resampling = CPUThreads()`."
end
end

if (tuned_model.acceleration isa CPUThreads &&
tuned_model.acceleration_resampling isa CPUProcesses)
message *=
Expand All @@ -222,7 +217,13 @@ function MLJBase.clean!(tuned_model::EitherTunedModel)
" supported. \n Resetting to"*
" `acceleration = CPUProcesses()` and"*
" `acceleration_resampling = CPUThreads()`."
end

tuned_model.acceleration = CPUProcesses()
tuned_model.acceleration_resampling = CPUThreads()
end

tuned_model.acceleration =
_process_accel_settings(tuned_model.acceleration)

return message
end
Expand Down

0 comments on commit 0f1a480

Please sign in to comment.