@@ -47,10 +47,10 @@ caches_data_by_default(m) = caches_data_by_default(typeof(m))
4747caches_data_by_default (:: Type ) = true
4848caches_data_by_default (:: Type{<:Symbol} ) = false
4949
50- mutable struct Machine{M,C} <: MLJType
50+ mutable struct Machine{M,OM, C} <: MLJType
5151
5252 model:: M
53- old_model # for remembering the model used in last call to `fit!`
53+ old_model:: OM # for remembering the model used in last call to `fit!`
5454 fitresult
5555 cache
5656
@@ -77,8 +77,11 @@ mutable struct Machine{M,C} <: MLJType
7777 function Machine (
7878 model:: M , args:: AbstractNode... ;
7979 cache= caches_data_by_default (model),
80- ) where M
81- mach = new {M,cache} (model)
80+ ) where M
81+ # In the case of symbolic model, machine cannot know the type of model to be fit
82+ # at time of construction:
83+ OM = M == Symbol ? Any : M
84+ mach = new {M,OM,cache} (model)
8285 mach. frozen = false
8386 mach. state = 0
8487 mach. args = args
@@ -115,7 +118,7 @@ any upstream dependencies in a learning network):
115118replace(mach, :args => (), :data => (), :data_resampled_data => (), :cache => nothing)
116119
117120"""
118- function Base. replace (mach:: Machine{<:Any,C} , field_value_pairs:: Pair... ) where C
121+ function Base. replace (mach:: Machine{<:Any,<:Any, C} , field_value_pairs:: Pair... ) where C
119122 # determined new `model` and `args` and build replacement dictionary:
120123 newfield_given_old = Dict (field_value_pairs) # to be extended
121124 fields_to_be_replaced = keys (newfield_given_old)
@@ -436,8 +439,8 @@ machines(::Source) = Machine[]
436439
437440# # DISPLAY
438441
439- _cache_status (:: Machine{<:Any,true} ) = " caches model-specific representations of data"
440- _cache_status (:: Machine{<:Any,false} ) = " does not cache data"
442+ _cache_status (:: Machine{<:Any,<:Any, true} ) = " caches model-specific representations of data"
443+ _cache_status (:: Machine{<:Any,<:Any, false} ) = " does not cache data"
441444
442445function Base. show (io:: IO , mach:: Machine )
443446 model = mach. model
502505# for getting model specific representation of the row-restricted
503506# training data from a machine, according to the value of the machine
504507# type parameter `C` (`true` or `false`):
505- _resampled_data (mach:: Machine{<:Any,true} , model, rows) = mach. resampled_data
506- function _resampled_data (mach:: Machine{<:Any,false} , model, rows)
508+ _resampled_data (mach:: Machine{<:Any,<:Any, true} , model, rows) = mach. resampled_data
509+ function _resampled_data (mach:: Machine{<:Any,<:Any, false} , model, rows)
507510 raw_args = map (N -> N (), mach. args)
508511 data = MMI. reformat (model, raw_args... )
509512 return selectrows (model, rows, data... )
@@ -518,6 +521,10 @@ err_no_real_model(mach) = ErrorException(
518521 """
519522)
520523
524+ err_missing_model (model) = ErrorException (
525+ " Specified `composite` model does not have `:$(model) ` as a field."
526+ )
527+
521528"""
522529 last_model(mach::Machine)
523530
@@ -605,7 +612,7 @@ more on these lower-level training methods.
605612
606613"""
607614function fit_only! (
608- mach:: Machine{<:Any,cache_data} ;
615+ mach:: Machine{<:Any,<:Any, cache_data} ;
609616 rows= nothing ,
610617 verbosity= 1 ,
611618 force= false ,
@@ -628,7 +635,8 @@ function fit_only!(
628635 # `getproperty(composite, mach.model)`:
629636 model = if mach. model isa Symbol
630637 isnothing (composite) && throw (err_no_real_model (mach))
631- mach. model in propertynames (composite)
638+ mach. model in propertynames (composite) ||
639+ throw (err_missing_model (model))
632640 getproperty (composite, mach. model)
633641 else
634642 mach. model
@@ -967,7 +975,7 @@ A machine returned by `serializable` is characterized by the property
967975See also [`restore!`](@ref), [`MLJBase.save`](@ref).
968976
969977"""
970- function serializable (mach:: Machine{<:Any, C} , model= mach. model; verbosity= 1 ) where C
978+ function serializable (mach:: Machine{<:Any,<:Any, C} , model= mach. model; verbosity= 1 ) where C
971979
972980 isdefined (mach, :fitresult ) || throw (ERR_SERIALIZING_UNTRAINED)
973981 mach. state == - 1 && return mach
0 commit comments