Skip to content

Wrapping deep learning models from the package Flux.jl for use in the MLJ.jl toolbox

License

Notifications You must be signed in to change notification settings

FluxML/MLJFlux.jl

Repository files navigation

image

An interface to the Flux deep learning models for the MLJ machine learning framework

Stable

Branch Julia CPU CI GPU CI Coverage
master v1 Continuous Integration (CPU) Continuous Integration (GPU) Code Coverage
dev v1 Continuous Integration (CPU) Continuous Integration (GPU) Code Coverage

Code Snippet

using MLJ, MLJFlux, RDatasets, Plots

Grab some data and split into features and target:

iris = RDatasets.dataset("datasets", "iris");
y, X = unpack(iris, ==(:Species), rng=123);
X = Float32.(X);      # To optmise for GPUs

Load model code and instantiate an MLJFlux model:

NeuralNetworkClassifier = @load NeuralNetworkClassifier pkg=MLJFlux

clf = NeuralNetworkClassifier(
    builder=MLJFlux.MLP(; hidden=(5,4)),
    batch_size=8,
    epochs=50,
    acceleration=CUDALibs()  # for training on a GPU
)

Wrap in "iteration controls":

stop_conditions = [
    Step(1),            # Apply controls every epoch
    NumberLimit(1000),  # Don't train for more than 1000 steps
    Patience(4),        # Stop after 4 iterations of deteriation in validation loss
    NumberSinceBest(5), # Or if the best loss occurred 5 iterations ago
    TimeLimit(30/60),   # Or if 30 minutes has passed
]

validation_losses = []
train_losses = []
callbacks = [
    WithLossDo(loss->push!(validation_losses, loss)),
    WithTrainingLossesDo(losses->push!(train_losses, losses[end])),
]

iterated_model = IteratedModel(
    model=clf,
    resampling=Holdout(fraction_train=0.5); # loss and stopping are based on out-of-sample
    measures=log_loss,
    controls=vcat(stop_conditions, callbacks),
);

Train the wrapped model:

julia> mach = machine(iterated_model, X, y)
julia> fit!(mach)

[ Info: No iteration parameter specified. Using `iteration_parameter=:(epochs)`. 
[ Info: final loss: 0.1284184007796247
[ Info: final training loss: 0.055630706
[ Info: Stop triggered by NumberSinceBest(5) stopping criterion. 
[ Info: Total of 811 iterations. 

Inspect results:

julia> plot(train_losses, label="Training Loss")
julia> plot!(validation_losses, label="Validation Loss", linewidth=2, size=(800,400))