Skip to content

Commit

Permalink
Merge pull request #558 from JuliaAI/constructor
Browse files Browse the repository at this point in the history
Base the Model Registry on constructors, rather than model types; and ancillary changes
  • Loading branch information
ablaom committed Jun 5, 2024
2 parents 6517209 + ae426cb commit 0dcf82c
Show file tree
Hide file tree
Showing 14 changed files with 1,288 additions and 856 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJModels"
uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>"]
version = "0.16.17"
version = "0.17.0"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down Expand Up @@ -37,7 +37,7 @@ Distributions = "0.25"
InteractiveUtils = "<0.0.1, 1"
LinearAlgebra = "<0.0.1, 1"
Markdown = "<0.0.1, 1"
MLJModelInterface = "1.4"
MLJModelInterface = "1.10"
OrderedCollections = "1.1"
Parameters = "0.12"
Pkg = "<0.0.1, 1"
Expand Down
10 changes: 5 additions & 5 deletions src/MLJModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ nonmissing = nonmissingtype

include("utilities.jl")

# load built-in models:
include("builtins/Constant.jl")
include("builtins/Transformers.jl")
include("builtins/ThresholdPredictors.jl")

Handle = NamedTuple{(:name, :pkg), Tuple{String,String}}
(::Type{Handle})(name,string) = NamedTuple{(:name, :pkg)}((name, string))

Expand All @@ -79,11 +84,6 @@ include("loading.jl")
include("registry/src/Registry.jl")
using .Registry

# load built-in models:
include("builtins/Constant.jl")
include("builtins/Transformers.jl")
include("builtins/ThresholdPredictors.jl")

# finalize:
include("init.jl")

Expand Down
10 changes: 2 additions & 8 deletions src/builtins/ThresholdPredictors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -328,14 +328,8 @@ MMI.package_uuid(::Type{<:ThresholdUnion}) = ""
MMI.is_wrapper(::Type{<:ThresholdUnion}) = true
MMI.package_url(::Type{<:ThresholdUnion}) =
"https://github.com/JuliaAI/MLJModels.jl"

for New in THRESHOLD_TYPE_EXS
New_str = string(New)
quote
MMI.load_path(::Type{<:$New{M}}) where M = "MLJModels."*$New_str
end |> eval
end

MMI.load_path(::Type{<:ThresholdUnion}) = "MLJModels.BinaryThresholdPredictor"
MMI.constructor(::Type{<:ThresholdUnion}) = BinaryThresholdPredictor

for trait in [:supports_weights,
:supports_class_weights,
Expand Down
198 changes: 21 additions & 177 deletions src/builtins/Transformers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,90 +183,7 @@ function MMI.fitted_params(::FillImputer, fr)
filler_given_feature=filler_given_feature)
end


# # FOR FEATURE (COLUMN) SELECTION

mutable struct FeatureSelector <: Unsupervised
# features to be selected; empty means all
features::Union{Vector{Symbol}, Function}
ignore::Bool # features to be ignored
end

# keyword constructor
function FeatureSelector(
;
features::Union{AbstractVector{Symbol}, Function}=Symbol[],
ignore::Bool=false
)
transformer = FeatureSelector(features, ignore)
message = MMI.clean!(transformer)
isempty(message) || throw(ArgumentError(message))
return transformer
end

function MMI.clean!(transformer::FeatureSelector)
err = ""
if (
typeof(transformer.features) <: AbstractVector{Symbol} &&
isempty(transformer.features) &&
transformer.ignore
)
err *= "Features to be ignored must be specified in features field."
end
return err
end

function MMI.fit(transformer::FeatureSelector, verbosity::Int, X)
all_features = Tables.schema(X).names

if transformer.features isa AbstractVector{Symbol}
if isempty(transformer.features)
features = collect(all_features)
else
features = if transformer.ignore
!issubset(transformer.features, all_features) && verbosity > -1 &&
@warn("Excluding non-existent feature(s).")
filter!(all_features |> collect) do ftr
!(ftr in transformer.features)
end
else
issubset(transformer.features, all_features) ||
throw(ArgumentError("Attempting to select non-existent feature(s)."))
transformer.features |> collect
end
end
else
features = if transformer.ignore
filter!(all_features |> collect) do ftr
!(transformer.features(ftr))
end
else
filter!(all_features |> collect) do ftr
transformer.features(ftr)
end
end
isempty(features) && throw(
ArgumentError("No feature(s) selected.\n The specified Bool-valued"*
" callable with the `ignore` option set to `$(transformer.ignore)` "*
"resulted in an empty feature set for selection")
)
end

fitresult = features
report = NamedTuple()
return fitresult, nothing, report
end

MMI.fitted_params(::FeatureSelector, fitresult) = (features_to_keep=fitresult,)

function MMI.transform(::FeatureSelector, features, X)
all(e -> e in Tables.schema(X).names, features) ||
throw(ArgumentError("Supplied frame does not admit previously selected features."))
return MMI.selectcols(X, features)
end


# # UNIVARIATE DISCRETIZER
## UNIVARIATE DISCRETIZER

