Skip to content

Commit

Permalink
update readme
Browse files Browse the repository at this point in the history
  • Loading branch information
ablaom committed Jun 11, 2020
1 parent 23b46d3 commit 476e00c
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 4 deletions.
42 changes: 38 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ problem). While each MLJ model has a simple default builder, users
will generally need to define their own builders to get good results,
and this will require familiarity with the [Flux
API](https://fluxml.ai/Flux.jl/stable/) for defining a neural network
chain.
chain.

In the future MLJFlux may provided an assortment of more sophisticated
canned builders.
Expand All @@ -37,7 +37,7 @@ standardization of input features.

```julia
using MLJ
import RDatasets
import RDatasets
iris = RDatasets.dataset("datasets", "iris");
y, X = unpack(iris, ==(:Species), colname -> true, rng=123);
@load NeuralNetworkClassifier
Expand All @@ -58,7 +58,41 @@ NeuralNetworkClassifier(
optimiser_changes_trigger_retraining = false) @ 160
```

#### Inspecting evolution of out-of-sample performance
#### Incremental training

```julia
import Random.seed!; seed!(123)
mach = machine(clf, X, y)
fit!(mach)

julia> training_loss = cross_entropy(predict(mach, X), y) |> mean
0.89526004f0

# reduce learning rate and add iterations:
clf.optimiser.eta = clf.optimiser.eta / 5
clf.epochs = clf.epochs + 5

julia> fit!(mach, verbosity=2)
[ Info: Updating Machine{NeuralNetworkClassifier{Short,}} @ 142.
[ Info: Loss is 0.9638
[ Info: Loss is 0.9226
[ Info: Loss is 0.9197
[ Info: Loss is 0.9477
[ Info: Loss is 0.9517
Machine{NeuralNetworkClassifier{Short,}} @ 142

julia> training_loss = cross_entropy(predict(mach, X), y) |> mean
0.87368965f0
```
#### Accessing the Flux chain (model)
```julia
julia> fitted_params(mach).chain
Chain(Chain(Dense(4, 3, σ), Flux.Dropout{Float64}(0.5, false), Dense(3, 3)), softmax)
```
#### Evolution of out-of-sample performance
```julia
r = range(clf, :epochs, lower=1, upper=200, scale=:log10)
Expand Down Expand Up @@ -104,7 +138,7 @@ model type | prediction type | `scitype(X) <: _` | `scitype(y) <: _`
> Table 1. Input and output types for MLJFlux models
#### Matrix input
#### Non-tabular input
Any `AbstractMatrix{<:AbstractFloat}` object `Xmat` can be forced to
have scitype `Table(Continuous)` by replacing it with ` X =
Expand Down
Binary file modified learning_curve.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 476e00c

Please sign in to comment.