Skip to content

Latest commit

 

History

History
140 lines (106 loc) · 5.51 KB

trait_declarations.md

File metadata and controls

140 lines (106 loc) · 5.51 KB

Trait declarations

Two trait functions allow the implementer to restrict the types of data X, y and Xnew discussed above. The MLJ task interface uses these traits for data type checks but also for model search. If they are omitted (and your model is registered) then a general user may attempt to use your model with inappropriately typed data.

The trait functions input_scitype and target_scitype take scientific data types as values. We assume here familiarity with ScientificTypes.jl (see Getting Started for the basics).

For example, to ensure that the X presented to the DecisionTreeClassifier fit method is a table whose columns all have Continuous element type (and hence AbstractFloat machine type), one declares

MMI.input_scitype(::Type{<:DecisionTreeClassifier}) = MMI.Table(MMI.Continuous)

or, equivalently,

MMI.input_scitype(::Type{<:DecisionTreeClassifier}) = Table(Continuous)

If, instead, columns were allowed to have either: (i) a mixture of Continuous and Missing values, or (ii) Count (i.e., integer) values, then the declaration would be

MMI.input_scitype(::Type{<:DecisionTreeClassifier}) = Table(Union{Continuous,Missing},Count)

Similarly, to ensure the target is an AbstractVector whose elements have Finite scitype (and hence CategoricalValue machine type) we declare

MMI.target_scitype(::Type{<:DecisionTreeClassifier}) = AbstractVector{<:Finite}

Multivariate targets

The above remarks continue to hold unchanged for the case multivariate targets. For example, if we declare

target_scitype(SomeSupervisedModel) = Table(Continuous)

then this constrains the target to be any table whose columns have Continuous element scitype (i.e., AbstractFloat), while

target_scitype(SomeSupervisedModel) = Table(Continuous, Finite{2})

restricts to tables with continuous or binary (ordered or unordered) columns.

For predicting variable length sequences of, say, binary values (CategoricalValues) with some common size-two pool) we declare

target_scitype(SomeSupervisedModel) = AbstractVector{<:NTuple{<:Finite{2}}}

The trait functions controlling the form of data are summarized as follows:

method return type declarable return values fallback value
input_scitype Type some scientific type Unknown
target_scitype Type some scientific type Unknown

Additional trait functions tell MLJ's @load macro how to find your model if it is registered, and provide other self-explanatory metadata about the model:

method return type declarable return values fallback value
load_path String unrestricted "unknown"
package_name String unrestricted "unknown"
package_uuid String unrestricted "unknown"
package_url String unrestricted "unknown"
package_license String unrestricted "unknown"
is_pure_julia Bool true or false false
supports_weights Bool true or false false
supports_class_weights Bool true or false false
supports_training_losses Bool true or false false
reports_feature_importances Bool true or false false

Here is the complete list of trait function declarations for DecisionTreeClassifier, whose core algorithms are provided by DecisionTree.jl, but whose interface actually lives at MLJDecisionTreeInterface.jl.

MMI.input_scitype(::Type{<:DecisionTreeClassifier}) = MMI.Table(MMI.Continuous)
MMI.target_scitype(::Type{<:DecisionTreeClassifier}) = AbstractVector{<:MMI.Finite}
MMI.load_path(::Type{<:DecisionTreeClassifier}) = "MLJDecisionTreeInterface.DecisionTreeClassifier"
MMI.package_name(::Type{<:DecisionTreeClassifier}) = "DecisionTree"
MMI.package_uuid(::Type{<:DecisionTreeClassifier}) = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
MMI.package_url(::Type{<:DecisionTreeClassifier}) = "https://github.com/bensadeghi/DecisionTree.jl"
MMI.is_pure_julia(::Type{<:DecisionTreeClassifier}) = true

Alternatively, these traits can also be declared using MMI.metadata_pkg and MMI.metadata_model helper functions as:

MMI.metadata_pkg(
  DecisionTreeClassifier,
  name="DecisionTree",
  package_uuid="7806a523-6efd-50cb-b5f6-3fa6f1930dbb",
  package_url="https://github.com/bensadeghi/DecisionTree.jl",
  is_pure_julia=true
)

MMI.metadata_model(
  DecisionTreeClassifier,
  input_scitype=MMI.Table(MMI.Continuous),
  target_scitype=AbstractVector{<:MMI.Finite},
  load_path="MLJDecisionTreeInterface.DecisionTreeClassifier"
)

Important. Do not omit the load_path specification. If unsure what it should be, post an issue at MLJ.

MMI.metadata_pkg
MMI.metadata_model