Skip to content

Commit

Permalink
update machine(...) doc-string
Browse files Browse the repository at this point in the history
  • Loading branch information
ablaom committed Dec 21, 2021
1 parent 6b0f2cf commit 0ec5472
Showing 1 changed file with 72 additions and 12 deletions.
84 changes: 72 additions & 12 deletions src/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,26 +202,53 @@ end
Construct a `Machine` object binding a `model`, storing
hyper-parameters of some machine learning algorithm, to some data,
`args`. When building a learning network, `Node` objects can be
substituted for concrete data. Specify `cache=false` to prioritize
memory managment over speed, and to guarantee data anonymity when
serializing composite models.
`args`. Calling `fit!` on a `Machine` object stores in the machine
object the outcomes of applying the algorithm. This in turn enables
generalization to new data using operations such as `predict` or
`transform`:
machine(Xs; oper1=node1, oper2=node2)
machine(Xs, ys; oper1=node1, oper2=node2)
```julia
using MLJModels
X, y = make_regression()
PCA = @load PCA pkg=MultivariateStats
model = PCA()
mach = machine(model, X)
fit!(mach, rows=1:50)
transform(mach, selectrows(X, 51:100)) # or transform(mach, rows=51:100)
DecisionTreeRegressor = @load DecisionTreeRegressor pkg=DecisionTree
model = DecisionTreeRegressor()
mach = machine(model, X, y)
fit!(mach, rows=1:50)
predict(mach, selectrows(X, 51:100)) # or predict(mach, rows=51:100)
```
Specify `cache=false` to prioritize memory management over speed, and
to guarantee data anonymity when serializing composite models.
When building a learning network, `Node` objects can be substituted
for the concrete data.
### Learning network machines
machine(Xs; oper1=node1, oper2=node2, ...)
machine(Xs, ys; oper1=node1, oper2=node2, ...)
machine(Xs, ys, extras...; oper1=node1, oper2=node2, ...)
Construct a special machine called a *learning network machine*, that
"wraps" a learning network, usually in preparation to export the
network as a stand-alone composite model type. The keyword arguments
declare what nodes are called when operations, such as `predict` and
`transform`, are called on the machine.
wraps a learning network, usually in preparation to export the network
as a stand-alone composite model type. The keyword arguments declare
what nodes are called when operations, such as `predict` and
`transform`, are called on the machine. An advanced option allows one
to additionally pass the output of any node to the machine's report;
see below.
In addition to the operations named in the constructor, the methods
`fit!`, `report`, and `fitted_params` can be applied as usual to the
machine constructed.
machine(Probablistic(), args...; kwargs...)
machine(Probabilistic(), args...; kwargs...)
machine(Deterministic(), args...; kwargs...)
machine(Unsupervised(), args...; kwargs...)
machine(Static(), args...; kwargs...)
Expand All @@ -234,7 +261,7 @@ machine that happens to be bound to a stand-alone composite model
(i.e., an *exported* learning network).
### Examples
### Examples of learning network machines
Supposing a supervised learning network's final predictions are
obtained by calling a node `yhat`, then the code
Expand Down Expand Up @@ -274,7 +301,40 @@ fit!(yhat)
transformed = Xout(Xnew)
predictions = yhat(Xnew)
```
### Including a node's output in the report
The return value of a node called with no arguments can be included in
a learning network machine's report, and so in the report of any
composite model type constructed by exporting a learning network. This
is useful for exposing byproducts of network training that are not
readily deduced from the `report`s and `fitted_params` of the
component machines (which are automatically exposed).
The following example shows how to expose `err1()` and `err2()`, where
`err1` are `err2` are nodes in the network delivering training errors.
```julia
X, y = make_moons()
Xs = source(X)
ys = source(y)
model = ConstantClassifier()
mach = machine(model, Xs, ys)
yhat = predict(mach, Xs)
err1 = @node auc(yhat, ys)
err2 = @node accuracy(yhat, ys)
network_mach = machine(Probabilistic(),
Xs,
ys,
predict=yhat,
report=(auc=err1, accuracy=err2))
fit!(network_mach)
r = report(network_mach)
@assert r.auc == auc(yhat(), ys())
@assert r.accuracy == accuracy(yhat(), ys())
```
"""
function machine end

Expand Down

0 comments on commit 0ec5472

Please sign in to comment.