Skip to content

Commit 7ae5821

Browse files
committed
annotate type for old_model field of Machine type
oops
1 parent f01a03c commit 7ae5821

File tree

4 files changed

+27
-18
lines changed

4 files changed

+27
-18
lines changed

src/machines.jl

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@ caches_data_by_default(m) = caches_data_by_default(typeof(m))
4747
caches_data_by_default(::Type) = true
4848
caches_data_by_default(::Type{<:Symbol}) = false
4949

50-
mutable struct Machine{M,C} <: MLJType
50+
mutable struct Machine{M,OM,C} <: MLJType
5151

5252
model::M
53-
old_model # for remembering the model used in last call to `fit!`
53+
old_model::OM # for remembering the model used in last call to `fit!`
5454
fitresult
5555
cache
5656

@@ -77,8 +77,11 @@ mutable struct Machine{M,C} <: MLJType
7777
function Machine(
7878
model::M, args::AbstractNode...;
7979
cache=caches_data_by_default(model),
80-
) where M
81-
mach = new{M,cache}(model)
80+
) where M
81+
# In the case of symbolic model, machine cannot know the type of model to be fit
82+
# at time of construction:
83+
OM = M == Symbol ? Any : M
84+
mach = new{M,OM,cache}(model)
8285
mach.frozen = false
8386
mach.state = 0
8487
mach.args = args
@@ -115,7 +118,7 @@ any upstream dependencies in a learning network):
115118
replace(mach, :args => (), :data => (), :data_resampled_data => (), :cache => nothing)
116119
117120
"""
118-
function Base.replace(mach::Machine{<:Any,C}, field_value_pairs::Pair...) where C
121+
function Base.replace(mach::Machine{<:Any,<:Any,C}, field_value_pairs::Pair...) where C
119122
# determined new `model` and `args` and build replacement dictionary:
120123
newfield_given_old = Dict(field_value_pairs) # to be extended
121124
fields_to_be_replaced = keys(newfield_given_old)
@@ -436,8 +439,8 @@ machines(::Source) = Machine[]
436439

437440
## DISPLAY
438441

439-
_cache_status(::Machine{<:Any,true}) = "caches model-specific representations of data"
440-
_cache_status(::Machine{<:Any,false}) = "does not cache data"
442+
_cache_status(::Machine{<:Any,<:Any,true}) = "caches model-specific representations of data"
443+
_cache_status(::Machine{<:Any,<:Any,false}) = "does not cache data"
441444

442445
function Base.show(io::IO, mach::Machine)
443446
model = mach.model
@@ -502,8 +505,8 @@ end
502505
# for getting model specific representation of the row-restricted
503506
# training data from a machine, according to the value of the machine
504507
# type parameter `C` (`true` or `false`):
505-
_resampled_data(mach::Machine{<:Any,true}, model, rows) = mach.resampled_data
506-
function _resampled_data(mach::Machine{<:Any,false}, model, rows)
508+
_resampled_data(mach::Machine{<:Any,<:Any,true}, model, rows) = mach.resampled_data
509+
function _resampled_data(mach::Machine{<:Any,<:Any,false}, model, rows)
507510
raw_args = map(N -> N(), mach.args)
508511
data = MMI.reformat(model, raw_args...)
509512
return selectrows(model, rows, data...)
@@ -518,6 +521,10 @@ err_no_real_model(mach) = ErrorException(
518521
"""
519522
)
520523

524+
err_missing_model(model) = ErrorException(
525+
"Specified `composite` model does not have `:$(model)` as a field."
526+
)
527+
521528
"""
522529
last_model(mach::Machine)
523530
@@ -605,7 +612,7 @@ more on these lower-level training methods.
605612
606613
"""
607614
function fit_only!(
608-
mach::Machine{<:Any,cache_data};
615+
mach::Machine{<:Any,<:Any,cache_data};
609616
rows=nothing,
610617
verbosity=1,
611618
force=false,
@@ -628,7 +635,8 @@ function fit_only!(
628635
# `getproperty(composite, mach.model)`:
629636
model = if mach.model isa Symbol
630637
isnothing(composite) && throw(err_no_real_model(mach))
631-
mach.model in propertynames(composite)
638+
mach.model in propertynames(composite) ||
639+
throw(err_missing_model(model))
632640
getproperty(composite, mach.model)
633641
else
634642
mach.model
@@ -967,7 +975,7 @@ A machine returned by `serializable` is characterized by the property
967975
See also [`restore!`](@ref), [`MLJBase.save`](@ref).
968976
969977
"""
970-
function serializable(mach::Machine{<:Any, C}, model=mach.model; verbosity=1) where C
978+
function serializable(mach::Machine{<:Any,<:Any,C}, model=mach.model; verbosity=1) where C
971979

972980
isdefined(mach, :fitresult) || throw(ERR_SERIALIZING_UNTRAINED)
973981
mach.state == -1 && return mach

src/operations.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,12 @@ for operation in OPERATIONS
7474
operation == :inverse_transform && continue
7575

7676
ex = quote
77-
function $(operation)(mach::Machine{<:Model,false}; rows=:)
77+
function $(operation)(mach::Machine{<:Model,<:Any,false}; rows=:)
7878
# catch deserialized machine with no data:
7979
isempty(mach.args) && throw(err_serialized($operation))
8080
return ($operation)(mach, mach.args[1](rows=rows))
8181
end
82-
function $(operation)(mach::Machine{<:Model,true}; rows=:)
82+
function $(operation)(mach::Machine{<:Model,<:Any,true}; rows=:)
8383
# catch deserialized machine with no data:
8484
isempty(mach.args) && throw(err_serialized($operation))
8585
model = last_model(mach)
@@ -92,8 +92,10 @@ for operation in OPERATIONS
9292
end
9393

9494
# special case of Static models (no training arguments):
95-
$operation(mach::Machine{<:Static,true}; rows=:) = throw(ERR_ROWS_NOT_ALLOWED)
96-
$operation(mach::Machine{<:Static,false}; rows=:) = throw(ERR_ROWS_NOT_ALLOWED)
95+
$operation(mach::Machine{<:Static,<:Any,true}; rows=:) =
96+
throw(ERR_ROWS_NOT_ALLOWED)
97+
$operation(mach::Machine{<:Static,<:Any,false}; rows=:) =
98+
throw(ERR_ROWS_NOT_ALLOWED)
9799
end
98100
eval(ex)
99101

src/resampling.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1106,7 +1106,7 @@ end
11061106
@static if VERSION >= v"1.3.0-DEV.573"
11071107

11081108
# determines if an instantiated machine caches data:
1109-
_caches_data(::Machine{M, C}) where {M, C} = C
1109+
_caches_data(::Machine{<:Any,<:Any,C}) where C = C
11101110

11111111
function _evaluate!(func, mach, accel::CPUThreads, nfolds, verbosity)
11121112

test/machines.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,6 @@ end
272272
X = ones(2, 3)
273273

274274
mach = @test_logs machine(Scale(2))
275-
@test mach isa Machine{Scale, false}
276275
transform(mach, X) # triggers training of `mach`, ie is mutating
277276
@test report(mach) in [nothing, NamedTuple()]
278277
@test isnothing(fitted_params(mach))

0 commit comments

Comments
 (0)