!!! warning
Models implementing the MLJ model interface according to the instructions
given here should import MLJModelInterface version 0.3 or higher. This is
enforced with a statement such as `MLJModelInterface = "^0.3" under
`[compat]` in the Project.toml file of the package containing the
implementation.
This guide outlines the specification of the MLJ model interface and provides detailed guidelines for implementing the interface for models intended for general use. See also the more condensed Quick-Start Guide to Adding Models.
For sample implementations, see MLJModels/src.
The machine learning tools provided by MLJ can be applied to the models in any package that imports the package MLJModelInterface and implements the API defined there, as outlined below. For a quick-and-dirty implementation of user-defined models see Simple User Defined Models. To make new models available to all MLJ users, see Where to place code implementing new models.
MLJModelInterface
is a very light-weight interface allowing you to define your
interface, but does not provide the functionality required to use or
test your interface; this requires
MLJBase. So,
while you only need to add MLJModelInterface
to your project's
[deps], for testing purposes you need to add
MLJBase to your
project's [extras] and [targets]. In testing, simply use MLJBase
in
place of MLJModelInterface
.
It is assumed the reader has read Getting Started. To implement the API described here, some familiarity with the following packages is also helpful:
-
MLJScientificTypes.jl (for specifying model requirements of data)
-
Distributions.jl (for probabilistic predictions)
-
CategoricalArrays.jl (essential if you are implementing a model handling data of
Multiclass
orOrderedFactor
scitype; familiarity withCategoricalPool
objects required) -
Tables.jl (if your algorithm needs input data in a novel format).
In MLJ, the basic interface exposed to the user, built atop the model interface described here, is the machine interface. After a first reading of this document, the reader may wish to refer to MLJ Internals for context.
A model is an object storing hyperparameters associated with some
machine learning algorithm. In MLJ, hyperparameters include configuration
parameters, like the number of threads, and special instructions, such
as "compute feature rankings", which may or may not affect the final
learning outcome. However, the logging level (verbosity
below) is
excluded.
The name of the Julia type associated with a model indicates the
associated algorithm (e.g., DecisionTreeClassifier
). The outcome of
training a learning algorithm is called a fitresult. For
ordinary multivariate regression, for example, this would be the
coefficients and intercept. For a general supervised model, it is the
(generally minimal) information needed to make new predictions.
The ultimate supertype of all models is MLJModelInterface.Model
, which
has two abstract subtypes:
abstract type Supervised <: Model end
abstract type Unsupervised <: Model end
Supervised
models are further divided according to whether they are
able to furnish probabilistic predictions of the target (which they
will then do by default) or directly predict "point" estimates, for each
new input pattern:
abstract type Probabilistic <: Supervised end
abstract type Deterministic <: Supervised end
Further division of model types is realized through Trait declarations.
Associated with every concrete subtype of Model
there must be a
fit
method, which implements the associated algorithm to produce the
fitresult. Additionally, every Supervised
model has a predict
method, while Unsupervised
models must have a transform
method. More generally, methods such as these, that are dispatched on
a model instance and a fitresult (plus other data), are called
operations. Probabilistic
supervised models optionally implement a
predict_mode
operation (in the case of classifiers) or a
predict_mean
and/or predict_median
operations (in the case of
regressors) although MLJModelInterface also provides fallbacks that will suffice
in most cases. Unsupervised
models may implement an
inverse_transform
operation.
Here is an example of a concrete supervised model type declaration:
import MLJModelInterface
const MMI = MLJModelInterface
mutable struct RidgeRegressor <: MMI.Deterministic
lambda::Float64
end
Models (which are mutable) should not be given internal
constructors. It is recommended that they be given an external lazy
keyword constructor of the same name. This constructor defines default values
for every field, and optionally corrects invalid field values by calling a
clean!
method (whose fallback returns an empty message string):
function MMI.clean!(model::RidgeRegressor)
warning = ""
if model.lambda < 0
warning *= "Need lambda ≥ 0. Resetting lambda=0. "
model.lambda = 0
end
return warning
end
# keyword constructor
function RidgeRegressor(; lambda=0.0)
model = RidgeRegressor(lambda)
message = MMI.clean!(model)
isempty(message) || @warn message
return model
end
An alternative to declaring the model struct, clean! method and keyword
constructor, is to use the @mlj_model
macro, as in the following example:
@mlj_model mutable struct YourModel <: MMI.Deterministic
a::Float64 = 0.5::(_ > 0)
b::String = "svd"::(_ in ("svd","qr"))
end
This declaration specifies:
- A keyword constructor (here
YourModel(; a=..., b=...)
), - Default values for the hyperparameters,
- Constraints on the hyperparameters where
_
refers to a value passed.
For example, a::Float64 = 0.5::(_ > 0)
indicates that
the field a
is a Float64
, takes 0.5
as default value, and
expects its value to be positive.
You cannot use the @mlj_model
macro if your model struct has type
parameters.
The compulsory and optional methods to be implemented for each
concrete type SomeSupervisedModel <: MMI.Supervised
are
summarized below. An =
indicates the return value for a fallback
version of the method.
Compulsory:
MMI.fit(model::SomeSupervisedModel, verbosity::Integer, X, y) -> fitresult, cache, report
MMI.predict(model::SomeSupervisedModel, fitresult, Xnew) -> yhat
Optional, to check and correct invalid hyperparameter values:
MMI.clean!(model::SomeSupervisedModel) = ""
Optional, to return user-friendly form of fitted parameters:
MMI.fitted_params(model::SomeSupervisedModel, fitresult) = fitresult
Optional, to avoid redundant calculations when re-fitting machines associated with a model:
MMI.update(model::SomeSupervisedModel, verbosity, old_fitresult, old_cache, X, y) =
MMI.fit(model, verbosity, X, y)
Optional, to specify default hyperparameter ranges (for use in tuning):
MMI.hyperparameter_ranges(T::Type) = Tuple(fill(nothing, length(fieldnames(T))))
Optional, if SomeSupervisedModel <: Probabilistic
:
MMI.predict_mode(model::SomeSupervisedModel, fitresult, Xnew) =
mode.(predict(model, fitresult, Xnew))
MMI.predict_mean(model::SomeSupervisedModel, fitresult, Xnew) =
mean.(predict(model, fitresult, Xnew))
MMI.predict_median(model::SomeSupervisedModel, fitresult, Xnew) =
median.(predict(model, fitresult, Xnew))
Required, if the model is to be registered (findable by general users):
MMI.load_path(::Type{<:SomeSupervisedModel}) = ""
MMI.package_name(::Type{<:SomeSupervisedModel}) = "Unknown"
MMI.package_uuid(::Type{<:SomeSupervisedModel}) = "Unknown"
MMI.input_scitype(::Type{<:SomeSupervisedModel}) = Unknown
Strongly recommended, to constrain the form of target data passed to fit:
MMI.target_scitype(::Type{<:SomeSupervisedModel}) = Unknown
Optional but recommended:
MMI.package_url(::Type{<:SomeSupervisedModel}) = "unknown"
MMI.is_pure_julia(::Type{<:SomeSupervisedModel}) = false
MMI.package_license(::Type{<:SomeSupervisedModel}) = "unknown"
If SomeSupervisedModel
supports sample weights, then instead of the fit
above, one implements
MMI.fit(model::SomeSupervisedModel, verbosity::Integer, X, y, w=nothing) -> fitresult, cache, report
and, if appropriate
MMI.update(model::SomeSupervisedModel, verbosity, old_fitresult, old_cache, X, y, w=nothing) =
MMI.fit(model, verbosity, X, y, w)
Additionally, if SomeSupervisedModel
supports sample weights, one must declare
MMI.supports_weights(model::Type{<:SomeSupervisedModel}) = true
The model implementer does not have absolute control over the types of
data X
, y
and Xnew
appearing in the fit
and predict
methods
they must implement. Rather, they can specify the scientific type of
this data by making appropriate declarations of the traits
input_scitype
and target_scitype
discussed later under Trait
declarations.
Important Note. Unless it genuinely makes little sense to do so, the
MLJ recommendation is to specify a Table
scientific type for X
(and hence Xnew
) and an AbstractVector
scientific type (e.g.,
AbstractVector{Continuous}
) for targets y
. Algorithms requiring
matrix input can coerce their inputs appropriately; see below.
If the core algorithm being wrapped requires data in a different or
more specific form, then fit
will need to coerce the table into the
form desired (and the same coercions applied to X
will have to be
repeated for Xnew
in predict
). To assist with common cases, MLJ
provides the convenience method
MMI.matrix
. MMI.matrix(Xtable)
has type Matrix{T}
where
T
is the tightest common type of elements of Xtable
, and Xtable
is any table.
Other auxiliary methods provided by MLJModelInterface for handling tabular data
are: selectrows
, selectcols
, select
and schema
(for extracting
the size, names and eltypes of a table's columns). See Convenience
methods below for details.
It is to be understood that the columns of the table X
correspond to
features and the rows to observations. So, for example, the predict
method for a linear regression model might look like predict(model, w, Xnew) = MMI.matrix(Xnew)*w
, where w
is the vector of learned
coefficients.
A compulsory fit
method returns three objects:
MMI.fit(model::SomeSupervisedModel, verbosity::Int, X, y) -> fitresult, cache, report
Note. The Int
typing of verbosity
cannot be omitted.
-
fitresult
is the fitresult in the sense above (which becomes an argument forpredict
discussed below). -
report
is a (possibly empty)NamedTuple
, for example,report=(deviance=..., dof_residual=..., stderror=..., vcov=...)
. Any training-related statistics, such as internal estimates of the generalization error, and feature rankings, should be returned in thereport
tuple. How, or if, these are generated should be controlled by hyperparameters (the fields ofmodel
). Fitted parameters, such as the coefficients of a linear model, do not go in the report as they will be extractable fromfitresult
(and accessible to MLJ through thefitted_params
method described below). -
The value of
cache
can benothing
, unless one is also defining anupdate
method (see below). The Julia type ofcache
is not presently restricted.
It is not necessary for fit
to provide type or dimension checks on
X
or y
or to call clean!
on the model; MLJ will carry out such
checks.
The method fit
should never alter hyperparameter values, the sole
exception being fields of type <:AbstractRNG
. If the package is able
to suggest better hyperparameters, as a byproduct of training, return
these in the report field.
The verbosity
level (0 for silent) is for passing to learning
algorithm itself. A fit
method wrapping such an algorithm should
generally avoid doing any of its own logging.
Sample weight support. If
supports_weights(::Type{<:SomeSupervisedModel})
has been declared
true
, then one instead implements the following variation on the
above fit
:
MMI.fit(model::SomeSupervisedModel, verbosity::Int, X, y, w=nothing) -> fitresult, cache, report
A fitted_params
method may be optionally overloaded. It's purpose is
to provide MLJ access to a user-friendly representation of the
learned parameters of the model (as opposed to the
hyperparameters). They must be extractable from fitresult
.
MMI.fitted_params(model::SomeSupervisedModel, fitresult) -> friendly_fitresult::NamedTuple
For a linear model, for example, one might declare something like
friendly_fitresult=(coefs=[...], bias=...)
.
The fallback is to return (fitresult=fitresult,)
.
A compulsory predict
method has the form
MMI.predict(model::SomeSupervisedModel, fitresult, Xnew) -> yhat
Here Xnew
will have the same form as the X
passed to fit
.
In the case of Deterministic
models, yhat
should have the same
scitype as the y
passed to fit
(see above). Any CategoricalValue
or CategoricalString
elements of yhat
must have a pool == to the
pool of the target y
presented in training, even if not all levels
appear in the training data or prediction itself. For example, in the
case of a univariate target, such as scitype(y) <: AbstractVector{Multiclass{3}}
, one requires MLJ.classes(yhat[i]) == MLJ.classes(y[j])
for all admissible i
and j
. (The method
classes
is described under Convenience methods below).
Unfortunately, code not written with the preservation of categorical
levels in mind poses special problems. To help with this, MLJModelInterface
provides three utility methods: int
(for converting a
CategoricalValue
or CategoricalString
into an integer, the
ordering of these integers being consistent with that of the pool),
decoder
(for constructing a callable object that decodes the
integers back into CategoricalValue
/CategoricalString
objects),
and classes
, for extracting all the CategoricalValue
or
CategoricalString
objects sharing the pool of a particular
value. Refer to Convenience methods below for important
details.
Note that a decoder created during fit
may need to be bundled with
fitresult
to make it available to predict
during re-encoding. So,
for example, if the core algorithm being wrapped by fit
expects a
nominal target yint
of type Vector{<:Integer}
then a fit
method
may look something like this:
function MMI.fit(model::SomeSupervisedModel, verbosity, X, y)
yint = MMI.int(y)
a_target_element = y[1] # a CategoricalValue/String
decode = MMI.decoder(a_target_element) # can be called on integers
core_fitresult = SomePackage.fit(X, yint, verbosity=verbosity)
fitresult = (decode, core_fitresult)
cache = nothing
report = nothing
return fitresult, cache, report
end
while a corresponding deterministic predict
operation might look like this:
function MMI.predict(model::SomeSupervisedModel, fitresult, Xnew)
decode, core_fitresult = fitresult
yhat = SomePackage.predict(core_fitresult, Xnew)
return decode.(yhat) # or decode(yhat) also works
end
For a concrete example, refer to the
code
for SVMClassifier
.
Of course, if you are coding a learning algorithm from scratch, rather than wrapping an existing one, these extra measures may be unnecessary.
In the case of Probabilistic
models with univariate targets, yhat
must be an AbstractVector
whose elements are distributions (one distribution
per row of Xnew
).
Presently, a distribution is any object d
for which
MMI.isdistribution(::d) = true
, which is the case for objects of
type Distributions.Sampleable
.
Use the distribution MMI.UnivariateFinite
for Probabilistic
models
predicting a target with Finite
scitype (classifiers). In this case
the eltype of the training target y
will be a CategoricalValue
.
For efficiency, one should not construct UnivariateDistribution
instances one at a time. Rather, once a probability vector or matrix
is known, construct an instance of UnivariateFiniteVector <: AbstractArray{<:UnivariateFinite},1}
to return. Both UnivariateFinite
and UnivariateFiniteVector
objects are constructed using the single
UnivariateFinite
function.
For example, suppose the target y
arrives as a subsample of some
ybig
and is missing some classes:
ybig = categorical([:a, :b, :a, :a, :b, :a, :rare, :a, :b])
y = ybig[1:6]
Your fit method has bundled the first element of y
with the
fitresult
to make it available to predict
for purposes of tracking
the complete pool of classes. Let's call this an_element = y[1]
. Then, supposing the corresponding probabilities of the observed
classes [:a, :b]
are in an n x 2
matrix probs
(where n
the number of
rows of Xnew
) then you return
yhat = UnivariateFinite([:a, :b], probs, pool=an_element)
This object automatically assigns zero-probability to the unseen class
:rare
(i.e., pdf.(yhat, :rare)
works and returns a zero
vector). If you would like to assign :rare
non-zero probabilities,
simply add it to the first vector (the support) and supply a larger
probs
matrix.
If instead of raw labels [:a, :b]
you have the corresponding
CategoricalElement
s (from, e.g., filter(cv->cv in unique(y), classes(y))
) then you can use these instead and drop the pool
specifier.
In a binary classification problem it suffices to specify a single
vector of probabilities, provided you specify augment=true
, as in
the following example, and note carefully that these probablities are
associated with the last (second) class you specify in the
constructor:
y = categorical([:TRUE, :FALSE, :FALSE, :TRUE, :TRUE])
an_element = y[1]
probs = rand(10)
yhat = UnivariateFinite([:FALSE, :TRUE], probs, augment=true, pool=an_element)
The constructor has a lot of options, including passing a dictionary
instead of vectors. See UnivariateFinite
for details.
See LinearBinaryClassifier for an example of a Probabilistic classifier implementation.
Important note on binary classifiers. There is no "Binary" scitype
distinct from Multiclass{2}
or OrderedFactor{2}
; Binary
is just
an alias for Union{Multiclass{2},OrderedFactor{2}}
. The
target_scitype
of a binary classifier will generally be
AbstractVector{<:Binary}
and according to the mlj scitype
convention, elements of y
have type CategoricalValue
, and not
Bool
. See
BinaryClassifier
for an example.
Two trait functions allow the implementer to restrict the types of
data X
, y
and Xnew
discussed above. The MLJ task interface uses
these traits for data type checks but also for model search. If they
are omitted (and your model is registered) then a general user may
attempt to use your model with inappropriately typed data.
The trait functions input_scitype
and target_scitype
take
scientific data types as values. We assume here familiarity with
MLJScientificTypes.jl
(see Getting Started for the basics).
For example, to ensure that the X
presented to the
DecisionTreeClassifier
fit
method is a table whose columns all have Continuous
element type
(and hence AbstractFloat
machine type), one declares
MMI.input_scitype(::Type{<:DecisionTreeClassifier}) = MMI.Table(MMI.Continuous)
or, equivalently,
MMI.input_scitype(::Type{<:DecisionTreeClassifier}) = Table(Continuous)
If, instead, columns were allowed to have either: (i) a mixture of Continuous
and Missing
values, or (ii) Count
(i.e., integer) values, then the
declaration would be
MMI.input_scitype(::Type{<:DecisionTreeClassifier}) = Table(Union{Continuous,Missing},Count)
Similarly, to ensure the target is an AbstractVector whose elements
have Finite
scitype (and hence CategoricalValue
machine type) we declare
MMI.target_scitype(::Type{<:DecisionTreeClassifier}) = AbstractVector{<:Finite}
The above remarks continue to hold unchanged for the case multivariate targets. For example, if we declare
target_scitype(SomeSupervisedModel) = Table(Continuous)
then this constrains the target to be any table whose columns have Continous
element scitype (i.e., AbstractFloat
), while
target_scitype(SomeSupervisedModel) = Table(Continuous, Finite{2})
restricts to tables with continuous or binary (ordered or unordered) columns.
For predicting variable length sequences of, say, binary values
(CategoricalValue
s) with some common size-two pool) we declare
target_scitype(SomeSupervisedModel) = AbstractVector{<:NTuple{<:Finite{2}}}
The trait functions controlling the form of data are summarized as follows:
method | return type | declarable return values | fallback value |
---|---|---|---|
input_scitype |
Type |
some scientfic type | Unknown |
target_scitype |
Type |
some scientific type | Unknown |
Additional trait functions tell MLJ's @load
macro how to find your
model if it is registered, and provide other self-explanatory metadata
about the model:
method | return type | declarable return values | fallback value |
---|---|---|---|
load_path |
String |
unrestricted | "unknown" |
package_name |
String |
unrestricted | "unknown" |
package_uuid |
String |
unrestricted | "unknown" |
package_url |
String |
unrestricted | "unknown" |
package_license |
String |
unrestricted | "unknown" |
is_pure_julia |
Bool |
true or false |
false |
supports_weights |
Bool |
true or false |
false |
New. A final trait you can optionally implement is the
hyperparamter_ranges
trait. It declares default ParamRange
objects
for one or more of your model's hyperparameters. This is for use (in
the future) by tuning algorithms (e.g., grid generation). It does not
represent the full space of allowed values. This information is
encoded in your clean!
method (or @mlj_model
call).
The value returned by hyperparamter_ranges
must be a tuple of
ParamRange
objects (query ?range
for details) whose length is the
number of hyperparameters (fields of your model). Note that varying a
hyperparameter over a specified range should not alter any type
parameters in your model struct (this never applies to numeric
ranges). If it doesn't make sense to provide a range for a parameter,
a nothing
entry is allowed. The fallback returns a tuple of
nothing
s.
For example, a three parameter model of the form
mutable struct MyModel{D} <: Deterministic
alpha::Float64
beta::Int
distribution::D
end
you might declare (order matters):
MMI.hyperparameter_ranges(::Type{<:MyModel}) =
(range(Float64, :alpha, lower=0, upper=1, scale=:log),
range(Int, :beta, lower=1, upper=Inf, origin=100, unit=50, scale=:log),
nothing)
Here is the complete list of trait function declarations for DecisionTreeClassifier
(source):
MMI.input_scitype(::Type{<:DecisionTreeClassifier}) = MMI.Table(MMI.Continuous)
MMI.target_scitype(::Type{<:DecisionTreeClassifier}) = AbstractVector{<:MMI.Finite}
MMI.load_path(::Type{<:DecisionTreeClassifier}) = "MLJModels.DecisionTree_.DecisionTreeClassifier"
MMI.package_name(::Type{<:DecisionTreeClassifier}) = "DecisionTree"
MMI.package_uuid(::Type{<:DecisionTreeClassifier}) = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
MMI.package_url(::Type{<:DecisionTreeClassifier}) = "https://github.com/bensadeghi/DecisionTree.jl"
MMI.is_pure_julia(::Type{<:DecisionTreeClassifier}) = true
Alternatively these traits can also be declared using MMI.metadata_pkg
and MMI.metadata_model
helper functions as:
MMI.metadata_pkg(DecisionTreeClassifier,name="DecisionTree",
uuid="7806a523-6efd-50cb-b5f6-3fa6f1930dbb",
url="https://github.com/bensadeghi/DecisionTree.jl",
julia=true)
MMI.metadata_model(DecisionTreeClassifier,
input=MMI.Table(MMI.Continuous),
target=AbstractVector{<:MMI.Finite},
path="MLJModels.DecisionTree_.DecisionTreeClassifier")
Important. Do not omit the path
specifcation.
MMI.metadata_pkg
MMI.metadata_model
You can test all your declarations of traits by calling MLJBase.info_dict(SomeModel)
.
An update
method may be optionally overloaded to enable a call by
MLJ to retrain a model (on the same training data) to avoid repeating
computations unnecessarily.
MMI.update(model::SomeSupervisedModel, verbosity, old_fitresult, old_cache, X, y) -> fit
result, cache, report
MMI.update(model::SomeSupervisedModel, verbosity, old_fitresult, old_cache, X, y, w=nothing) -> fit
result, cache, report
Here the second variation applies if SomeSupervisedModel
supports
sample weights.
If an MLJ Machine
is being fit!
and it is not the first time, then
update
is called instead of fit
, unless the machine fit!
has
been called with a new rows
keyword argument. However, MLJModelInterface
defines a fallback for update
which just calls fit
. For context,
see MLJ Internals.
Learning networks wrapped as models constitute one use-case (see
Composing Models): one would like each component model to
be retrained only when hyperparameter changes "upstream" make this
necessary. In this case MLJ provides a fallback (specifically, the
fallback is for any subtype of SupervisedNetwork = Union{DeterministicNetwork,ProbabilisticNetwork}
). A second more
generally relevant use-case is iterative models, where calls to
increase the number of iterations only restarts the iterative
procedure if other hyperparameters have also changed. (A useful method
for inspecting model changes in such cases is
MLJModelInterface.is_same_except
. ) For an example, see the MLJ ensemble
code.
A third use-case is to avoid repeating time-consuming preprocessing of
X
and y
required by some models.
In the event that the argument fitresult
(returned by a preceding
call to fit
) is not sufficient for performing an update, the author
can arrange for fit
to output in its cache
return value any
additional information required (for example, pre-processed versions
of X
and y
), as this is also passed as an argument to the update
method.
A supervised model may optionally implement a transform
method,
whose signature is the same as predict
. In that case the
implementation should define a value for the output_scitype
trait. A
declaration
output_scitype(::Type{<:SomeSupervisedModel}) = T
is an assurance that scitype(transform(model, Xnew, fitresult)) <: T
always holds, for any model
of type SomeSupervisedModel
.
A use-case for a transform
method for a supervised model is a neural
network that learns feature embeddings for categorical input
features as part of overall training. Such a model becomes a
transformer that other supervised models can use to transform the
categorical features (instead of applying the higher-dimensional one-hot
encoding representations).
Unsupervised models implement the MLJ model interface in a very similar fashion. The main differences are:
-
The
fit
method has only one training argumentX
, as inMLJModelInterface.fit(model, verbosity::Int, X)
. However, it has the same return value(fitresult, cache, report)
. Anupdate
method (e.g., for iterative models) can be optionally implemented in the same way. -
A
transform
method is compulsory and has the same signature aspredict
, as inMLJModelInterface.transform(model, Xnew, fitresult)
. -
Instead of defining the
target_scitype
trait, one declares anoutput_scitype
trait (see above for the meaning). -
An
inverse_transform
can be optionally implemented. The signature is the same astransform
, as inMLJModelInterface.inverse_transform(model, Xout, fitresult)
, which:-
must make sense for any
Xout
for whichscitype(Xout) <: output_scitype(SomeSupervisedModel)
(see below); and -
must return an object
Xin
satisfyingscitype(Xin) <: input_scitype(SomeSupervisedModel)
.
-
-
A
predict
method may be optionally implemented, and has the same signature as for supervised models, as inMLJModelInterface.predict(model, Xnew, fitresult)
. A use-case is clustering algorithms thatpredict
labels andtransform
new input features into a space of lower-dimension. See Transformers that also predict for an example.
!!! warning "Experimental"
The following API is experimental
Models that learn a probability distribution, or more generally a
"sampler" object, should be regarded as Supervised
models that fit a
distribution to the target y
, given a void input feature, X = nothing
. Here is a working implementation of a model to fit any
distribution from the
Distributions.jl
package to some data y
, illustrating the idea (trait declarations
omitted):
# Implmentation:
mutable struct DistributionFitter{D<:Distributions.Distribution} <: Supervised
distribution::D
end
DistributionFitter(; distribution=Distributions.Normal()) =
DistributionFitter(distribution)
function MLJModelInterface.fit(model::DistributionFitter{D},
verbosity::Int,
::Nothing,
y) where D
fitresult = Distributions.fit(D, y)
report = (params=Distributions.params(fitresult),)
cache = nothing
verbosity > 0 && @info "Fitted a $fitresult"
return fitresult, cache, report
end
MLJModelInterface.predict(model::DistributionFitter,
fitresult,
::Nothing) = fitresult
# Example use:
yhat = randn(100)
mach = machine(DistributionFitter(), nothing, y) |> fit!
yhat = predict(mach, nothing)
@assert yhat isa Distributions.Normal
MLJModelInterface.int
MLJModelInterface.classes
MLJModelInterface.decoder
MLJModelInterface.matrix
MLJModelInterface.table
MLJModelInterface.select
MLJModelInterface.selectrows
MLJModelInterface.selectcols
UnivariateFinite
Note that different packages can implement models having the same name without causing conflicts, although an MLJ user cannot simultaneously load two such models.
There are two options for making a new model implementation available to all MLJ users:
-
Native implementations (preferred option). The implementation code lives in the same package that contains the learning algorithms implementing the interface. In this case, it is sufficient to open an issue at MLJ requesting the package to be registered with MLJ. Registering a package allows the MLJ user to access its models' metadata and to selectively load them.
-
External implementations (short-term alternative). The model implementation code is necessarily separate from the package
SomePkg
defining the learning algorithm being wrapped. In this case, the recommended procedure is to include the implementation code at MLJModels/src via a pull-request, and test code at MLJModels/test. AssumingSomePkg
is the only package imported by the implementation code, one needs to: (i) registerSomePkg
with MLJ as explained above; and (ii) add a corresponding@require
line in the PR to MLJModels/src/MLJModels.jl to enable lazy-loading of that package by MLJ (following the pattern of existing additions). If other packages must be imported, add them to the MLJModels project file after checking they are not already there. If it is really necessary, packages can be also added to Project.toml for testing purposes.
Additionally, one needs to ensure that the implementation code defines
the package_name
and load_path
model traits appropriately, so that
MLJ
's @load
macro can find the necessary code (see
MLJModels/src
for examples). The @load
command can only be tested after
registration. If changes are made, lodge an new issue at
MLJ requesting your
changes to be updated.
The MLJ model registry is located in the MLJModels.jl repository. To add a model, you need to follow these steps
-
Ensure your model conforms to the interface defined above
-
Raise an issue at MLJModels.jl and point out where the MLJ-interface implementation is, e.g. by providing a link to the code.
-
An administrator will then review your implementation and work with you to add the model to the registry