# helper function:
reftype(::CategoricalArray{<:Any,<:Any,R}) where R = R
Expand Down Expand Up @@ -1027,9 +944,14 @@ function MMI.transform(transformer::ContinuousEncoder, fitresult, X)
features_to_keep, hot_encoder, hot_fitresult = values(fitresult)

# dump unseen or untransformable features:
selector = FeatureSelector(features=features_to_keep)
selector_fitresult, _, _ = MMI.fit(selector, 0, X)
X0 = transform(selector, selector_fitresult, X)
if !issubset(features_to_keep, MMI.schema(X).names)
throw(
ArgumentError(
"Supplied frame does not admit previously selected features."
)
)
end
X0 = MMI.selectcols(X, features_to_keep)

# one-hot encode:
X1 = transform(hot_encoder, hot_fitresult, X0)
Expand Down Expand Up @@ -1080,11 +1002,18 @@ end
# # METADATA FOR ALL BUILT-IN TRANSFORMERS

metadata_pkg.(
(FeatureSelector, UnivariateStandardizer,
UnivariateDiscretizer, Standardizer,
UnivariateBoxCoxTransformer, UnivariateFillImputer,
OneHotEncoder, FillImputer, ContinuousEncoder,
UnivariateTimeTypeToContinuous, InteractionTransformer),
(
UnivariateStandardizer,
UnivariateDiscretizer,
Standardizer,
UnivariateBoxCoxTransformer,
UnivariateFillImputer,
OneHotEncoder,
FillImputer,
ContinuousEncoder,
UnivariateTimeTypeToContinuous,
InteractionTransformer
),
package_name = "MLJModels",
package_uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7",
package_url = "https://github.com/JuliaAI/MLJModels.jl",
Expand All @@ -1106,11 +1035,6 @@ metadata_model(FillImputer,
output_scitype = Table,
load_path = "MLJModels.FillImputer")

metadata_model(FeatureSelector,
input_scitype = Table,
output_scitype = Table,
load_path = "MLJModels.FeatureSelector")

metadata_model(UnivariateDiscretizer,
input_scitype = AbstractVector{<:Continuous},
output_scitype = AbstractVector{<:OrderedFactor},
Expand Down Expand Up @@ -1371,86 +1295,6 @@ See also [`UnivariateFillImputer`](@ref).
"""
FillImputer

"""
$(MLJModelInterface.doc_header(FeatureSelector))
Use this model to select features (columns) of a table, usually as
part of a model `Pipeline`.
# Training data
In MLJ or MLJBase, bind an instance `model` to data with
mach = machine(model, X)
where
- `X`: any table of input features, where "table" is in the sense of Tables.jl
Train the machine using `fit!(mach, rows=...)`.
# Hyper-parameters
- `features`: one of the following, with the behavior indicated:
- `[]` (empty, the default): filter out all features (columns) which
were not encountered in training
- non-empty vector of feature names (symbols): keep only the
specified features (`ignore=false`) or keep only unspecified
features (`ignore=true`)
- function or other callable: keep a feature if the callable returns
`true` on its name. For example, specifying
`FeatureSelector(features = name -> name in [:x1, :x3], ignore =
true)` has the same effect as `FeatureSelector(features = [:x1,
:x3], ignore = true)`, namely to select all features, with the
exception of `:x1` and `:x3`.
- `ignore`: whether to ignore or keep specified `features`, as
explained above
# Operations
- `transform(mach, Xnew)`: select features from the table `Xnew` as
specified by the model, taking features seen during training into
account, if relevant
# Fitted parameters
The fields of `fitted_params(mach)` are:
- `features_to_keep`: the features that will be selected
# Example
```
using MLJ
X = (ordinal1 = [1, 2, 3],
ordinal2 = coerce(["x", "y", "x"], OrderedFactor),
ordinal3 = [10.0, 20.0, 30.0],
ordinal4 = [-20.0, -30.0, -40.0],
nominal = coerce(["Your father", "he", "is"], Multiclass));
selector = FeatureSelector(features=[:ordinal3, ], ignore=true);
julia> transform(fit!(machine(selector, X)), X)
(ordinal1 = [1, 2, 3],
ordinal2 = CategoricalValue{Symbol,UInt32}["x", "y", "x"],
ordinal4 = [-20.0, -30.0, -40.0],
nominal = CategoricalValue{String,UInt32}["Your father", "he", "is"],)
```
"""
FeatureSelector


"""
$(MLJModelInterface.doc_header(Standardizer))
Expand Down
10 changes: 8 additions & 2 deletions src/metadata.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,14 @@ function decode_dic(s::String)
if s[1] == ':'
return Symbol(s[2:end])
elseif s[1] == '`' && s[2] != '`' # to exclude strings starting with ```
return eval(Meta.parse(s[2:end-1]))
ex = Meta.parse(s[2:end-1])
# We need a `try` here because `constructor` trait generally returns a
# function not in the namespace, because pkg defining it has not been loaded.
return try
eval(ex)
catch
ex
end
else
return s
end
Expand Down Expand Up @@ -150,4 +157,3 @@ function model_traits_in_registry(info_given_handle)
first_entry = info_given_handle[Handle("ConstantRegressor")]
return keys(first_entry) |> collect
end

Loading

0 comments on commit 0dcf82c

Please sign in to comment.