Check number of threads being used first

In [None]:
num_threads = Threads.nthreads()

Code from Sparsification module

In [None]:
using Distributions
using GeometricIntegrators
using Optim
using Random
using Flux
using Distances
using Symbolics
using ForwardDiff
using Plots
using LinearAlgebra
using RuntimeGeneratedFunctions
RuntimeGeneratedFunctions.init(@__MODULE__)

gr()

_prod(a, b, c, arrs...) = a .* _prod(b, c, arrs...)
_prod(a, b) = a .* b
_prod(a) = a

# generates a vector out of symbolic arrays (p,q) with a certain dimension
function get_z_vector(dims)
    @variables q[1:dims]
    @variables p[1:dims]
    z = vcat(q,p)
    return z
end

# make combinations of bases of just the order that is given 
# e.g order = 2 will give just the bases whose powers sum to 2
function poly_combos(z, order, inds...)
    if order == 0
        return Num[1]
    elseif order == length(inds)
        return [_prod([z[i] for i in inds]...)]
    else
        start_ind = length(inds) == 0 ? 1 : inds[end]
        return vcat([poly_combos(z, order, inds..., j) for j in start_ind:length(z)]...)
    end
end

# gives all bases monomials up to a certain order
function primal_monomial_basis(z, order::Int)
    return Vector{Symbolics.Num}(vcat([poly_combos(z, i) for i in 1:order]...))
end

# calculates coefficient bases up to a certain order
# mostly for use with trigonometric functions example sin(k*z),
# where k is the coefficient
function primal_coeff_basis(z, max_coeff::Int)
    return Vector{Symbolics.Num}(vcat([k .* z for k in 1:max_coeff]...))
end

# calculates +,-,*,/ between states as a new basis
# the return output is a set to avoid duplicates
function primal_operator_basis(z, operator)
    return Vector{Symbolics.Num}([operator(z[i], z[j]) for i in 1:length(z)-1 for j in i+1:length(z)] ∪ [operator(z[j], z[i]) for i in 1:length(z)-1 for j in i+1:length(z)])
end

function primal_power_basis(z, max_power::Int)
    if max_power > 0
        return Vector{Symbolics.Num}(vcat([z.^i for i in 1:max_power]...))
    elseif max_power < 0
        return Vector{Symbolics.Num}(vcat([z.^-i for i in 1:abs(max_power)]...))
    end
end

function polynomial_basis(z::Vector{Symbolics.Num} = get_z_vector(2); polyorder::Int = 0, operator=nothing, max_coeff::Int = 0)
    primes = primal_monomial_basis(z, polyorder)
    primes = vcat(primes, primal_coeff_basis(z, max_coeff))
    if operator !== nothing
        primes = vcat(primes, primal_operator_basis(z, operator))
    end
    return primes
end

function trigonometric_basis(z::Vector{Symbolics.Num} = get_z_vector(2); polyorder::Int = 0, operator=nothing, max_coeff::Int = 0)
    primes = polynomial_basis(z, polyorder = polyorder, operator = operator, max_coeff = max_coeff)
    return vcat(sin.(primes), cos.(primes))
end

function exponential_basis(z::Vector{Symbolics.Num} = get_z_vector(2); polyorder::Int = 0, operator=nothing, max_coeff::Int = 0)
    primes = polynomial_basis(z, polyorder = polyorder, operator = operator, max_coeff = max_coeff)
    return exp.(primes)
end

function logarithmic_basis(z::Vector{Symbolics.Num} = get_z_vector(2); polyorder::Int = 0, operator=nothing, max_coeff::Int = 0)
    primes = polynomial_basis(z, polyorder = polyorder, operator = operator, max_coeff = max_coeff)
    return log.(abs.(primes))
end

function mixed_states_basis(basis::Vector{Symbolics.Num}...)
    mixed_states = Tuple(basis)
    
    ham = Vector{Symbolics.Num}()
    for i in eachindex(mixed_states)
        for j in i+1:lastindex(mixed_states)
            ham = vcat(ham, [mixed_states[i][k] * mixed_states[j][l] for k in 1:length(mixed_states[i]) for l in 1:length(mixed_states[j])])
        end
    end
    
    return Vector{Symbolics.Num}(ham)
end

# returns the number of required coefficients for the basis
function get_numCoeffs(basis::Vector{Symbolics.Num})
    return length(basis)
end


