diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cac96ab..621c572 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -17,9 +17,8 @@ jobs: fail-fast: false matrix: version: - - '1.6' + - '1.10' - '1' - - 'nightly' os: - ubuntu-latest arch: @@ -30,7 +29,7 @@ jobs: with: version: ${{ matrix.version }} arch: ${{ matrix.arch }} - - uses: actions/cache@v1 + - uses: julia-actions/cache@v2 env: cache-name: cache-artifacts with: diff --git a/Project.toml b/Project.toml index 9dbfe21..554a425 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJDecisionTreeInterface" uuid = "c6f25543-311c-4c74-83dc-3ea6d1015661" authors = ["Anthony D. Blaom "] -version = "0.4.2" +version = "0.4.3" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" @@ -11,11 +11,11 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] -CategoricalArrays = "0.10" +CategoricalArrays = "1" DecisionTree = "0.12" MLJModelInterface = "1.5" Tables = "1.6" -julia = "1.6" +julia = "1.10" [extras] MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" diff --git a/src/MLJDecisionTreeInterface.jl b/src/MLJDecisionTreeInterface.jl index cd5357b..d0e409e 100644 --- a/src/MLJDecisionTreeInterface.jl +++ b/src/MLJDecisionTreeInterface.jl @@ -23,10 +23,6 @@ end Base.show(stream::IO, c::TreePrinter) = print(stream, "TreePrinter object (call with display depth)") -function classes(y) - p = CategoricalArrays.pool(y) - [p[i] for i in 1:length(p)] -end # # DECISION TREE CLASSIFIER @@ -79,7 +75,7 @@ function MMI.fit( end # returns a dictionary of categorical elements keyed on ref integer: -get_encoding(classes_seen) = Dict(MMI.int(c) => c for c in classes(classes_seen)) +get_encoding(classes_seen) = Dict(MMI.int(c) => c for c in levels(classes_seen)) # given such a dictionary, return printable class labels, ordered by corresponding ref # integer: @@ -459,7 +455,7 @@ _columnnames(X, ::Val{false}) = Tables.columnnames(first(Tables.rows(X))) # for fit: MMI.reformat(::Classifier, X, y) = - (Tables.matrix(X), MMI.int(y), _columnnames(X), classes(y)) + (Tables.matrix(X), MMI.int(y), _columnnames(X), levels(y)) MMI.reformat(::Regressor, X, y) = (Tables.matrix(X), float(y), _columnnames(X)) MMI.selectrows(::TreeModel, I, Xmatrix, y, meta...) = diff --git a/test/runtests.jl b/test/runtests.jl index 7252294..9179dcd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,7 @@ using Test import CategoricalArrays import CategoricalArrays.categorical +import CategoricalArrays.levels using MLJBase using StableRNGs using Random @@ -48,7 +49,7 @@ stable_rng() = StableRNGs.StableRNG(123) Xraw, yraw = @load_iris X = Tables.matrix(Xraw); y = int(yraw); -_classes = MLJDecisionTreeInterface.classes(yraw) +_classes = levels(yraw) features = MLJDecisionTreeInterface._columnnames(Xraw) baretree = DecisionTreeClassifier(rng=stable_rng()) @@ -74,7 +75,7 @@ yhat = MLJBase.predict(baretree, fitresult, X); # check preservation of levels: yyhat = predict_mode(baretree, fitresult, X[1:3, :]) -@test MLJBase.classes(yyhat[1]) == MLJBase.classes(yraw) +@test MLJBase.levels(yyhat[1]) == MLJBase.levels(yraw) # check report and fitresult fields: @test Set([:classes_seen, :print_tree, :features]) == Set(keys(report))