In [None]:
num_threads = Threads.nthreads()

In [None]:
using Distributions
using GeometricIntegrators
using Optim
using Random
using Flux
using Enzyme
using Zygote
using Distances
using Symbolics
using Plots
using RuntimeGeneratedFunctions
RuntimeGeneratedFunctions.init(@__MODULE__)

gr()

_prod(a, b, c, arrs...) = a .* _prod(b, c, arrs...)
_prod(a, b) = a .* b
_prod(a) = a

# generates a vector out of symbolic arrays (p,q) with a certain dimension
function get_z_vector(dims)
    @variables q[1:dims]
    @variables p[1:dims]
    z = vcat(q,p)
    return z
end

# make combinations of bases of just the order that is given 
# e.g order = 2 will give just the bases whose powers sum to 2
function poly_combos(z, order, inds...)
    if order == 0
        return Num[1]
    elseif order == length(inds)
        return [_prod([z[i] for i in inds]...)]
    else
        start_ind = length(inds) == 0 ? 1 : inds[end]
        return vcat([poly_combos(z, order, inds..., j) for j in start_ind:length(z)]...)
    end
end

# gives all bases monomials up to a certain order
function primal_monomial_basis(z, order::Int)
    return Vector{Symbolics.Num}(vcat([poly_combos(z, i) for i in 1:order]...))
end

# calculates coefficient bases up to a certain order
# mostly for use with trigonometric functions example sin(k*z),
# where k is the coefficient
function primal_coeff_basis(z, max_coeff::Int)
    return Vector{Symbolics.Num}(vcat([k .* z for k in 1:max_coeff]...))
end

# calculates +,-,*,/ between states as a new basis
# the return output is a set to avoid duplicates
function primal_operator_basis(z, operator)
    return Vector{Symbolics.Num}([operator(z[i], z[j]) for i in 1:length(z)-1 for j in i+1:length(z)] ∪ [operator(z[j], z[i]) for i in 1:length(z)-1 for j in i+1:length(z)])
end

function primal_power_basis(z, max_power::Int)
    if max_power > 0
        return Vector{Symbolics.Num}(vcat([z.^i for i in 1:max_power]...))
    elseif max_power < 0
        return Vector{Symbolics.Num}(vcat([z.^-i for i in 1:abs(max_power)]...))
    end
end

function polynomial_basis(z::Vector{Symbolics.Num} = get_z_vector(2); polyorder::Int = 0, operator=nothing, max_coeff::Int = 0)
    primes = primal_monomial_basis(z, polyorder)
    primes = vcat(primes, primal_coeff_basis(z, max_coeff))
    if operator !== nothing
        primes = vcat(primes, primal_operator_basis(z, operator))
    end
    return primes
end

function trigonometric_basis(z::Vector{Symbolics.Num} = get_z_vector(2); polyorder::Int = 0, operator=nothing, max_coeff::Int = 0)
    primes = polynomial_basis(z, polyorder = polyorder, operator = operator, max_coeff = max_coeff)
    return vcat(sin.(primes), cos.(primes))
end

function exponential_basis(z::Vector{Symbolics.Num} = get_z_vector(2); polyorder::Int = 0, operator=nothing, max_coeff::Int = 0)
    primes = polynomial_basis(z, polyorder = polyorder, operator = operator, max_coeff = max_coeff)
    return exp.(primes)
end

function logarithmic_basis(z::Vector{Symbolics.Num} = get_z_vector(2); polyorder::Int = 0, operator=nothing, max_coeff::Int = 0)
    primes = polynomial_basis(z, polyorder = polyorder, operator = operator, max_coeff = max_coeff)
    return log.(abs.(primes))
end

function mixed_states_basis(basis::Vector{Symbolics.Num}...)
    mixed_states = Tuple(basis)
    
    ham = Vector{Symbolics.Num}()
    for i in eachindex(mixed_states)
        for j in i+1:lastindex(mixed_states)
            ham = vcat(ham, [mixed_states[i][k] * mixed_states[j][l] for k in 1:length(mixed_states[i]) for l in 1:length(mixed_states[j])])
        end
    end
    
    return Vector{Symbolics.Num}(ham)
end

# returns the number of required coefficients for the basis
function get_numCoeffs(basis::Vector{Symbolics.Num})
    return length(basis)
end