# gets a vector of combinations of hamiltonian basis
function get_basis_set(basis::Vector{Symbolics.Num}...)
    # gets a vector of combinations of basis
    basis = vcat(basis...)
    
    # removes duplicates
    basis = Vector{Symbolics.Num}(collect(unique(basis)))

    return basis
end

# returns a function that can build the gradient of the hamiltonian
function ΔH_func_builder(d::Int, z::Vector{Symbolics.Num} = get_z_vector(d), basis::Vector{Symbolics.Num}...) 
    # nd is the total number of dimensions of all the states, e.g. if q,p each of 3 dims, that is 6 dims in total
    nd = 2d
    Dz = Differential.(z)
    
    # collects and sums combinations of basis and coefficients"
    basis = get_basis_set(basis...)
   
    # gets number of terms in the basis
    @variables a[1:get_numCoeffs(basis)]
    
    # collect and sum combinations of basis and coefficients
    ham = sum(collect(a .* basis))
    
    # gives derivative of the hamiltonian, but not the skew-symmetric true one
    f = [expand_derivatives(dz(ham)) for dz in Dz]

    #simplify the expression potentially to make it faster
    f = simplify(f)
    
    # line below makes the vector into a hamiltonian vector field by multiplying with the skew-symmetric matrix
    ∇H = vcat(f[d+1:2d], -f[1:d])
    
    # builds a function that calculates Hamiltonian gradient and converts the function to a native Julia function
    ∇H_eval = @RuntimeGeneratedFunction(Symbolics.inject_registered_module_functions(build_function(∇H, z, a)[2]))
    
    return ∇H_eval
end

struct HamiltonianSINDy{T, GHT}
    basis::Vector{Symbolics.Num} # the augmented basis for sparsification
    analytical_fθ::GHT
    z::Vector{Symbolics.Num} 
    λ::T # Sparsification Parameter
    noise_level::T # Noise amplitude added to the data
    noiseGen_timeStep::T # Time step for the integrator to get noisy data 
    nloops::Int # Sparsification Loops
    
    function HamiltonianSINDy(basis::Vector{Symbolics.Num},
        analytical_fθ::GHT = missing,
        z::Vector{Symbolics.Num} = get_z_vector(2);
        λ::T = 0.05,
        noise_level::T = 0.00,
        noiseGen_timeStep::T = 0.05,
        nloops = 10) where {T, GHT <: Union{Base.Callable,Missing}}

        new{T, GHT}(basis, analytical_fθ, z, λ, noise_level, noiseGen_timeStep, nloops)
    end
end

function gen_noisy_t₂_data(method::HamiltonianSINDy, x)
    # initialize timestep data for analytical solution
    tstep = method.noiseGen_timeStep
    tspan = (zero(tstep), tstep)

    function next_timestep(x)
        prob_ref = ODEProblem((dx, t, x, params) -> method.analytical_fθ(dx, x, params, t), tspan, tstep, x)
        sol = integrate(prob_ref, Gauss(2))
        sol.q[end]
    end

    data_ref = [next_timestep(_x) for _x in x]

    # add noise
    data_ref_noisy = [_x .+ method.noise_level .* randn(size(_x)) for _x in data_ref]

    return data_ref_noisy

end

struct TrainingData{AT<:AbstractArray}
    x::AT # initial condition
    ẋ::AT # initial condition
    y::AT # noisy data at next time step

    TrainingData(x::AT, ẋ::AT, y::AT) where {AT} = new{AT}(x, ẋ, y)
    TrainingData(x::AT, ẋ::AT) where {AT} = new{AT}(x, ẋ)
end



In [None]:
# --------------------
# Setup
# --------------------

println("Setting up...")

# 2D system with 4 variables [q₁, q₂, p₁, p₂]
const nd = 4

z = get_z_vector(nd÷2)
polynomial = polynomial_basis(z, polyorder=3)
trigonometric  = trigonometric_basis(z, max_coeff=1)
# prime_diff = primal_operator_basis(z, -)
# basis = get_basis_set(polynomial, trigonometric, prime_diff)
basis = get_basis_set(polynomial, trigonometric)
# initialize analytical function, keep λ smaller than ϵ so system is identifiable
ϵ = 0.5
m = 1

