In this notebook, we want to introduce an alternative to JAX&Flax: namely Enzyme&Reactant, used in conjunction with Lux as a ML library. For this, we reimplement the good old MLP MNIST problem once more.
Note that the role of JAX is not only AD, but also to take code, and compile it down to XLA (@jax.jit). This is done by Reactant. The AD part itself is run by Enzyme.

Note: Due to some CUDA.jl issues, we need atleast Julia version 1.11.2. I recommend 1.11.3, for other versions this notebook is not tested.

In [1]:
ENV["XLA_FLAGS"] = "--xla_gpu_enable_triton_gemm=false"
Base.active_project()

"/home/jonas/Documents/project/jax_intro/lux_intro/lux_intro_env/Project.toml"

In [2]:
using Lux, LuxCUDA, Random, Optimisers, Plots, Reactant, Enzyme, OneHotArrays, Zygote
Reactant.set_default_backend("gpu")

I0000 00:00:1739379197.046268   64843 service.cc:152] XLA service 0x398c32a0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1739379197.046296   64843 service.cc:160]   StreamExecutor device (0): NVIDIA GeForce RTX 4090, Compute Capability 8.9
I0000 00:00:1739379197.046299   64843 service.cc:160]   StreamExecutor device (1): NVIDIA GeForce RTX 4090, Compute Capability 8.9
I0000 00:00:1739379197.046301   64843 service.cc:160]   StreamExecutor device (2): NVIDIA GeForce RTX 4090, Compute Capability 8.9
I0000 00:00:1739379197.046302   64843 service.cc:160]   StreamExecutor device (3): NVIDIA GeForce RTX 4090, Compute Capability 8.9
I0000 00:00:1739379197.047911   64843 se_gpu_pjrt_client.cc:987] Using BFC allocator.
I0000 00:00:1739379197.047941   64843 gpu_helpers.cc:136] XLA backend allocating 18946572288 bytes on device 0 for BFCAllocator.
I0000 00:00:1739379197.047968   64843 gpu_helpers.cc:136] XLA backend allocating 18946572288 byt

Reactant.XLA.Client(Ptr{Nothing} @0x0000000039fce750, Int32[0, 1, 2, 3])

In [3]:
rng  = Random.Xoshiro(123)

Xoshiro(0xfefa8d41b8f5dca5, 0xf80cc98e147960c1, 0x20e2ccc17662fc1d, 0xea7a7dcb2e787c01, 0xf4e85a418b9c4f80)

In [None]:
dev =     reactant_device()
gpu_dev = gpu_device(2)
cpu_dev = cpu_device()

(::CPUDevice) (generic function with 1 method)

Lux Models are decoupled from their parameters and states. Hence, we have to obtain them using the appropriate Initializers. For this, Lux comes with `Lux.setup`, which iterates through the model and returns the parameter and state arrays.

In [5]:
model = Chain(Dense(28*28, 128, relu), Dense(128, 10), softmax)
ps, st = Lux.setup(rng, model)
ps_dev = ps |> dev
st_dev = st |> dev

(layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple())

In [6]:
x_dev = Random.randn(Float32, 28*28, 60) |> dev;
y_dev= Random.rand(0:9, 60) |>(x-> onehotbatch(x, 0:9)) .|> Float32 |> dev

10×60 CuArray{Float32, 2, CUDA.DeviceMemory}:
 1.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  1.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  1.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  1.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  1.0  0.0  0.0  1.0  0.0  0.0
 0.0  0.0  0.0  0.0  1.0  0.0  1.0  1.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  1.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  1.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  1.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  1.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  1.0  0.0  0.0  1.0

In [7]:
compiled_model = @compile model(x_dev, ps_dev, st_dev)

ErrorException: cannot copy Ptr{CUDA.CUctx_st} @0x0000000036717300 of type Ptr{CUDA.CUctx_st}

# Loading the MNIST data
The MNIST Dataset is a common benchmark dataset, hence it is included in MLDatasets.jl. Here, we can download it quite easily.
(You propably have to do that in the REPL, since it requires you to accept the download.)

In [8]:
using MLDatasets, MLUtils, OneHotArrays, Plots, Images, Zygote