# gets a vector of combinations of hamiltonian basis
function get_basis_set(basis::Vector{Symbolics.Num}...)
    # gets a vector of combinations of basis
    basis = vcat(basis...)
    
    # removes duplicates
    basis = Vector{Symbolics.Num}(collect(unique(basis)))

    return basis
end

#TODO: maybe basis shouldn't be variable number of arguments
# returns a function that can build the gradient of the hamiltonian
function ΔH_func_builder(d::Int, z::Vector{Symbolics.Num} = get_z_vector(d), basis::Vector{Symbolics.Num}...) 
    # nd is the total number of dimensions of all the states, e.g. if q,p each of 3 dims, that is 6 dims in total
    nd = 2d
    Dz = Differential.(z)
    
    # collects and sums combinations of basis and coefficients"
    basis = get_basis_set(basis...)
   
    # gets number of terms in the basis
    @variables a[1:get_numCoeffs(basis)]
    
    # collect and sum combinations of basis and coefficients
    ham = sum(collect(a .* basis))
    
    # gives derivative of the hamiltonian, but not the skew-symmetric true one
    f = [expand_derivatives(dz(ham)) for dz in Dz]

    #simplify the expression potentially to make it faster
    f = simplify(f)
    
    # line below makes the vector into a hamiltonian vector field by multiplying with the skew-symmetric matrix
    ∇H = vcat(f[d+1:2d], -f[1:d])
    
    # builds a function that calculates Hamiltonian gradient and converts the function to a native Julia function
    ∇H_eval = @RuntimeGeneratedFunction(Symbolics.inject_registered_module_functions(build_function(∇H, z, a)[2]))
    
    return ∇H_eval
end

struct HamiltonianSINDy{T, GHT}
    basis::Vector{Symbolics.Num} # the augmented basis for sparsification
    analytical_fθ::GHT
    z::Vector{Symbolics.Num} 
    λ::T # Sparsification Parameter
    noise_level::T # Noise amplitude added to the data
    noiseGen_timeStep::T # Time step for the integrator to get noisy data 
    nloops::Int # Sparsification Loops
    batch_size::Int # Batch size for training
    basis_coeff::Float64 # Coefficient for the coefficients of the basis
    
    function HamiltonianSINDy(basis::Vector{Symbolics.Num},
        analytical_fθ::GHT = missing,
        z::Vector{Symbolics.Num} = get_z_vector(2);
        λ::T = 0.05,
        noise_level::T = 0.00,
        noiseGen_timeStep::T = 0.05,
        nloops::Int = 10,
        batch_size::Int,
        basis_coeff::Float64) where {T, GHT <: Union{Base.Callable,Missing}}

        new{T, GHT}(basis, analytical_fθ, z, λ, noise_level, noiseGen_timeStep, nloops, batch_size::Int, basis_coeff::Float64)
    end
end

function gen_noisy_ref_data(method::HamiltonianSINDy, x)
    # initialize timestep data for analytical solution
    tstep = method.noiseGen_timeStep
    tspan = (zero(tstep), tstep)

    function next_timestep(x)
        prob_ref = ODEProblem((dx, t, x, params) -> method.analytical_fθ(dx, x, params, t), tspan, tstep, x)
        sol = integrate(prob_ref, Gauss(2))
        sol.q[end]
    end

    data_ref = [next_timestep(_x) for _x in x]

    # add noise
    data_ref_noisy = [_x .+ method.noise_level .* randn(size(_x)) for _x in data_ref]

    return data_ref_noisy

end

struct TrainingData{AT<:AbstractArray}
    x::AT # initial condition
    ẋ::AT # initial condition
    y::AT # noisy data at next time step

    TrainingData(x::AT, ẋ::AT, y::AT) where {AT} = new{AT}(x, ẋ, y)
    TrainingData(x::AT, ẋ::AT) where {AT} = new{AT}(x, ẋ)
end



In [None]:
# --------------------
# Setup
# --------------------

println("Setting up...")

# 2D system with 4 variables [q₁, q₂, p₁, p₂]
const nd = 4

z = get_z_vector(nd/2)
polynomial = polynomial_basis(z, polyorder=3)
trigonometric  = trigonometric_basis(z, max_coeff=1)
# prime_diff = primal_operator_basis(z, -)
# basis = get_basis_set(polynomial, trigonometric, prime_diff)
basis = get_basis_set(polynomial, trigonometric)
# initialize analytical function, keep λ smaller than ϵ so system is identifiable
ϵ = 0.5
m = 1