# two-dim simple harmonic oscillator (not used anywhere only in case some testing needed)
# H_ana(x, p, t) = ϵ * x[1]^2 + ϵ * x[2]^2 + 1/(2*m) * x[3]^2 + 1/(2*m) * x[4]^2
# H_ana(x, p, t) = cos(x[1]) + cos(x[2]) + 1/(2*m) * x[3]^2 + 1/(2*m) * x[4]^2

# Gradient function of the 2D hamiltonian
# grad_H_ana(x) = [x[3]; x[4]; -2ϵ * x[1]; -2ϵ * x[2]]
grad_H_ana(x) = [x[3]; x[4]; sin(x[1]); sin(x[2])]
function grad_H_ana!(dx, x, p, t)
    dx .= grad_H_ana(x)
end

# ------------------------------------------------------------
# Training Data
# ------------------------------------------------------------

println("Generate Training Data...")

# number of samples
num_samp = 10

# samples in p and q space
samp_range = LinRange(-10, 10, num_samp)

# initialize vector of matrices to store ODE solve output

# s depend on size of nd (total dims), 4 in the case here so we use samp_range x samp_range x samp_range x samp_range
s = collect(Iterators.product(fill(samp_range, nd)...))


# compute vector field from x state values
x = [collect(s[i]) for i in eachindex(s)]

# normal_x = Flux.normalize(x)


dx = zeros(nd)
p = 0
t = 0
ẋ = [grad_H_ana!(copy(dx), _x, p, t) for _x in x]


# ----------------------------------------
# Compute Sparse Regression
# ----------------------------------------

# choose SINDy method
# (λ parameter must be close to noise value so that only coeffs with value around the noise are sparsified away)
# noiseGen_timeStep chosen randomly for now
method = HamiltonianSINDy(basis, grad_H_ana!, z, λ = 0.05, noise_level = 0.00, noiseGen_timeStep = 0.3)

# generate noisy references data at next time step
y = gen_noisy_t₂_data(method, x)

# collect training data
x = hcat(x...)
ẋ = hcat(ẋ...)
y = hcat(y...)
tdata = TrainingData(x, ẋ, y)

In [None]:
# dimension of system
d = size(tdata.x, 1) ÷ 2

# returns function that builds hamiltonian gradient through symbolics
fθ = ΔH_func_builder(d, method.z, method.basis)

In [None]:
encoder(x) =  (model[1].W  * x .+ model[1].b) 
decoder(x) = (model[2].W * x .+ model[2].b)

In [None]:
function batched_jacobian(model_layer, x_batch, model)
    output_dim = size(model_layer(x_batch[:, 1]))[1]
    batch_size = size(x_batch, 2)
    
    batch_jac = zeros(output_dim, batch_size, size(x_batch, 1))
    
    for i in 1:batch_size
        x_input = x_batch[:, i]
        jac = Flux.jacobian(model_layer, x_input)[1]
        batch_jac[:, i, :] = jac
    end
    return batch_jac
end

# Get ż from dz/dx and ẋ
function enc_ż(enc_jac_batch, ẋ_batch)
    ż_ref = zero(ẋ_batch)
    for i in 1:size(enc_jac_batch, 2)
        ż_ref[:, i] = enc_jac_batch[:,i,:] * (ẋ_batch[:,i])
    end
    return ż_ref
end

function evaluate_fθ(fθ, enc_x_batch, coeffs)
    f = zero(enc_x_batch[:,1])
    out = zero(enc_x_batch)
    for i in 1:size(enc_x_batch, 2)
        fθ(f, enc_x_batch[:,i], coeffs)
        out[:,i] = f
    end
    return out
end

# Get ẋ from decoder derivative (dx/dz) and ż
function dec_ẋ(dec_jac_batch, ż)
    dec_mult_ẋ = zero(ż)
    for i in 1:size(dec_jac_batch, 2)
        dec_mult_ẋ[:, i] = dec_jac_batch[:,i,:] * ż[:,i]
    end
    return dec_mult_ẋ
end

function Diff_ż(grad_fθ, ż_ref)
    return sum(abs2, (grad_fθ - ż_ref))
end

function Diff_ẋ(dec_jac_batch, grad_fθ, ẋ_ref)
    ẋ_SINDy = zeros(size(ẋ_ref))  # Initialize with zeros of the same size as ẋ_ref
    for i in 1:size(dec_jac_batch, 2)
        ẋ_SINDy[:, i] = ForwardDiff.value.(dec_jac_batch[:,i,:]) * ForwardDiff.value.(grad_fθ[:,i])
    end
    return sum(abs2, (ẋ_SINDy - ẋ_ref))
