# Incremental Training with MLJFlux
In this workflow example we explore how to incrementally train MLJFlux models.

**Julia version** is assumed to be 1.10.*

### Basic Imports

In [1]:
using MLJ               # Has MLJFlux models
using Flux              # For more flexibility
import RDatasets        # Dataset source

### Loading and Splitting the Data

In [2]:
iris = RDatasets.dataset("datasets", "iris");
y, X = unpack(iris, ==(:Species), colname -> true, rng=123);
X = Float32.(X)      # To be compatible with type of network network parameters
(X_train, X_test), (y_train, y_test) = partition((X, y), 0.8,
                                                 multi = true,
                                                 shuffle = true,
                                                 rng=42);

### Instantiating the model
Now let's construct our model. This follows a similar setup to the one followed in the [Quick Start](../../index.md#quick-start).

In [3]:
NeuralNetworkClassifier = @load NeuralNetworkClassifier pkg=MLJFlux
clf = NeuralNetworkClassifier(
    builder=MLJFlux.MLP(; hidden=(5,4), σ=Flux.relu),
    optimiser=Flux.ADAM(0.01),
    batch_size=8,
    epochs=10,
    rng=42
    )

[ Info: For silent loading, specify `verbosity=0`. 
import MLJFlux ✔


NeuralNetworkClassifier(
  builder = MLP(
        hidden = (5, 4), 
        σ = NNlib.relu), 
  finaliser = NNlib.softmax, 
  optimiser = Flux.Optimise.Adam(0.01, (0.9, 0.999), 1.0e-8, IdDict{Any, Any}()), 
  loss = Flux.Losses.crossentropy, 
  epochs = 10, 
  batch_size = 8, 
  lambda = 0.0, 
  alpha = 0.0, 
  rng = 42, 
  optimiser_changes_trigger_retraining = false, 
  acceleration = ComputationalResources.CPU1{Nothing}(nothing))

### Initial round of training
Now let's train the model. Calling fit! will automatically train it for 100 epochs as specified above.

In [4]:
mach = machine(clf, X_train, y_train)
fit!(mach)

[ Info: Training machine(NeuralNetworkClassifier(builder = MLP(hidden = (5, 4), …), …), …).


trained Machine; caches model-specific representations of data
  model: NeuralNetworkClassifier(builder = MLP(hidden = (5, 4), …), …)
  args: 
    1:	Source @655 ⏎ ScientificTypesBase.Table{AbstractVector{ScientificTypesBase.Continuous}}
    2:	Source @902 ⏎ AbstractVector{ScientificTypesBase.Multiclass{3}}


Let's evaluate the training loss and validation accuracy

In [5]:
training_loss = cross_entropy(predict(mach, X_train), y_train)

0.5187556517212482

In [6]:
val_acc = accuracy(predict_mode(mach, X_test), y_test)

0.5333333333333333

Poor performance it seems.
### Incremental Training
Now let's train it for another 30 epochs at half the original learning rate. All we need to do is changes these
hyperparameters and call fit again. It won't reset the model parameters before training.

In [7]:
clf.optimiser.eta = clf.optimiser.eta / 2
clf.epochs = clf.epochs + 30
fit!(mach, verbosity=2);

[ Info: Updating machine(NeuralNetworkClassifier(builder = MLP(hidden = (5, 4), …), …), …).
[ Info: Loss is 0.5195
[ Info: Loss is 0.5113
[ Info: Loss is 0.5056
[ Info: Loss is 0.501
[ Info: Loss is 0.497
[ Info: Loss is 0.4944
[ Info: Loss is 0.4909
[ Info: Loss is 0.4881
[ Info: Loss is 0.4855
[ Info: Loss is 0.4833
[ Info: Loss is 0.4813
[ Info: Loss is 0.4794
[ Info: Loss is 0.4777
[ Info: Loss is 0.476
[ Info: Loss is 0.4744
[ Info: Loss is 0.4729
[ Info: Loss is 0.471
[ Info: Loss is 0.4685
[ Info: Loss is 0.4357
[ Info: Loss is 0.3986
[ Info: Loss is 0.354
[ Info: Loss is 0.3212
[ Info: Loss is 0.294
[ Info: Loss is 0.2832
[ Info: Loss is 0.2727
[ Info: Loss is 0.247
[ Info: Loss is 0.2285
[ Info: Loss is 0.2153
[ Info: Loss is 0.2024
[ Info: Loss is 0.1928


Let's evaluate the training loss and validation accuracy

In [8]:
training_loss = cross_entropy(predict(mach, X_train), y_train)

0.18276122841169196

In [9]:
training_acc = accuracy(predict_mode(mach, X_test), y_test)

0.9333333333333333

That's much better. If we are rather interested in resetting the model parameters before fitting, we can do `fit(mach, force=true)`.

---

*This notebook was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*