# two-dim simple harmonic oscillator (not used anywhere only in case some testing needed)
# H_ana(x, p, t) = ϵ * x[1]^2 + ϵ * x[2]^2 + 1/(2*m) * x[3]^2 + 1/(2*m) * x[4]^2
# H_ana(x, p, t) = cos(x[1]) + cos(x[2]) + 1/(2*m) * x[3]^2 + 1/(2*m) * x[4]^2

# Gradient function of the 2D hamiltonian
# grad_H_ana(x) = [x[3]; x[4]; -2ϵ * x[1]; -2ϵ * x[2]]
grad_H_ana(x) = [x[3]; x[4]; sin(x[1]); sin(x[2])]
function grad_H_ana!(dx, x, p, t)
    dx .= grad_H_ana(x)
end

# ------------------------------------------------------------
# Training Data
# ------------------------------------------------------------

println("Generate Training Data...")

# number of samples
num_samp = 8

# samples in p and q space
samp_range = LinRange(-10, 10, num_samp)

# initialize vector of matrices to store ODE solve output

# s depend on size of nd (total dims), 4 in the case here so we use samp_range x samp_range x samp_range x samp_range
s = collect(Iterators.product(fill(samp_range, nd)...))


# compute vector field from x state values
x = [collect(s[i]) for i in eachindex(s)]

dx = zeros(nd)
p = 0
t = 0
ẋ = [grad_H_ana!(copy(dx), _x, p, t) for _x in x]


# ----------------------------------------
# Compute Sparse Regression
# ----------------------------------------

# choose SINDy method
# (λ parameter must be close to noise value so that only coeffs with value around the noise are sparsified away)
# noiseGen_timeStep chosen randomly for now
method = HamiltonianSINDy(basis, grad_H_ana!, z, λ = 0.05, noise_level = 0.05, noiseGen_timeStep = 0.0, batch_size = 500, basis_coeff = 0.52)

# generate noisy references data at next time step
# y = gen_noisy_ref_data(method, x)

# Change to matrices for faster computations
x = hcat(x...)
ẋ = hcat(ẋ...)
# y = hcat(y...)

# collect training data
tdata = TrainingData(x, ẋ)

In [144]:
# dimension of system
d = size(tdata.x)[1] ÷ 2

# returns function that builds hamiltonian gradient through symbolics
fθ = ΔH_func_builder(d, method.z, method.basis)