end

# set the values of the model coeffs to the values of the Ξ vector for each state
function set_coeffs(model_Coeffs, Ξ, biginds)
    coeffs = zero(model_Coeffs)
    coeffs[biginds] = Ξ
    return coeffs
end

function loss_kernel(recon_model, x_batch, ẋ_batch, ż_ref_batch, dec_jac_batch, alphas, basis_coeff)
    enc_x_batch = recon_model[1].W * x_batch

    coeffs = recon_model[3].W

    # Compute the reconstruction loss for the entire batch
    L_r = sum(abs2, recon_model[2].W * enc_x_batch - x_batch)

    # Note: grad_fθ, dec_mult_ẋ, and L_c in loss function so recon_model acts on terms in loss function
    # and gradient can see that and use that for its update calculations

    # encoded gradient from SINDy
    grad_fθ = evaluate_fθ(fθ, enc_x_batch, coeffs)

    # Difference b/w encoded gradients from SINDy and reference
    #TODO: could also try alphas/10 instead of alphas/100
    L_ż = alphas / 100 * Diff_ż(grad_fθ, ż_ref_batch)

    # Difference b/w decoded-encoded gradients from SINDy against reference
    ẋ_SINDy = Diff_ẋ(dec_jac_batch, grad_fθ, ẋ_batch)
    L_ẋ = alphas * ẋ_SINDy

    # Compute the total loss for the entire batch
    batchLoss = L_r + L_ż + L_ẋ

    # Mean of the coefficients averaged
    L_c = sum(abs, coeffs) / length(coeffs)

    batch_loss_average = batchLoss / size(x_batch, 2) + basis_coeff * L_c

    return batch_loss_average
end

In [None]:
# define loss function
function loss(flattened_model::AbstractVector, x_batch, ẋ_batch, ż_ref_batch, dec_jac_batch, alphas, basis_coeff)
    # Convert the flattened parameters back to the original structure
    local recon_model = reconstruct_model(flattened_model, ld, ndim)
    return loss_kernel(recon_model, x_batch, ẋ_batch, ż_ref_batch, dec_jac_batch, alphas, basis_coeff)
end

In [None]:
function flatten_model(model)
    θ = [model[1].W, model[1].b, model[2].W, model[2].b, model[3].W]
    # Flatten the model into a single vector
    return flattened_model = vcat([vec(θ[i]) for i in 1:length(θ)]...)
end


In [None]:
function reconstruct_model(flattened_model, ld, ndim)
    param_matrix_size = ld*ndim
    reconstructed_model = (
        (W = reshape(flattened_model[1:param_matrix_size], ld, ndim), b = reshape(flattened_model[param_matrix_size+1:param_matrix_size+ld], ld)),
        (W = reshape(flattened_model[param_matrix_size+ld+1:2*param_matrix_size+ld], ndim, ld), b = reshape(flattened_model[2*param_matrix_size+ld+1:2*param_matrix_size+2*ld], ld)),
        (W = flattened_model[2*param_matrix_size+2*ld+1:end], )
    )    
  return reconstructed_model
end

In [None]:
# latent dimension: ld
ld = size(tdata.x, 1)
ndim = size(tdata.x, 1)
model = ( 
	(W = Matrix{Float64}(LinearAlgebra.I, ndim, ld), b = zeros(ndim)),
	(W = Matrix{Float64}(LinearAlgebra.I, ndim, ld), b = zeros(ld)),
	(W = ones(get_numCoeffs(method.basis)), )
)

In [None]:
Initial_loss_array = Vector{Float64}()

# initial guess
println("Initial Guess...")

optimizer = Adam()

batch_size = 500  # Set your desired batch size
num_batches = ceil(Int, size(tdata.x, 2) / batch_size)

# Derivatives of the encoder and decoder
enc_jac = batched_jacobian(encoder, tdata.x, model)
dec_jac = batched_jacobian(decoder, model[1].W * tdata.x .+ model[1].b, model)

# encoded gradient ż = dz/dx*ẋ
ż_ref = enc_ż(enc_jac, tdata.ẋ)

alphas = round(sum(abs2, tdata.x) / sum(abs2, tdata.ẋ), sigdigits = 3)

basis_coeff = 0.62

flattened_model = flatten_model(model)
println(model[3].W)
push!(Initial_loss_array, loss(flattened_model, tdata.x, tdata.ẋ, ż_ref, dec_jac, alphas, basis_coeff))

