Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJDecisionTreeInterface"
uuid = "c6f25543-311c-4c74-83dc-3ea6d1015661"
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>"]
version = "0.2.0"
version = "0.2.1"

[deps]
DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
Expand Down
44 changes: 34 additions & 10 deletions src/MLJDecisionTreeInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -325,9 +325,20 @@ MMI.metadata_model(

# # DOCUMENT STRINGS

const DOC_CART = "[CART algorithm](https://en.wikipedia.org/wiki/Decision_tree_learning)"*
", originally published in Breiman, Leo; Friedman, J. H.; Olshen, R. A.; "*
"Stone, C. J. (1984): \"Classification and regression trees\". *Monterey, "*
"CA: Wadsworth & Brooks/Cole Advanced Books & Software.*"

const DOC_RANDOM_FOREST = "[Random Forest algorithm]"*
"(https://en.wikipedia.org/wiki/Random_forest), originally published in "*
"Breiman, L. (2001): \"Random Forests.\", *Machine Learning*, vol. 45, pp. 5–32"

"""
$(MMI.doc_header(DecisionTreeClassifier))

`DecisionTreeClassifier` implements the $DOC_CART.

# Training data

In MLJ or MLJBase, bind an instance `model` to data with
Expand All @@ -338,10 +349,11 @@ where

- `X`: any table of input features (eg, a `DataFrame`) whose columns
each have one of the following element scitypes: `Continuous`,
`Count`, or `<:OrderedFactor`.
`Count`, or `<:OrderedFactor`; check column scitypes with `schema(X)`

- `y`: is the target, which can be any `AbstractVector` whose element
scitype is `<:OrderedFactor` or `<:Multiclass`.
scitype is `<:OrderedFactor` or `<:Multiclass`; check the scitype
with `scitype(y)`

Train the machine using `fit!(mach, rows=...)`.

Expand Down Expand Up @@ -460,6 +472,9 @@ DecisionTreeClassifier
"""
$(MMI.doc_header(RandomForestClassifier))

`RandomForestClassifier` implements the standard $DOC_RANDOM_FOREST.


# Training data

In MLJ or MLJBase, bind an instance `model` to data with
Expand All @@ -470,10 +485,11 @@ where

- `X`: any table of input features (eg, a `DataFrame`) whose columns
each have one of the following element scitypes: `Continuous`,
`Count`, or `<:OrderedFactor`.
`Count`, or `<:OrderedFactor`; check column scitypes with `schema(X)`

- `y`: the target, which can be any `AbstractVector` whose element
scitype is `<:OrderedFactor` or `<:Multiclass`.
scitype is `<:OrderedFactor` or `<:Multiclass`; check the scitype
with `scitype(y)`

Train the machine with `fit!(mach, rows=...)`.

Expand Down Expand Up @@ -546,6 +562,7 @@ RandomForestClassifier
"""
$(MMI.doc_header(AdaBoostStumpClassifier))


# Training data

In MLJ or MLJBase, bind an instance `model` to data with
Expand All @@ -556,10 +573,11 @@ where:

- `X`: any table of input features (eg, a `DataFrame`) whose columns
each have one of the following element scitypes: `Continuous`,
`Count`, or `<:OrderedFactor`.
`Count`, or `<:OrderedFactor`; check column scitypes with `schema(X)`

- `y`: the target, which can be any `AbstractVector` whose element
scitype is `<:OrderedFactor` or `<:Multiclass`.
scitype is `<:OrderedFactor` or `<:Multiclass`; check the scitype
with `scitype(y)`

Train the machine with `fit!(mach, rows=...)`.

Expand Down Expand Up @@ -619,6 +637,9 @@ AdaBoostStumpClassifier
"""
$(MMI.doc_header(DecisionTreeRegressor))

`DecisionTreeRegressor` implements the $DOC_CART.


# Training data

In MLJ or MLJBase, bind an instance `model` to data with
Expand All @@ -629,10 +650,10 @@ where

- `X`: any table of input features (eg, a `DataFrame`) whose columns
each have one of the following element scitypes: `Continuous`,
`Count`, or `<:OrderedFactor`.
`Count`, or `<:OrderedFactor`; check column scitypes with `schema(X)`

- `y`: the target, which can be any `AbstractVector` whose element
scitype is `Continuous`.
scitype is `Continuous`; check the scitype with `scitype(y)`

Train the machine with `fit!(mach, rows=...)`.

Expand Down Expand Up @@ -699,6 +720,9 @@ DecisionTreeRegressor
"""
$(MMI.doc_header(RandomForestRegressor))

`DecisionTreeRegressor` implements the standard $DOC_RANDOM_FOREST


# Training data

In MLJ or MLJBase, bind an instance `model` to data with
Expand All @@ -709,10 +733,10 @@ where

- `X`: any table of input features (eg, a `DataFrame`) whose columns
each have one of the following element scitypes: `Continuous`,
`Count`, or `<:OrderedFactor`.
`Count`, or `<:OrderedFactor`; check column scitypes with `schema(X)`

- `y`: the target, which can be any `AbstractVector` whose element
scitype is `Continuous`.
scitype is `Continuous`; check the scitype with `scitype(y)`

Train the machine with `fit!(mach, rows=...)`.

Expand Down