Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How can I solve complex-valued ordinary differential equations (ODEs) using neural networks, given limitations with complex data types in libraries like Lux? #818

Closed
RomanSahakyan03 opened this issue Feb 22, 2024 · 7 comments
Labels

Comments

@RomanSahakyan03
Copy link

Here I describe the code where I want to train the neural network but face the problem

Here are the libraries I use.

using NeuralPDE
using DifferentialEquations
using Plots
using Lux, Random

The function system_of_de! defines a system of four coupled complex-valued differential equations. The equations describe the dynamics of a physical system with four variables u[1] , u[2] , u[3] , and u[4]. This equations name is Bloch equations.

The parameters of the system are Ω , Δ , and Γ , which are defined outside the function.

function system_of_de!(du, u, p, t)
    Ω, Δ, Γ = p
    γ = Γ / 2

    du[1] =  im * Ω * (u[3] - u[4]) + Γ * u[2]
    du[2] = -im * Ω * (u[3] - u[4]) - Γ * u[2]
    du[3] = -(γ + im * Δ) * u[3] - im * Ω * (u[2] - u[1])
    du[4] = conj(du[3])

    return nothing
end

Initial Conditions, Time Span, and Parameters:

u0 = zeros(ComplexF64, 4)
u0[1] = 1

time_span = (0.0, 7.0)

#             Ω,     Δ,   Γ         
parameters = [100.0, 0.0, 1.0]

Defining the ODE Problem

problem = ODEProblem(system_of_de!, u0, time_span, parameters)
ODEProblem with uType Vector{ComplexF64} and tType Float64. In-place: true
timespan: (0.0, 7.0)
u0: 4-element Vector{ComplexF64}:

This part involves setting up and using the Neural Network Ordinary Differential Equation (NNODE) solver.

  • rng is a random number generator used for initialization.
  • chain defines the architecture of the neural network used by NNODE.
  • ps and st are the initial parameters of the neural network.
  • opt is an optimization algorithm used to train the neural network.
  • alg is an NNODE solver object that combines the neural network and optimizer.
rng = Random.default_rng()
Random.seed!(rng, 0)
chain = Chain(Dense(1, 5, σ), Dense(5, 1))
ps, st = Lux.setup(rng, chain) |> Lux.f64

Solving the ODE:

  • sol is the solution obtained by solving the problem using the NNODE solver.
  • maxiters specifies the maximum number of iterations the solver allows.
  • saveat defines how often the solution is saved during the simulation
using OptimizationOptimisers

opt = Adam(0.1)
alg = NNODE(chain, opt, init_params = ps)

sol = solve(problem, alg, verbose = true, maxiters = 2000, saveat = 0.01)

Checking the Result:

  • ground_truth is the solution obtained using a traditional numerical solver (Tsit5()).
  • The code plots both solutions for the first two variables (u[1] and u[2]) to compare them visually.
ground_truth  = solve(problem, Tsit5(), saveat = 0.01)    

plot(ground_truth.t, real.(ground_truth[1, :]), linecolor=:blue, legend = false)
plot!(ground_truth.t, real.(ground_truth[2, :]), linecolor=:blue)

plot!(sol.t, real.(sol[1, :]), linecolor=:red)
plot!(sol.t, real.(sol[2, :]), linecolor=:red)
@ChrisRackauckas
Copy link
Member

@avik-pal do you know the current state of complex numbers with Lux?

@avik-pal
Copy link
Member

avik-pal commented Feb 23, 2024

It should just work, but sigmoid will not work. It gives a very nice message as well (is is only printed if depwarn is on, my bad)

image

Use something like tanh

image

@sathvikbhagavan
Copy link
Member

A few things to make it work:

  1. Function should be out of place for NNODE
  2. Chain should have 4 outputs instead of 1 (as there are 4 states)
  3. Have to define - SciMLBase.allowscomplex(::NNODE) = true to allow complex types in the solve. We should maybe have it in the package itself.
  4. For QuadratureTraining, pass it explicitly as the default defines reltol based on the element type of u0 which cannot be complex.

This snippet of code below works.

using NeuralPDE
using DifferentialEquations
using Plots
using Lux, Random
using OptimizationOptimisers

function system_of_de(u, p, t)
    Ω, Δ, Γ = p
    γ = Γ / 2
    du =  [im * Ω * (u[3] - u[4]) + Γ * u[2],
          -im * Ω * (u[3] - u[4]) - Γ * u[2],
          -+ im * Δ) * u[3] - im * Ω * (u[2] - u[1]),
          conj(-+ im * Δ) * u[3] - im * Ω * (u[2] - u[1]))]
    return du
end

u0 = zeros(ComplexF64, 4)
u0[1] = 1
time_span = (0.0, 7.0)   
parameters = [100.0, 0.0, 1.0]

problem = ODEProblem(system_of_de, u0, time_span, parameters)

rng = Random.default_rng()
Random.seed!(rng, 0)
chain = Chain(Dense(1, 5, tanh), Dense(5, 4))
ps, st = Lux.setup(rng, chain) |> Lux.f64

opt = Adam(0.1)
alg = NNODE(chain, opt, init_params = ps, strategy = StochasticTraining(2))
SciMLBase.allowscomplex(::NNODE) = true
sol = solve(problem, alg, verbose = true, maxiters = 1, saveat = 0.01)

ground_truth  = solve(problem, Tsit5(), saveat = 0.01)

@sathvikbhagavan
Copy link
Member

@RomanSahakyan03, did the above script solve your problem?

@RomanSahakyan03
Copy link
Author

RomanSahakyan03 commented Mar 3, 2024

@sathvikbhagavan Yes, it helped a lot. I'm sorry that I didn't answer, but with this method, everything worked well. thank you very much!

@ChrisRackauckas
Copy link
Member

Have to define - SciMLBase.allowscomplex(::NNODE) = true to allow complex types in the solve. We should maybe have it in the package itself.

We definitely should.

Can we make this a doc example?

@RomanSahakyan03
Copy link
Author

RomanSahakyan03 commented Mar 3, 2024

@ChrisRackauckas I think you should do this, because there is no information about the SciMLBase.allowscomplex(::NNODE) parameter in the docs, and if it is not difficult, add the use of PINN to solve a system of differential equations with complex values

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants