diff --git a/src/MLJDecisionTreeInterface.jl b/src/MLJDecisionTreeInterface.jl index 281b58d..86c5626 100644 --- a/src/MLJDecisionTreeInterface.jl +++ b/src/MLJDecisionTreeInterface.jl @@ -79,7 +79,7 @@ end function get_encoding(classes_seen) a_cat_element = classes_seen[1] - return Dict(c => MMI.int(c) for c in MMI.classes(a_cat_element)) + return Dict(MMI.int(c) => c for c in MMI.classes(a_cat_element)) end MMI.fitted_params(::DecisionTreeClassifier, fitresult) = @@ -538,9 +538,9 @@ To interpret the internal class labelling: ``` julia> fitted_params(mach).encoding Dict{CategoricalArrays.CategoricalValue{String, UInt32}, UInt32} with 3 entries: - "virginica" => 0x00000003 - "setosa" => 0x00000001 - "versicolor" => 0x00000002 + 0x00000003 => "virginica" + 0x00000001 => "setosa" + 0x00000002 => "versicolor" ``` See also diff --git a/test/runtests.jl b/test/runtests.jl index 14553f2..ad97520 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -49,6 +49,9 @@ yyhat = predict_mode(baretree, fitresult, MLJBase.selectrows(X, 1:3)) fp = fitted_params(baretree, fitresult) @test Set([:tree, :encoding, :features]) == Set(keys(fp)) @test fp.features == report.features +enc = fp.encoding +@test Set(values(enc)) == Set(["virginica", "setosa", "versicolor"]) +@test enc[MLJBase.int(y[end])] == "virginica" using Random: seed! seed!(0)