In [None]:
using Random

for epoch in 1:1000
    # Shuffle the indices of the data
    shuffled_indices = shuffle(1:size(tdata.x, 2))
    
    # Training phase
    train_batch_loss = 0.0
    for batch in 1:num_batches
        # Get the indices for the current batch
        batch_indices = shuffled_indices[(batch-1)*batch_size+1:min(batch*batch_size, size(tdata.x, 2))]
        
        # Extract the current batch from tdata.x and tdata.y
        x_batch = tdata.x[:, batch_indices]
        ẋ_batch = tdata.y[:, batch_indices]

        # Derivatives of the encoder and decoder
        enc_jac_batch = batched_jacobian(encoder, x_batch, model)
        dec_jac_batch = batched_jacobian(decoder, model[1].W * ẋ_batch .+ model[1].b, model)

        # encoded gradient ż = dz/dx*ẋ
        ż_ref_batch = enc_ż(enc_jac_batch, ẋ_batch)
        
        # Compute gradients using ForwardDiff.jl
        gradients = ForwardDiff.gradient(flattened_model -> loss(flattened_model, x_batch, ẋ_batch, ż_ref_batch, dec_jac_batch, alphas, basis_coeff), flattened_model)
        
        # Update the parameters using the optimizer
        Flux.Optimise.update!(optimizer, flattened_model, gradients)
    
        train_batch_loss += loss(flattened_model, x_batch, ẋ_batch, ż_ref_batch, dec_jac_batch, alphas, basis_coeff)
    end
    train_batch_loss /= num_batches

    push!(Initial_loss_array, train_batch_loss)
    if epoch % 100 == 0
        println("Epoch $epoch: Average Train Loss: $train_batch_loss")
    end
end

In [None]:
plot(log.(Initial_loss_array), label = "Initial Batch Loss", ylabel="Log Loss", xlabel="Iterations")

In [None]:
reconstructed_model = reconstruct_model(flattened_model, ld, ndim)
coeffs = reconstructed_model[3].W
println(reconstructed_model)

In [None]:
# Regress dynamics onto remaining terms to find sparse coeffs
function sparseloss(flattened_model::AbstractVector, x_batch, ẋ_batch, coeffs, ż_ref_batch, dec_jac_batch, alphas, basis_coeff, biginds, ld, ndim)
    c = zeros(eltype(flattened_model), axes(coeffs))
    c[biginds] .= flattened_model[(2*ld*ndim)+(2*ld)+1:end]

    param_matrix_size = ld*ndim
    local reconstructed_model = (
        (W = reshape(flattened_model[1:param_matrix_size], ld, ndim), b = reshape(flattened_model[param_matrix_size+1:param_matrix_size+ld], ld)),
        (W = reshape(flattened_model[param_matrix_size+ld+1:2*param_matrix_size+ld], ndim, ld), b = reshape(flattened_model[2*param_matrix_size+ld+1:2*param_matrix_size+2*ld], ld)),
        (W = c, )
    ) 
    return loss_kernel(reconstructed_model, x_batch, ẋ_batch, ż_ref_batch, dec_jac_batch, alphas, basis_coeff)
end

In [None]:
sparse_loss_array = Vector{Float64}()