RuntimeGeneratedFunction(#=in Main=#, #=using Main=#, :((ˍ₋out, ˍ₋arg1, a)->begin
          #= C:\Users\nigel\.julia\packages\SymbolicUtils\Oyu8Z\src\code.jl:373 =#
          #= C:\Users\nigel\.julia\packages\SymbolicUtils\Oyu8Z\src\code.jl:374 =#
          #= C:\Users\nigel\.julia\packages\SymbolicUtils\Oyu8Z\src\code.jl:375 =#
          begin
              begin
                  #= C:\Users\nigel\.julia\packages\Symbolics\BQlmn\src\build_function.jl:519 =#
                  #= C:\Users\nigel\.julia\packages\SymbolicUtils\Oyu8Z\src\code.jl:422 =# @inbounds begin
                          #= C:\Users\nigel\.julia\packages\SymbolicUtils\Oyu8Z\src\code.jl:418 =#
                          ˍ₋out[1] = (+)((+)((+)((+)((+)((+)((+)((+)((+)((+)((+)((+)((+)((+)((+)((+)((*)((^)(ˍ₋arg1[4], 2), (getindex)(a, 33)), (*)((^)(ˍ₋arg1[1], 2), (getindex)(a, 17))), (*)((^)(ˍ₋arg1[2], 2), (getindex)(a, 26))), (*)((cos)(ˍ₋arg1[3]), (getindex)(a, 37))), (*)((getindex)(a, 7), ˍ₋arg1[1])), (*)((getindex)(a, 13

In [145]:
function set_model(data, method)
    #TODO: updated ld to be a parameter
    ld = size(data.x)[1]
    ndim = size(data.x)[1]

    encoder = Chain(
    Dense(ndim => ld, sigmoid), 
    Dense(ld => ndim)
    )

    decoder = Chain(
    Dense(ndim => ld, sigmoid),  
    Dense(ld => ndim)
    )

    model = ( 
        (W = encoder,),
        (W = decoder,),
        (W = zeros(Float32, get_numCoeffs(method.basis)), ),
    )
    return model
end

set_model (generic function with 1 method)

# ------Enzyme Code------

In [None]:
function enc_ż(enc_jac_batch, ẋ_batch)
    enc_mult_ż = zero(ẋ_batch)
    for i in 1:size(enc_jac_batch, 2)
        enc_mult_ż[:, i] = (enc_jac_batch[:,i,:] * (ẋ_batch[:,i]))
    end
    return enc_mult_ż
end


function evaluate_fθ(fθ, enc_x_batch, coeffs)
    f = zero(enc_x_batch[:,1])
    out = zero(enc_x_batch)
    for i in 1:size(enc_x_batch, 2)
        fθ(f, enc_x_batch[:,i], coeffs)
        out[:,i] = f
    end
    return out
end

function dec_ẋ(dec_jac_batch, ż)
    dec_mult_ẋ = zero(ż)
    for i in 1:size(dec_jac_batch, 2)
        dec_mult_ẋ[:, i] = dec_jac_batch[:,i,:] * ż[:,i]
    end
    return dec_mult_ẋ
end

function Diff_ż(grad_fθ, enc_mult_ż)
    return sum(abs2, grad_fθ - enc_mult_ż)
end

function Diff_ẋ(dec_jac_batch, grad_fθ, ẋ_batch)
    dec_mult_ẋ = zero(ẋ_batch)
    for i in 1:size(dec_jac_batch, 2)
        dec_mult_ẋ[:, i] = dec_jac_batch[:,i,:] * grad_fθ[:,i]
    end
    return sum(abs2, dec_mult_ẋ - ẋ_batch)
end

function loss(model, x_batch, ẋ_batch, enc_mult_ż, dec_jac_batch, alphas, method)
    enc_x_batch = model[1].W(x_batch)

    coeffs = model[3].W

    # Compute the reconstruction loss for the entire batch
    L_r = sum(abs2, model[2].W(enc_x_batch) - x_batch)

    # Note: grad_fθ, dec_mult_ẋ, and L_c in loss function so model acts on terms in loss function
    # and gradient can see that and use that for its update calculations

    # encoded gradient from SINDy
    grad_fθ = evaluate_fθ(fθ, enc_x_batch, coeffs)

    # Difference b/w encoded gradients from SINDy and reference
    L_ż = alphas / 10 * Diff_ż(grad_fθ, enc_mult_ż)

    # decoded SINDy gradient ẋ = dx/dz*grad_fθ
    # dec_mult_ẋ = dec_ẋ(dec_jac_batch, grad_fθ)

    # Difference b/w decoded-encoded gradients from SINDy against reference
    ẋ_diff = Diff_ẋ(dec_jac_batch, grad_fθ, ẋ_batch)
    L_ẋ = alphas * ẋ_diff

    # Compute the total loss for the entire batch
    batchLoss = L_r + L_ż + L_ẋ

    # Mean of the coefficients averaged
    L_c = sum(abs, coeffs) / length(coeffs)

    batch_loss_average = batchLoss / size(x_batch, 2) + method.basis_coeff * L_c
    
    return batch_loss_average
end

In [None]:

# Define initial model
model = set_model(tdata, method)

# dmodel = Flux.fmap(model) do x
#     x isa Array ? zero(x) : x
# end

model_gradients = deepcopy(model)

# Flux gradient has problem working with the structure data directly
x = Float32.(tdata.x)
ẋ = Float32.(tdata.ẋ)

total_samples = size(x)[2]
num_batches = ceil(Int, total_samples / method.batch_size)

# Coefficients for the loss_kernel terms
alphas = round(sum(abs2, x) / sum(abs2, ẋ), sigdigits = 3)

# Derivatives of the encoder and decoder
enc_jac_batch = batched_jacobian(model[1].W, x)
dec_jac_batch = batched_jacobian(model[2].W, model[1].W(x))

# encoded gradient ż = ẋ*dz/dx
enc_mult_ż = enc_ż(enc_jac_batch, ẋ)

initial_loss_array = Vector{Float32}()

# initial guess
println("Initial Guess...")

# Set up the optimizer's state
opt_state = Flux.setup(Adam(), model)

println("Coeffs: $(model[3].W)")
println()
println("initial total loss:", loss(model, x, ẋ, enc_mult_ż, dec_jac_batch, alphas, method))

In [None]:

# Array to store the losses
epoch_loss_array = Vector{Float32}()

for epoch in 1:500
    epoch_loss = 0.0
    # Shuffle the data indices for each epoch
    shuffled_indices = shuffle(1:total_samples)
    for batch in 1:num_batches
        # Get the indices for the current batch
        batch_start = (batch - 1) * method.batch_size + 1
        batch_end = min(batch * method.batch_size, total_samples)
        batch_indices = shuffled_indices[batch_start:batch_end]

        # Extract the data for the current batch
        x_batch = x[:, batch_indices]
        ẋ_batch = ẋ[:, batch_indices]
        # Derivatives of the encoder and decoder
        enc_jac_batch = batched_jacobian(model[1].W, x_batch)
        dec_jac_batch = batched_jacobian(model[2].W, model[1].W(x_batch))
        # encoded gradient ż = ẋ*dz/dx
        enc_mult_ż = enc_ż(enc_jac_batch, ẋ_batch)

        # Compute gradients using Enzyme
        Enzyme.autodiff(Reverse, (model, x_batch, ẋ_batch, enc_mult_ż, dec_jac_batch, alphas, method) -> loss(model, x_batch, ẋ_batch, enc_mult_ż, dec_jac_batch, alphas, method), Active, Duplicated(model, model_gradients), Const(x_batch), Const(ẋ_batch), Const(enc_mult_ż), Const(dec_jac_batch), Const(alphas), Const(method))

        # Update the parameters
        Flux.Optimise.update!(opt_state, model, model_gradients)

        # Accumulate the loss for the current batch
        epoch_loss += loss(model, x_batch, ẋ_batch, enc_mult_ż, dec_jac_batch, alphas, method)
    end
    # Compute the average loss for the epoch
    epoch_loss /= num_batches

    # Store the epoch loss
    push!(epoch_loss_array, epoch_loss)

    # Print loss after some iterations
    if epoch % 100 == 0
        println("Epoch $epoch: Average Loss: $epoch_loss")
        println("Epoch $epoch: Coefficents: $(model[3].W)")
        println()
    end
end

# ------Zygote Code------

In [146]:
function evaluate_fθ(x_batch, model, fθ)
    out = Zygote.Buffer(x_batch)
    f = Zygote.Buffer(x_batch[:,1])
    for i in 1:size(x_batch, 2)
        fθ(f, model[1].W(x_batch)[:,i], model[3].W)
        out[:,i] = f[:]
    end
    return copy(out) 
end

evaluate_fθ (generic function with 1 method)

In [147]:
# Needed because Flux.gradient can't handle Flux.jacobian
function batched_jacobian(model_layer, x_batch)
    output_dim = size(model_layer(x_batch[:, 1]))[1]
    batch_size = size(x_batch, 2)
    
    batch_jac = zeros(Float32, output_dim, batch_size, size(x_batch, 1))
    
    for i in 1:batch_size
        x_input = x_batch[:, i]
        jac = Flux.jacobian(model_layer, x_input)[1]
        batch_jac[:, i, :] = jac
    end
    return batch_jac
end

batched_jacobian (generic function with 1 method)

In [148]:
# Get ż from encoder derivative and ẋ
function enc_ż(enc_jac_batch, ẋ_batch)
    enc_mult_ż = Zygote.Buffer(ẋ_batch)
    for i in 1:size(enc_jac_batch, 2)
        enc_mult_ż[:, i] = (enc_jac_batch[:,i,:] * (ẋ_batch[:,i]))
    end
    return copy(enc_mult_ż)
end

# Get ẋ from decoder derivative and ż
function dec_ẋ(dec_jac_batch, ż)
    dec_mult_ẋ = Zygote.Buffer(ż)
    for i in 1:size(dec_jac_batch, 2)
        dec_mult_ẋ[:, i] = dec_jac_batch[:,i,:] * ż[:,i]
    end
    return copy(dec_mult_ẋ)
end

dec_ẋ (generic function with 1 method)

In [150]:
function loss(model, x_batch, ẋ_batch, enc_jac_batch, dec_jac_batch, enc_mult_ż, method, fθ, alphas)
    # Compute the reconstruction loss for the entire batch
    L_r = sum(abs2, model[2].W(model[1].W(x_batch)) - x_batch)

    ### Note: grad_fθ, dec_mult_ẋ, and L_c in loss function so model acts on terms in loss function
    ### and gradient can see that and use that for its update calculations

    # encoded gradient from SINDy
    grad_fθ = evaluate_fθ(x_batch, model, fθ)

    L_ż = alphas / 10 * sum(abs2, enc_mult_ż - grad_fθ)

    # decoded SINDy gradient ẋ = dx/dz*grad_fθ
    dec_mult_ẋ = dec_ẋ(dec_jac_batch, grad_fθ)

    L_ẋ = alphas * sum(abs2, dec_mult_ẋ  - ẋ_batch)

    # Compute the total loss for the entire batch
    batchLoss = L_r + L_ż + L_ẋ

    # Mean of the coefficients averaged
    coeffs = model[3].W
    L_c = sum(abs, coeffs) / length(coeffs)

    batch_loss_average = batchLoss / size(x_batch, 2) + method.basis_coeff * L_c
    
    return batch_loss_average
end

loss (generic function with 1 method)

In [151]:
model = set_model(tdata, method)

((W = Chain(Dense(4 => 4, σ), Dense(4 => 4)),), (W = Chain(Dense(4 => 4, σ), Dense(4 => 4)),), (W = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],))

In [161]:
# Flux gradient has problem working with the structure data directly
x = Float32.(tdata.x)
ẋ = Float32.(tdata.ẋ)

total_samples = size(x)[2]
num_batches = ceil(Int, total_samples / method.batch_size)

# Coefficients for the loss_kernel terms
alphas = round(sum(abs2, x) / sum(abs2, ẋ), sigdigits = 3)

# Derivatives of the encoder and decoder
enc_jac_batch = batched_jacobian(model[1].W, x)
dec_jac_batch = batched_jacobian(model[2].W, model[1].W(x))

# encoded gradient ż = ẋ*dz/dx
enc_mult_ż = enc_ż(enc_jac_batch, ẋ)

4×4096 Matrix{Float32}:
 -0.104432   -0.154544  -0.290448  …  0.289954  0.154615  0.104467
 -0.269342   -0.306433  -1.42981      1.43657   0.307815  0.269566
 -0.0365111  -0.122082  -0.616071     0.617489  0.122315  0.0365417
 -0.309275   -0.371181  -1.12722      1.13055   0.372116  0.309462

In [171]:
initial_loss_array = Vector{Float32}()

# initial guess
println("Initial Guess...")

# Set up the optimizer's state
opt_state = Flux.setup(Adam(), model)

@show("Coeffs: $(model[3].W)")
@show("initial total loss:", loss(model, x, ẋ, enc_jac_batch, dec_jac_batch, enc_mult_ż, method, fθ, alphas))

342.00853027578444

In [None]:
# Array to store the losses
epoch_loss_array = Vector{Float32}()

for epoch in 1:2
    epoch_loss = 0.0
    # Shuffle the data indices for each epoch
    shuffled_indices = shuffle(1:total_samples)
    
    for batch in 1:num_batches
        # Get the indices for the current batch
        batch_start = (batch - 1) * method.batch_size + 1
        batch_end = min(batch * method.batch_size, total_samples)
        batch_indices = shuffled_indices[batch_start:batch_end]

        # Extract the data for the current batch
        x_batch = x[:, batch_indices]
        ẋ_batch = ẋ[:, batch_indices]
        # Derivatives of the encoder and decoder
        enc_jac_batch = batched_jacobian(model[1].W, x_batch)
        dec_jac_batch = batched_jacobian(model[2].W, model[1].W(x_batch))
        # encoded gradient ż = ẋ*dz/dx
        enc_mult_ż = enc_ż(enc_jac_batch, ẋ_batch)

        # Compute gradients using Flux
        gradients = Flux.gradient(model -> loss(model, x_batch, ẋ_batch, enc_jac_batch, dec_jac_batch, enc_mult_ż, method, fθ, alphas), model)[1]

        # Update the parameters
        Flux.Optimise.update!(opt_state, model, gradients)

        # Accumulate the loss for the current batch
        epoch_loss += loss(model, x_batch, ẋ_batch, enc_jac_batch, dec_jac_batch, enc_mult_ż, method, fθ, alphas)
    end
    # Compute the average loss for the epoch
    epoch_loss /= num_batches

    # Store the epoch loss
    push!(epoch_loss_array, epoch_loss)

    # Print loss after some iterations
    if epoch % 20 == 0
        println("Epoch $epoch: Average Loss: $epoch_loss")
        println("Epoch $epoch: Coefficents: $(model[3].W)")
        println()
    end
end