Skip to content

Commit

Permalink
Merge pull request #178 from rikhuijzer/rh/instantiate-models
Browse files Browse the repository at this point in the history
Clarify error for uninstantiated model further
  • Loading branch information
ablaom committed Jun 8, 2022
2 parents 0f50eca + 69f0fae commit 9436c17
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 10 deletions.
2 changes: 1 addition & 1 deletion Project.toml
@@ -1,7 +1,7 @@
name = "MLJTuning"
uuid = "03970b2e-30c4-11ea-3135-d1576263f10f"
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>"]
version = "0.7.1"
version = "0.7.2"

[deps]
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
Expand Down
11 changes: 9 additions & 2 deletions src/tuned_models.jl
Expand Up @@ -19,6 +19,10 @@ const ERR_BOTH_DISALLOWED = ArgumentError(
"You cannot specify both `model` and `models`. ")
const ERR_MODEL_TYPE = ArgumentError(
"Only `Deterministic` and `Probabilistic` model types supported.")
const ERR_UNINSTANTIATED_MODEL = AssertionError(
"Type encountered where model instance expected. (Tuning evaluates "*
"models by mutating clones of the provided instance, as specified "*
"by `range`.) ")
const INFO_MODEL_IGNORED =
"`model` being ignored. Using `model=first(range)`. "
const ERR_TOO_MANY_ARGUMENTS =
Expand Down Expand Up @@ -260,11 +264,11 @@ function TunedModel(args...; model=nothing,

# user can specify model as argument instead of kwarg:
length(args) < 2 || throw(ERR_TOO_MANY_ARGUMENTS)
if length(args) === 1
if length(args) == 1
arg = first(args)
model === nothing ||
@warn warn_double_spec(arg, model)
model =arg
model = arg
end

# either `models` is specified and `tuning` is set to `Explicit`,
Expand Down Expand Up @@ -309,7 +313,10 @@ function TunedModel(args...; model=nothing,
else
throw(ERR_MODEL_TYPE)
end
elseif model isa Type
throw(ERR_UNINSTANTIATED_MODEL)
else
# Model is probably an instantiated model.
M = typeof(model)
end

Expand Down
6 changes: 1 addition & 5 deletions test/runtests.jl
Expand Up @@ -9,11 +9,7 @@ using StableRNGs
# Display Number of processes and if necessary number
# of Threads
@info "nworkers: $(nworkers())"
@static if VERSION >= v"1.3.0-DEV.573"
@info "nthreads: $(Threads.nthreads())"
else
@info "Running julia $(VERSION). Multithreading tests excluded. "
end
@info "nthreads: $(Threads.nthreads())"

include("test_utilities.jl")

Expand Down
6 changes: 4 additions & 2 deletions test/tuned_models.jl
Expand Up @@ -17,10 +17,10 @@ N = 30
x1 = rand(N);
x2 = rand(N);
x3 = rand(N);
X = (x1=x1, x2=x2, x3=x3);
X = (; x1, x2, x3);
y = 2*x1 .+ 5*x2 .- 3*x3 .+ 0.4*rand(N);

m(K) = KNNRegressor(K=K)
m(K) = KNNRegressor(; K)
r = [m(K) for K in 13:-1:2]

# TODO: replace the above with the line below and post an issue on
Expand Down Expand Up @@ -66,6 +66,8 @@ r = [m(K) for K in 13:-1:2]
tm = @test_logs TunedModel(model=first(r), range=r, measure=rms)
@test tm.tuning isa RandomSearch
@test input_scitype(tm) == Table(Continuous)

@test_throws MLJTuning.ERR_UNINSTANTIATED_MODEL TunedModel(; model=KNNRegressor, range=r)
end

results = [(evaluate(model, X, y,
Expand Down

0 comments on commit 9436c17

Please sign in to comment.