Skip to content

Commit

Permalink
more fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ExpandingMan committed Nov 10, 2022
1 parent 448c803 commit f59e148
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 14 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Expand Up @@ -12,7 +12,7 @@ XGBoost = "009559a3-9522-5dbb-924b-0b6ed2b22bb9"
[compat]
MLJModelInterface = "0.3.5, 0.4, 1"
Tables = "1.0.5"
XGBoost = "2"
XGBoost = "2.0.1"
julia = "1.6"

[extras]
Expand Down
21 changes: 8 additions & 13 deletions src/MLJXGBoostInterface.jl
Expand Up @@ -100,22 +100,17 @@ end

eval(modelexpr(:XGBoostRegressor, :XGBoostAbstractRegressor, "reg:squarederror", :validate_reg_objective))

function kwargs(model, verbosity::Integer, obj)
function kwargs(model, verbosity, obj)
excluded = [:importance_type]
fn = filter((excluded), fieldnames(typeof(model)))
o = NamedTuple(n=>getfield(model, n) for n fn if !isnothing(getfield(model, n)))
o = merge(o, (silent=(verbosity 0),))
merge(o, (objective=_fix_objective(obj),))
end

function importances(X, r)
fs = schema(X).names
[named_importance(fi, fs) for fi XGB.importance(r)]
end

function MMI.feature_importances(model::XGTypes, (booster, _), (features,))
dict = XGB.importance(booster, model.importance_type)
if length(last(first(importance_dict))) > 1
if length(last(first(dict))) > 1
[features[k] => zero(first(v)) for (k, v) in dict]
else
[features[k] => first(v) for (k, v) in dict]
Expand Down Expand Up @@ -146,10 +141,10 @@ eval(modelexpr(:XGBoostCount, :XGBoostAbstractRegressor, "count:poisson", :valid

eval(modelexpr(:XGBoostClassifier, :XGBoostAbstractClassifier, "automatic", :validate_class_objective))

function MMI.fit(model::XGBoostClassifier
, verbosity::Int #> must be here even if unsupported in pkg
, X
, y)
function MMI.fit(model::XGBoostClassifier,
verbosity, # must be here even if unsupported in pkg
X, y,
)
a_target_element = y[1] # a CategoricalValue or CategoricalString
nclass = length(MMI.classes(a_target_element))

Expand Down Expand Up @@ -204,8 +199,8 @@ MMI.save(::XGBoostClassifier, fr; kw...) = (_save(fr[1]; kw...), fr[2])

MMI.restore(::XGBoostClassifier, fr) = (_restore(fr[1]), fr[2])

MLJModelInterface.reports_feature_importances(::Type{XGBoostAbstractRegressor}) = true
MLJModelInterface.reports_feature_importances(::Type{XGBoostAbstractClassifier}) = true
MLJModelInterface.reports_feature_importances(::Type{<:XGBoostAbstractRegressor}) = true
MLJModelInterface.reports_feature_importances(::Type{<:XGBoostAbstractClassifier}) = true


MMI.package_name(::Type{<:XGTypes}) = "XGBoost"
Expand Down
3 changes: 3 additions & 0 deletions test/runtests.jl
Expand Up @@ -176,6 +176,9 @@ end
mach = machine(plain_classifier, X, yclass)
fit!(mach, verbosity=0)
yhat = predict_mode(mach, X);

imps = feature_importances(mach)
@test Set(string.([imp[1] for imp imps])) == Set(["x1", "x2", "x3"])

# serialize:
io = IOBuffer()
Expand Down

0 comments on commit f59e148

Please sign in to comment.