In [None]:
using Lux, Random, Optimisers, Zygote, MLUtils

In [None]:
rng = Random.default_rng()
Random.seed!(rng, 0)

In [None]:
x = reshape(collect(Float32, 1:10), 1, :)   # (in_features, batch)
y = 3f0 * x .+ 1f0

In [None]:
loader = DataLoader((x, y); batchsize=2, shuffle=true)

In [None]:
model = Dense(1 => 1)

In [None]:
params, states = Lux.setup(rng, model)

In [None]:
function loss(params, x, y, states)
    yhat, new_states = model(x, params, states)
    return sum((yhat .- y).^2) / length(y), new_states
end

In [None]:
# grad = gradient(params -> loss(params, x, y, states), params) # for pre-compliation only...

In [None]:
opt = Optimisers.Adam(0.01)
opt_state = Optimisers.setup(opt, params)

In [None]:
nepochs = 50
for epoch in 1:nepochs
    epoch_loss = 0f0

    for (xb, yb) in loader
        (ls, new_states), back = Zygote.pullback(params -> loss(params, xb, yb, states), params)
        states = new_states
        grad = back((1f0, nothing))[1]

        opt_state, params = Optimisers.update(opt_state, params, grad)
        epoch_loss += ls
    end

    @info "epoch $epoch, loss = $(epoch_loss / length(loader))"
end

In [None]:
W = params.weight
b = params.bias

@show W b

In [None]:
# use ComponentArray(W) for flattening?..