New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How can I solve complex-valued ordinary differential equations (ODEs) using neural networks, given limitations with complex data types in libraries like Lux? #818
Comments
@avik-pal do you know the current state of complex numbers with Lux? |
A few things to make it work:
This snippet of code below works. using NeuralPDE
using DifferentialEquations
using Plots
using Lux, Random
using OptimizationOptimisers
function system_of_de(u, p, t)
Ω, Δ, Γ = p
γ = Γ / 2
du = [im * Ω * (u[3] - u[4]) + Γ * u[2],
-im * Ω * (u[3] - u[4]) - Γ * u[2],
-(γ + im * Δ) * u[3] - im * Ω * (u[2] - u[1]),
conj(-(γ + im * Δ) * u[3] - im * Ω * (u[2] - u[1]))]
return du
end
u0 = zeros(ComplexF64, 4)
u0[1] = 1
time_span = (0.0, 7.0)
parameters = [100.0, 0.0, 1.0]
problem = ODEProblem(system_of_de, u0, time_span, parameters)
rng = Random.default_rng()
Random.seed!(rng, 0)
chain = Chain(Dense(1, 5, tanh), Dense(5, 4))
ps, st = Lux.setup(rng, chain) |> Lux.f64
opt = Adam(0.1)
alg = NNODE(chain, opt, init_params = ps, strategy = StochasticTraining(2))
SciMLBase.allowscomplex(::NNODE) = true
sol = solve(problem, alg, verbose = true, maxiters = 1, saveat = 0.01)
ground_truth = solve(problem, Tsit5(), saveat = 0.01) |
@RomanSahakyan03, did the above script solve your problem? |
@sathvikbhagavan Yes, it helped a lot. I'm sorry that I didn't answer, but with this method, everything worked well. thank you very much! |
We definitely should. Can we make this a doc example? |
@ChrisRackauckas I think you should do this, because there is no information about the SciMLBase.allowscomplex(::NNODE) parameter in the docs, and if it is not difficult, add the use of PINN to solve a system of differential equations with complex values |
Here I describe the code where I want to train the neural network but face the problem
Here are the libraries I use.
The function
system_of_de!
defines a system of four coupled complex-valued differential equations. The equations describe the dynamics of a physical system with four variablesu[1]
,u[2]
,u[3]
, andu[4]
. This equations name is Bloch equations.The parameters of the system are
Ω
,Δ
, andΓ
, which are defined outside the function.Initial Conditions, Time Span, and Parameters:
Defining the ODE Problem
This part involves setting up and using the Neural Network Ordinary Differential Equation (NNODE) solver.
rng
is a random number generator used for initialization.chain
defines the architecture of the neural network used by NNODE.ps
andst
are the initial parameters of the neural network.opt
is an optimization algorithm used to train the neural network.alg
is an NNODE solver object that combines the neural network and optimizer.Solving the ODE:
sol
is the solution obtained by solving theproblem
using the NNODE solver.maxiters
specifies the maximum number of iterations the solver allows.saveat
defines how often the solution is saved during the simulationChecking the Result:
ground_truth
is the solution obtained using a traditional numerical solver (Tsit5()).u[1]
andu[2]
) to compare them visually.The text was updated successfully, but these errors were encountered: