Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions src/MLJDecisionTreeInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,8 @@ Train the machine using `fit!(mach, rows=...)`.

- `display_depth=5`: max depth to show when displaying the tree

- `feature_importance`: method to use for computing feature importances. One of `(:impurity, :split)`
- `feature_importance`: method to use for computing feature importances. One of `(:impurity,
:split)`

- `rng=Random.GLOBAL_RNG`: random number generator or seed

Expand Down Expand Up @@ -591,7 +592,8 @@ Train the machine with `fit!(mach, rows=...)`.

- `sampling_fraction=0.7` fraction of samples to train each tree on

- `feature_importance`: method to use for computing feature importances. One of `(:impurity, :split)`
- `feature_importance`: method to use for computing feature importances. One of `(:impurity,
:split)`

- `rng=Random.GLOBAL_RNG`: random number generator or seed

Expand All @@ -613,6 +615,11 @@ The fields of `fitted_params(mach)` are:
- `forest`: the `Ensemble` object returned by the core DecisionTree.jl algorithm


# Report

- `features`: the names of the features encountered in training


# Examples

```
Expand All @@ -632,6 +639,11 @@ predict_mode(mach, Xnew) # point predictions
pdf.(yhat, "virginica") # probabilities for the "verginica" class

fitted_params(mach).forest # raw `Ensemble` object from DecisionTrees.jl

feature_importances(mach) # `:impurity` feature importances
forest.feature_importance = :split
feature_importance(mach) # `:split` feature importances

```
See also
[DecisionTree.jl](https://github.com/bensadeghi/DecisionTree.jl) and
Expand Down Expand Up @@ -692,6 +704,12 @@ The fields of `fitted_params(mach)` are:

- `coefficients`: the stump coefficients (one per stump)


# Report

- `features`: the names of the features encountered in training


```
using MLJ
Booster = @load AdaBoostStumpClassifier pkg=DecisionTree
Expand Down Expand Up @@ -781,6 +799,11 @@ The fields of `fitted_params(mach)` are:
DecisionTree.jl algorithm


# Report

- `features`: the names of the features encountered in training


# Examples

```
Expand Down Expand Up @@ -864,6 +887,11 @@ The fields of `fitted_params(mach)` are:
- `forest`: the `Ensemble` object returned by the core DecisionTree.jl algorithm


# Report

- `features`: the names of the features encountered in training


# Examples

```
Expand Down