In [9]:
MNIST_train  = MLDatasets.MNIST(:train)
MNIST_test   = MLDatasets.MNIST(:test)

dataset MNIST:
  metadata  =>    Dict{String, Any} with 3 entries
  split     =>    :test
  features  =>    28×28×10000 Array{Float32, 3}
  targets   =>    10000-element Vector{Int64}

In [10]:
train_data = MNIST_train.features |> x-> reshape(x, (28*28, size(x,3)))
test_data = MNIST_test.features   |> x-> reshape(x, (28*28, size(x,3)))
train_labels = MNIST_train.targets |> x-> onehotbatch(x,0:9) .|> Float32
test_labels = MNIST_test.targets |> x-> onehotbatch(x,0:9) .|> Float32;

In [11]:
train_dataloader = DataLoader((data = train_data, label = train_labels), batchsize=60, shuffle=true) |> dev

DeviceIterator{CUDADevice{CuDevice}, DataLoader{BatchView{@NamedTuple{data::Matrix{Float32}, label::Matrix{Float32}}, ObsView{@NamedTuple{data::Matrix{Float32}, label::Matrix{Float32}}, Vector{Int64}}, Val{nothing}}, Bool, :serial, Val{nothing}, @NamedTuple{data::Matrix{Float32}, label::Matrix{Float32}}, TaskLocalRNG}}(CUDADevice{CuDevice}(CuDevice(0)), DataLoader(::@NamedTuple{data::Matrix{Float32}, label::Matrix{Float32}}, shuffle=true, batchsize=60))

# The Loss Function

In [12]:
function loss_function(model, ps, st, x, y)
    pred, _ = model(x, ps, st)
    return CrossEntropyLoss()(pred, y)
end

loss_function (generic function with 1 method)

In [13]:
function loss_gradient(model, ps, st, x, y)
    return Enzyme.gradient(Enzyme.Reverse, Const(loss_function),
                           Const(model), ps, Const(st), Const(x), Const(y))[2]
end

loss_gradient (generic function with 1 method)

In [None]:
compiled_loss_grad = @compile loss_gradient(model, ps_dev, st_dev, x_dev, y_dev)

# The Training Loop
In the following, let us define the training loop.

In [15]:
function train_model(model, ps, st, dataloader, optimizer, iterations::Integer)
    train_state = Training.TrainState(model, ps, st, optimizer)
    for iter in 1:iterations
        for (x,y) in dataloader
            _, loss, _, train_state = Training.single_train_step!(AutoZygote(), CrossEntropyLoss(), (x,y), train_state)
        end
    end
    return train_state
end

train_model (generic function with 1 method)

In [16]:
end_state = train_model(model, ps_dev, st_dev, train_dataloader, Optimisers.Descent(0.01), 1000)

TrainState
    model: Chain{@NamedTuple{layer_1::Dense{typeof(relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::WrappedFunction{typeof(softmax)}}, Nothing}((layer_1 = Dense(784 => 128, relu), layer_2 = Dense(128 => 10), layer_3 = WrappedFunction(softmax)), nothing)
    # of parameters: 101770
    # of states: 0
    optimizer: Descent(0.01)
    step: 1000000

In [17]:
ps_trained = end_state.parameters
st_trained = end_state.states

(layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple())

In [18]:
function accuracy(model, ps, st, x, y)
    pred, _ = model(x,ps,st)
    return mean(onecold(pred |> cpu_dev) .== onecold(y |> cpu_dev))
end

accuracy (generic function with 1 method)

In [19]:
using Statistics
test_dataloader  = DataLoader((data = test_data, label = test_labels), batchsize=60) |> dev
map(test_dataloader) do (x,y)
missing#   size(x) == (784,60) ? accuracy(compiled_model, ps_trained, st_trained, x,y) : missing
end |> skipmissing |> mean

ArgumentError: ArgumentError: reducing over an empty collection is not allowed; consider supplying `init` to the reducer

In [20]:
function show_image(flat_image)
    #input is (28*28)
    flat_image = reshape(flat_image, (28,28))
    Gray.(flat_image')
end

show_image (generic function with 1 method)