for n in 1:method.nloops
    # find coefficients below λ threshold
    smallinds = abs.(coeffs) .< method.λ
    biginds = .~smallinds
    
    # check if there are any small coefficients != 0 left
    all(coeffs[smallinds] .== 0) && break

    println("Iteration #$n...")

    # set all small coefficients to zero
    coeffs[smallinds] .= 0

    # θ is partly a reference to coeffs[biginds] so coeffs[biginds] will be updated
    θ = [reconstructed_model[1].W, reconstructed_model[1].b, reconstructed_model[2].W, reconstructed_model[2].b, coeffs[biginds]]
    # Flatten the model into a single vector
    flattened_model = vcat([vec(θ[i]) for i in 1:length(θ)]...)
    
    for epoch in 1:500
        epoch_loss = 0.0
        # Shuffle the indices of the data
        shuffled_indices = shuffle(1:size(tdata.x, 2))
        for batch in 1:num_batches
           # Get the indices for the current batch
            batch_indices = shuffled_indices[(batch-1)*batch_size+1:min(batch*batch_size, size(tdata.x, 2))]
        
            # Extract the current batch from tdata.x and tdata.y
            x_batch = tdata.x[:, batch_indices]
            ẋ_batch = tdata.y[:, batch_indices]

            # Derivatives of the encoder and decoder
            enc_jac_batch = batched_jacobian(encoder, x_batch, model)
            dec_jac_batch = batched_jacobian(decoder, model[1].W * ẋ_batch .+ model[1].b, model)

            # encoded gradient ż = dz/dx*ẋ
            ż_ref_batch = enc_ż(enc_jac_batch, ẋ_batch)

            # Compute gradients using ForwardDiff.jl
            gradients = ForwardDiff.gradient(flattened_model -> sparseloss(flattened_model, x_batch, ẋ_batch, coeffs, ż_ref_batch, dec_jac_batch, alphas, basis_coeff, biginds, ld, ndim), flattened_model)
        
            # Update the parameters using the optimizer
            Flux.Optimise.update!(optimizer, flattened_model, gradients)
            
            # push!(sparse_loss_array, sparseloss(flattened_model, batch_x, batch_y, coeffs, biginds, ld, ndim))

            coeffs[biginds] = flattened_model[(2*ld*ndim)+(2*ld)+1:end]
            
            # if epoch % 3 == 0 && batch % 3 == 0
            #     @show epoch, batch, sparseloss(flattened_model, x_batch, ẋ_batch, coeffs, ż_ref_batch, dec_jac_batch, alphas, basis_coeff, biginds, ld, ndim)
            #     @show coeffs  
            # end
            
            epoch_loss += sparseloss(flattened_model, x_batch, ẋ_batch, coeffs, ż_ref_batch, dec_jac_batch, alphas, basis_coeff, biginds, ld, ndim)
        end
        
        avg_epoch_loss = epoch_loss / num_batches  # Calculate average loss for the epoch
        push!(sparse_loss_array, avg_epoch_loss)
        if epoch % 50 == 0
            println("Epoch $epoch: Average Loss: $avg_epoch_loss")
        end
    end

    coeffs[biginds] = flattened_model[(2*ld*ndim)+(2*ld)+1:end]

    param_matrix_size = ld*ndim
    reconstructed_model = (
        (W = reshape(flattened_model[1:param_matrix_size], ld, ndim), b = reshape(flattened_model[param_matrix_size+1:param_matrix_size+ld], ld)),
        (W = reshape(flattened_model[param_matrix_size+ld+1:2*param_matrix_size+ld], ndim, ld), b = reshape(flattened_model[2*param_matrix_size+ld+1:2*param_matrix_size+2*ld], ld)),
        (W = flattened_model[2*param_matrix_size+2*ld+1:end], )
    )
    
    println(coeffs)
end

In [None]:
plot(log.(sparse_loss_array), label = "Sparse Batch Loss", ylabel="Log Loss", xlabel="Iterations")

Continue further optimization till desired convergence is achieved i.e loss doesn't change much

In [None]:
final_loss_array = Vector{Float64}()

# find coefficients below λ threshold
smallinds = abs.(coeffs) .< method.λ
biginds = .~smallinds

# set all small coefficients to zero
coeffs[smallinds] .= 0

# θ is partly a reference to coeffs[biginds] so coeffs[biginds] will be updated
θ = [reconstructed_model[1].W, reconstructed_model[1].b, reconstructed_model[2].W, reconstructed_model[2].b, coeffs[biginds]]
# Flatten the model into a single vector
flattened_model = vcat([vec(θ[i]) for i in 1:length(θ)]...)

# Shuffle the indices for the batches
shuffled_indices = randperm(length(tdata.x))

