Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@ jobs:
fail-fast: false
matrix:
version:
- '1.6'
- '1.10'
- '1'
- 'nightly'
os:
- ubuntu-latest
arch:
Expand All @@ -30,7 +29,7 @@ jobs:
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: actions/cache@v1
- uses: julia-actions/cache@v2
env:
cache-name: cache-artifacts
with:
Expand Down
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJDecisionTreeInterface"
uuid = "c6f25543-311c-4c74-83dc-3ea6d1015661"
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>"]
version = "0.4.2"
version = "0.4.3"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand All @@ -11,11 +11,11 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
CategoricalArrays = "0.10"
CategoricalArrays = "1"
DecisionTree = "0.12"
MLJModelInterface = "1.5"
Tables = "1.6"
julia = "1.6"
julia = "1.10"

[extras]
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
Expand Down
8 changes: 2 additions & 6 deletions src/MLJDecisionTreeInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@ end
Base.show(stream::IO, c::TreePrinter) =
print(stream, "TreePrinter object (call with display depth)")

function classes(y)
p = CategoricalArrays.pool(y)
[p[i] for i in 1:length(p)]
end

# # DECISION TREE CLASSIFIER

Expand Down Expand Up @@ -79,7 +75,7 @@ function MMI.fit(
end

# returns a dictionary of categorical elements keyed on ref integer:
get_encoding(classes_seen) = Dict(MMI.int(c) => c for c in classes(classes_seen))
get_encoding(classes_seen) = Dict(MMI.int(c) => c for c in levels(classes_seen))

# given such a dictionary, return printable class labels, ordered by corresponding ref
# integer:
Expand Down Expand Up @@ -459,7 +455,7 @@ _columnnames(X, ::Val{false}) = Tables.columnnames(first(Tables.rows(X)))

# for fit:
MMI.reformat(::Classifier, X, y) =
(Tables.matrix(X), MMI.int(y), _columnnames(X), classes(y))
(Tables.matrix(X), MMI.int(y), _columnnames(X), levels(y))
MMI.reformat(::Regressor, X, y) =
(Tables.matrix(X), float(y), _columnnames(X))
MMI.selectrows(::TreeModel, I, Xmatrix, y, meta...) =
Expand Down
5 changes: 3 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Test
import CategoricalArrays
import CategoricalArrays.categorical
import CategoricalArrays.levels
using MLJBase
using StableRNGs
using Random
Expand Down Expand Up @@ -48,7 +49,7 @@ stable_rng() = StableRNGs.StableRNG(123)
Xraw, yraw = @load_iris
X = Tables.matrix(Xraw);
y = int(yraw);
_classes = MLJDecisionTreeInterface.classes(yraw)
_classes = levels(yraw)
features = MLJDecisionTreeInterface._columnnames(Xraw)

baretree = DecisionTreeClassifier(rng=stable_rng())
Expand All @@ -74,7 +75,7 @@ yhat = MLJBase.predict(baretree, fitresult, X);

# check preservation of levels:
yyhat = predict_mode(baretree, fitresult, X[1:3, :])
@test MLJBase.classes(yyhat[1]) == MLJBase.classes(yraw)
@test MLJBase.levels(yyhat[1]) == MLJBase.levels(yraw)

# check report and fitresult fields:
@test Set([:classes_seen, :print_tree, :features]) == Set(keys(report))
Expand Down
Loading