# Using IterativeControl to train a tree-booster on the iris data set

In this demonstration we show how to the controls in
[IterationControl.jl](https://github.com/ablaom/IterationControl.jl)
with an iterative
[MLJ](https://github.com/alan-turing-institute/MLJ.jl) model, using
our bare hands. (MLJ will ultimately provide its own canned
`IteratedModel` wrapper to make this more convenient and
compositional.)

In [1]:
using Pkg
Pkg.activate(@__DIR__)
Pkg.instantiate()

import MLJ
using IterationControl

using Statistics
using Random
Random.seed!(123)

using Plots
pyplot(size=(600, 300*(sqrt(5)-1)));

MLJ.color_off()

 Activating environment at `~/Dropbox/Julia7/MLJ/IterationControl/examples/iris/Project.toml`


false

Loading some data and splitting observation indices into test/train:

In [2]:
X, y = MLJ.@load_iris;
train, test = MLJ.partition(eachindex(y), 0.7, shuffle=true)

([125, 100, 130, 9, 70, 148, 39, 64, 6, 107  …  134, 114, 52, 74, 44, 61, 83, 18, 122, 26], [97, 78, 30, 108, 101, 24, 85, 91, 135, 96  …  112, 144, 140, 72, 109, 41, 106, 147, 47, 5])

Import an model type:

In [3]:
Booster = MLJ.@load EvoTreeClassifier verbosity=0

EvoTrees.EvoTreeClassifier

Note that in MLJ a "model" is just a container for
hyper-parameters. The objects we will iterate here are MLJ
[*machines*](https://alan-turing-institute.github.io/MLJ.jl/dev/machines/);
these bind the model to train/test data and, in the case of
iterative models, can be trained using a warm-restart.

Creating a machine:

In [4]:
mach = MLJ.machine(Booster(nrounds=1), X, y);

Lifting MLJ's `fit!(::Machine)` method to `IterativeControl.train!`:

In [5]:
function IterationControl.train!(mach::MLJ.Machine{<:Booster}, n::Int)
    mlj_model = mach.model
    mlj_model.nrounds = mlj_model.nrounds + n
    MLJ.fit!(mach, rows=train, verbosity=0)
end

Lifting the out-of-sample loss:

In [6]:
function IterationControl.loss(mach::MLJ.Machine{<:Booster})
    mlj_model = mach.model
    yhat = MLJ.predict(mach, rows=test)
    return MLJ.log_loss(yhat, y[test]) |> mean
end

Iterating with controls:

In [7]:
logging(mach) = "loss: $(IterationControl.loss(mach))"

IterationControl.train!(mach,
                        Train(5),
                        GL(),
                        Info(logging))

┌ Info: loss: 0.40747008
└ @ IterationControl /Users/anthony/Dropbox/Julia7/MLJ/IterationControl/src/controls.jl:104
┌ Info: loss: 0.2794046
└ @ IterationControl /Users/anthony/Dropbox/Julia7/MLJ/IterationControl/src/controls.jl:104
┌ Info: loss: 0.2300303
└ @ IterationControl /Users/anthony/Dropbox/Julia7/MLJ/IterationControl/src/controls.jl:104
┌ Info: loss: 0.21370259
└ @ IterationControl /Users/anthony/Dropbox/Julia7/MLJ/IterationControl/src/controls.jl:104
┌ Info: loss: 0.21425402
└ @ IterationControl /Users/anthony/Dropbox/Julia7/MLJ/IterationControl/src/controls.jl:104
┌ Info: loss: 0.22378877
└ @ IterationControl /Users/anthony/Dropbox/Julia7/MLJ/IterationControl/src/controls.jl:104
┌ Info: Early stop triggered by GL(2.0) stopping criterion. 
└ @ IterationControl /Users/anthony/Dropbox/Julia7/MLJ/IterationControl/src/controls.jl:50


((Train(5), NamedTuple()), (GL(2.0), (done = true, log = "Early stop triggered by GL(2.0) stopping criterion. ")), (Info{typeof(Main.##260.logging)}(Main.##260.logging), NamedTuple()))

Continuing iteration with a different stopping criterion:

In [8]:
IterationControl.train!(mach,
                        Train(5),
                        NumberLimit(10),
                        Info(logging))

┌ Info: loss: 0.23854089
└ @ IterationControl /Users/anthony/Dropbox/Julia7/MLJ/IterationControl/src/controls.jl:104
┌ Info: loss: 0.25629267
└ @ IterationControl /Users/anthony/Dropbox/Julia7/MLJ/IterationControl/src/controls.jl:104
┌ Info: loss: 0.27593076
└ @ IterationControl /Users/anthony/Dropbox/Julia7/MLJ/IterationControl/src/controls.jl:104
┌ Info: loss: 0.29661915
└ @ IterationControl /Users/anthony/Dropbox/Julia7/MLJ/IterationControl/src/controls.jl:104
┌ Info: loss: 0.31790704
└ @ IterationControl /Users/anthony/Dropbox/Julia7/MLJ/IterationControl/src/controls.jl:104
┌ Info: loss: 0.3308279
└ @ IterationControl /Users/anthony/Dropbox/Julia7/MLJ/IterationControl/src/controls.jl:104
┌ Info: loss: 0.34393334
└ @ IterationControl /Users/anthony/Dropbox/Julia7/MLJ/IterationControl/src/controls.jl:104
┌ Info: loss: 0.3614778
└ @ IterationControl /Users/anthony/Dropbox/Julia7/MLJ/IterationControl/src/controls.jl:104
┌ Info: loss: 0.375098
└ @ IterationControl /Users/anthony/Dropbox

((Train(5), NamedTuple()), (NumberLimit(10), (done = true, log = "Early stop triggered by NumberLimit(10) stopping criterion. ")), (Info{typeof(Main.##260.logging)}(Main.##260.logging), NamedTuple()))

---

*This notebook was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*