for epoch in 1:500
    epoch_loss = 0.0
    # Shuffle the indices of the data
    shuffled_indices = shuffle(1:size(tdata.x, 2))
    for batch in 1:num_batches
        # Get the indices for the current batch
        batch_indices = shuffled_indices[(batch-1)*batch_size+1:min(batch*batch_size, size(tdata.x, 2))]
    
        # Extract the current batch from tdata.x and tdata.y
        x_batch = tdata.x[:, batch_indices]
        ẋ_batch = tdata.y[:, batch_indices]

        # Derivatives of the encoder and decoder
        enc_jac_batch = batched_jacobian(encoder, x_batch, model)
        dec_jac_batch = batched_jacobian(decoder, model[1].W * ẋ_batch .+  model[1].b, model)

        # encoded gradient ż = dz/dx*ẋ
        ż_ref_batch = enc_ż(enc_jac_batch, ẋ_batch)

        # Compute gradients using ForwardDiff.jl
        gradients = ForwardDiff.gradient(flattened_model -> sparseloss(flattened_model, x_batch, ẋ_batch, coeffs, ż_ref_batch, dec_jac_batch, alphas, basis_coeff, biginds, ld, ndim), flattened_model)
    
        # Update the parameters using the optimizer
        Flux.Optimise.update!(optimizer, flattened_model, gradients)

        coeffs[biginds] = flattened_model[(2*ld*ndim)+(2*ld)+1:end]
        
        epoch_loss += sparseloss(flattened_model, x_batch, ẋ_batch, coeffs, ż_ref_batch, dec_jac_batch, alphas, basis_coeff, biginds, ld, ndim)
    end
    
    avg_epoch_loss = epoch_loss / num_batches
    push!(final_loss_array, avg_epoch_loss)
    if epoch % 50 == 0
        println("Epoch $epoch: Average Loss: $avg_epoch_loss")
    end
end

coeffs[biginds] = flattened_model[(2*ld*ndim)+(2*ld)+1:end]

param_matrix_size = ld*ndim
reconstructed_model = (
    (W = reshape(flattened_model[1:param_matrix_size], ld, ndim), b = reshape(flattened_model[param_matrix_size+1:param_matrix_size+ld], ld)),
    (W = reshape(flattened_model[param_matrix_size+ld+1:2*param_matrix_size+ld], ndim, ld), b = reshape(flattened_model[2*param_matrix_size+ld+1:2*param_matrix_size+2*ld], ld)),
    (W = flattened_model[2*param_matrix_size+2*ld+1:end], )
)

println(coeffs)

In [None]:
plot((final_loss_array), label = "Final Iteration Loss")

In [None]:
function (vectorfield)(dz, z)
    fθ(dz, z, reconstructed_model[3].W)
    return dz
end

(vectorfield)(dz, z, p, t) = vectorfield(dz, z)

In [None]:
xsol = hcat([decoder(data_sindy.q[i,:]) for i in axes(data_sindy.q, 1)]...)

In [None]:
# ----------------------------------------
# Plot Results
# ----------------------------------------

println("Plotting...")

t_step = 0.01
t_span = (0.0,1.0)

for i in 1:1
    idx = rand(1:length(s))

    prob_reference = ODEProblem((dx, t, x, params) -> grad_H_ana!(dx, x, params, t), t_span, t_step, (x[:, idx]))
    data_reference = integrate(prob_reference, Gauss(1))

    prob_sindy = ODEProblem((dx, t, x, params) -> vectorfield(dx, x, params, t), t_span, t_step, (encoder(x[:, idx])))
    data_sindy = integrate(prob_sindy, Gauss(1))

    xsol = hcat([decoder(data_sindy.q[i,:]) for i in axes(data_sindy.q, 1)]...)
    
    p1 = plot(xlabel = "Time", ylabel = "q₁")
    scatter!(p1, data_reference.t, data_reference.q[:,1], label = "Data q₁")
    scatter!(p1, data_sindy.t, xsol[1,:], markershape=:xcross, label = "Identified q₁")

    p3 = plot(xlabel = "Time", ylabel = "p₁")
    scatter!(p3, data_reference.t, data_reference.q[:,3], label = "Data p₁")
    scatter!(p3, data_sindy.t, xsol[3,:], markershape=:xcross, label = "Identified p₁")

    plot!(size=(1000,1000))
    display(plot(p1, p3, title="Analytical vs Calculated q₁ & p₁ in a 2D system with Euler"))

    p2 = plot(xlabel = "Time", ylabel = "q₂")
    scatter!(p2, data_reference.t, data_reference.q[:,2], label = "Data q₂")
    scatter!(p2, data_sindy.t, xsol[2,:], markershape=:xcross, label = "Identified q₂")

    p4 = plot(xlabel = "Time", ylabel = "p₂")
    scatter!(p4, data_reference.t, data_reference.q[:,4], label = "Data p₂")
    scatter!(p4, data_sindy.t, xsol[4,:], markershape=:xcross, label = "Identified p₂")

    plot!(size=(1000,1000))
    display(plot(p2, p4, title="Analytical vs Calculated q₂ & p₂ in a 2D system with Euler"))

end