Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for equations with complex numbers #821

Closed
sathvikbhagavan opened this issue Mar 3, 2024 · 20 comments
Closed

Add support for equations with complex numbers #821

sathvikbhagavan opened this issue Mar 3, 2024 · 20 comments

Comments

@sathvikbhagavan
Copy link
Member

#818 describes a problem with complex numbers. We can add it properly in NeuralPDE and also have an example explaining it.

@RomanSahakyan03
Copy link

@sathvikbhagavan Is there anything I can do to help?

@IromainI
Copy link

IromainI commented Mar 9, 2024

@sathvikbhagavan I am doing the same task, and I have a code that considers uncorrected. You recently helped me with the output of the result.

using NeuralPDE, Lux, CUDA, Random, ComponentArrays
using Optimization
using OptimizationOptimisers
using OptimizationOptimJL
using Plots
using DifferentialEquations
import ModelingToolkit: Interval

@parameters t
@variables ρ11(..), ρ22(..), ρ12(..), ρ21(..)
D = Differential(t)

Ω = 10
Δ = 0
Γ = 1
γ = 1/2

eqs = [
    (D(ρ11(t)) ~ im * Ω * (ρ12(t) - ρ21(t)) + Γ * ρ22(t))...,
    (D(ρ22(t)) ~ -im * Ω * (ρ12(t) - ρ21(t)) - Γ * ρ22(t))...,
    (D(ρ12(t)) ~ -(γ + im * Δ) * ρ12(t) - im * Ω * (ρ22(t) - ρ11(t)))...,
    D(ρ21(t)) ~ conj(D(ρ12(t)))
]

bcs = [ρ11(0) ~ 1 + 0im, ρ22(0) ~ 0 + 0im, ρ12(0) ~ 0 + 0im, ρ21(0) ~ 0 + 0im]
domains = [t ∈ Interval(0.0, 45.0)]
dt = 0.01

input_ = length(domains)
n = 40
chain = [Lux.Chain(Lux.Dense(input_, n, Lux.σ),
                   Lux.Dense(n, n, Lux.σ), 
                   Lux.Dense(n, n, Lux.σ), 
                   Lux.Dense(n, n, Lux.σ), 
                   Lux.Dense(n, n, Lux.σ), 
                   Lux.Dense(n, 1)) for _ in 1:4]

@named pdesystem = PDESystem(eqs, bcs, domains, [t], [ρ11(t), ρ22(t), ρ12(t), ρ21(t)])

strategy = NeuralPDE.GridTraining(dt)
discretization = PhysicsInformedNN(chain, strategy)
sym_prob = NeuralPDE.symbolic_discretize(pdesystem, discretization)

pde_loss_functions = sym_prob.loss_functions.pde_loss_functions
bc_loss_functions = sym_prob.loss_functions.bc_loss_functions

callback = function (p, l)
    println("loss: ", l)
    # println("pde_losses: ", map(l_ -> l_(p), pde_loss_functions))
    # println("bcs_losses: ", map(l_ -> l_(p), bc_loss_functions))
    return false
end

loss_functions = [pde_loss_functions; bc_loss_functions]

function loss_function(θ, p)
    sum(map(l -> l(θ), loss_functions))
end

f_ = OptimizationFunction(loss_function, Optimization.AutoZygote())
prob = Optimization.OptimizationProblem(f_, sym_prob.flat_init_params)

opt1 = OptimizationOptimisers.Adam(0.001)
opt2 = OptimizationOptimJL.BFGS()
opt3 = LBFGS(linesearch = BackTracking())

res = Optimization.solve(prob, opt3; callback = callback, maxiters = 150)

phi = discretization.phi
ts = 0.0:0.01:45.0 |> collect

minimizers_ = [res.u.depvar[sym_prob.depvars[i]] for i in 1:4]

u_predict = [
    vec(phi[1](ts', res.u.depvar.ρ11)),
    vec(phi[2](ts', res.u.depvar.ρ22)),
    vec(phi[3](ts', res.u.depvar.ρ12)),
    vec(phi[4](ts', res.u.depvar.ρ21))
]

plot(ts, u_predict[1], linecolor=:red, label = "ml ρ11")
plot!(ts, u_predict[2], linecolor=:blue, label = "ml ρ22")
plot!(ts, u_predict[3], linecolor=:green, label = "ml ρ12")
plot!(ts, u_predict[4], linecolor=:yellow, label = "ml ρ21")

@sathvikbhagavan
Copy link
Member Author

Hey @RomanSahakyan03, apologies for the late reply. Adding support for complex numbers is trivial I think and I was working on #815 first to get doc builds passing. I will add the complex number support after that.

Hey @IromainI, I am not sure I understand your question 😅. Is there an error?

@RomanSahakyan03
Copy link

RomanSahakyan03 commented Mar 10, 2024

@sathvikbhagavan yep, no problem. Thanks for it. I hope you'll add for both NNODE() and PhysicalInformedNN() functions

@RomanSahakyan03
Copy link

@sathvikbhagavan I saw that you have closed your issues, can we work on complex numbers, because I have a lot of questions about this, please?

@sathvikbhagavan
Copy link
Member Author

Yes, I will start working on it.

@RomanSahakyan03
Copy link

@sathvikbhagavan I can provide some problems I've received, if you want

@sathvikbhagavan
Copy link
Member Author

I can provide some problems I've received, if you want

Yes, that would be great!

@RomanSahakyan03
Copy link

Ok @sathvikbhagavan . Lets start with the system of differential equations.
here is it:

dρ₁₁ = im * Ω * (ρ₁₂ - ρ₂₁) + Γ * ρ₂₂
dρ₂₂ = -im * Ω * (ρ₁₂ - ρ₂₁) - Γ * ρ₂₂
dρ₁₂ = -(γ + im * Δ) * ρ₁₂ - im * Ω * (ρ₂₂ - ρ₁₁)
dρ₂₁ = conj(dρ₁₂)

here is the function from which we'll create an ODEProblem

function bloch_equations(u, p, t)
    Ω, Δ, Γ = p
    γ = Γ / 2

    ρ₁₁, ρ₂₂, ρ₁₂, ρ₂₁ = u

    d̢ρ = [im * Ω * (ρ₁₂ - ρ₂₁) + Γ * ρ₂₂;
            -im * Ω * (ρ₁₂ - ρ₂₁) - Γ * ρ₂₂;
            -(γ + im * Δ) * ρ₁₂ - im * Ω * (ρ₂₂ - ρ₁₁);
            conj(-(γ + im * Δ) * ρ₁₂ - im * Ω * (ρ₂₂ - ρ₁₁))]
    return d̢ρ
end

from the analytic solution I defenitly know that the solutions for ρ₁₁, ρ₂₂ are with real numbers unlike for ρ₁₂, ρ₂₁ (this ones solutions are comlex numbers). We can be sure about that using Tsit5(). So I did.

here is some code as a testing for NNODE.

opt = Adam(0.01)
alg = NNODE(chain, opt, init_params = ps, strategy = StochasticTraining(2))
SciMLBase.allowscomplex(::NNODE) = true
sol = solve(problem, alg, verbose = true, maxiters = 3000, saveat = 0.01)
#--------------------------------------


# Checking part

println("This is the maximum value to the imaginary part of NNODE solution ρ₁₁: $(maximum(imag(sol[1, :])))")
println("This is the maximum value to the imaginary part of NNODE solution ρ₂₂: $(maximum(imag(sol[2, :])))")
println("This is the maximum value to the imaginary part of NNODE solution ρ₁₂: $(maximum(imag(sol[3, :])))")
println("This is the maximum value to the imaginary part of NNODE solution ρ₂₁: $(maximum(imag(sol[4, :])))")


ground_truth  = solve(problem, Tsit5(), saveat = 0.01)  

println("This is the maximum value to the imaginary part of ground truth solution ρ₁₁: $(maximum(imag(ground_truth[1, :])))")
println("This is the maximum value to the imaginary part of ground truth solution ρ₂₂: $(maximum(imag(ground_truth[2, :])))")
println("This is the maximum value to the imaginary part of ground truth solution ρ₁₂: $(maximum(imag(ground_truth[3, :])))")
println("This is the maximum value to the imaginary part of ground truth solution ρ₂₁: $(maximum(imag(ground_truth[4, :])))")

And here is the output:

This is the maximum value to the imaginary part of NNODE solution ρ₁₁: 0.0
This is the maximum value to the imaginary part of NNODE solution ρ₂₂: 0.0
This is the maximum value to the imaginary part of NNODE solution ρ₁₂: 0.0
This is the maximum value to the imaginary part of NNODE solution ρ₂₁: 0.0
This is the maximum value to the imaginary part of ground truth solution ρ₁₁: 0.0
This is the maximum value to the imaginary part of ground truth solution ρ₂₂: 0.0
This is the maximum value to the imaginary part of ground truth solution ρ₁₂: 0.48891955620393385
This is the maximum value to the imaginary part of ground truth solution ρ₂₁: 0.2558438236688021

But what interesting that the NNODE somehow approximate ρ₁₁ and ρ₂₂. Here is the plot:

image

In short

  • The problem is the SciMLBase.allowscomplex(::NNODE) = true allowes to work with COMPLEXF64 data type but not approximate the imaginary part, which can cause to the problem.

@sathvikbhagavan
Copy link
Member Author

sathvikbhagavan commented Mar 18, 2024

So, there are two problems here:

  1. The reason the neural network doesn't output imaginary values is the inputs to it is time which is real and the parameters are also real. So, to fix this, the parameters have to be initialized with complex values.
  2. PINNs generally don't learn high frequency behaviour easily. For this purpose, I have reduced the value of the parameter in the differential equation and time span.

Here is my script:

using NeuralPDE
using OrdinaryDiffEq
using Plots
using Lux, Random
using OptimizationOptimisers

function bloch_equations(u, p, t)
    Ω, Δ, Γ = p
    γ = Γ / 2

    ρ₁₁, ρ₂₂, ρ₁₂, ρ₂₁ = u

    d̢ρ = [im * Ω * (ρ₁₂ - ρ₂₁) + Γ * ρ₂₂;
            -im * Ω * (ρ₁₂ - ρ₂₁) - Γ * ρ₂₂;
            -+ im * Δ) * ρ₁₂ - im * Ω * (ρ₂₂ - ρ₁₁);
            conj(-+ im * Δ) * ρ₁₂ - im * Ω * (ρ₂₂ - ρ₁₁))]
    return d̢ρ
end

u0 = zeros(ComplexF64, 4)
u0[1] = 1
time_span = (0.0, 2.0)   
parameters = [2.0, 0.0, 1.0]

problem = ODEProblem(bloch_equations, u0, time_span, parameters)

rng = Random.default_rng()
Random.seed!(rng, 0)

chain = Chain(Dense(1, 16, tanh; init_weight = (rng, a...) -> kaiming_normal(rng, ComplexF64, a...)) , Dense(16, 4; init_weight = (rng, a...) -> kaiming_normal(rng, ComplexF64, a...)))
ps, st = Lux.setup(rng, chain)

opt = Adam(0.01)
alg = NNODE(chain, opt, ps; strategy = GridTraining(0.01))
sol = solve(problem, alg, verbose = true, maxiters = 5000, saveat = 0.01)
ground_truth  = solve(problem, Tsit5(), saveat = 0.01)

plot(sol.t, real.(reduce(hcat, sol.u)[1, :]));
plot!(ground_truth.t, real.(reduce(hcat, ground_truth.u)[1, :]))

plot(sol.t, imag.(reduce(hcat, sol.u)[1, :]));
plot!(ground_truth.t, imag.(reduce(hcat, ground_truth.u)[1, :]))

plot(sol.t, real.(reduce(hcat, sol.u)[2, :]));
plot!(ground_truth.t, real.(reduce(hcat, ground_truth.u)[2, :]))

plot(sol.t, imag.(reduce(hcat, sol.u)[2, :]));
plot!(ground_truth.t, imag.(reduce(hcat, ground_truth.u)[2, :]))

plot(sol.t, real.(reduce(hcat, sol.u)[3, :]));
plot!(ground_truth.t, real.(reduce(hcat, ground_truth.u)[3, :]))

plot(sol.t, imag.(reduce(hcat, sol.u)[3, :]));
plot!(ground_truth.t, imag.(reduce(hcat, ground_truth.u)[3, :]))

plot(sol.t, real.(reduce(hcat, sol.u)[4, :]));
plot!(ground_truth.t, real.(reduce(hcat, ground_truth.u)[4, :]))

plot(sol.t, imag.(reduce(hcat, sol.u)[4, :]));
plot!(ground_truth.t, imag.(reduce(hcat, ground_truth.u)[4, :]))

u1:

image

image

u2:

image

image

u3:

image

image

u4:

image

image

You can see it learns real parts of u1 and u2 and imaginary parts of u3 and u4 well.

This is just a demonstration that training with complex valued functions does work.

@sathvikbhagavan
Copy link
Member Author

@RomanSahakyan03, can I use this example in the documentation?

@RomanSahakyan03
Copy link

Hi @sathvikbhagavan , my sincere apologies for the late response

We appreciate your consideration of our equation for inclusion in the documentation of your library. For us, students, for me and my colleague, the opportunity to see our work in your library would be a great honor. This step would not only be significant for us but would also serve as a reflection of our contribution to the scientific community. Upon giving reference of this example to me and my colleague, you can surely use it in the documentation.

@RomanSahakyan03
Copy link

RomanSahakyan03 commented Mar 23, 2024

@sathvikbhagavan if it's not difficult, can you explain why the NNODE is bad at finding constant values? where can I find information about how NNODE works and what kind of loss function it has?
and about the documentation. Is there anything I can help you with?

@RomanSahakyan03
Copy link

@sathvikbhagavan I came across a problem in another equation, when the curve of a function changes very slowly (or does not change, that is, it is a constant value), then NNODE has some problems approximating the answer. this was still visible in what we had already considered when the imaginary part was zero, but the NNODE strangely manifested itself at the same time

@sathvikbhagavan
Copy link
Member Author

@RomanSahakyan03, I would say it depends on how the training is done and what loss functions is used to train PINNs. They are always approximations. For NNODE, we use L2 loss - https://github.com/SciML/NeuralPDE.jl/blob/master/src/ode_solve.jl#L189
You can always tinker around the source code to try out new things.

For the documentation, I will work on this week and add the example. Can you give a reference to the paper where it is described?

@RomanSahakyan03
Copy link

@sathvikbhagavan Yep. Here is the reference.

@article{steck2023,
  title={alkali data},
  author={Steck, Daniel A},
  url          = {https://steck.us/alkalidata/},
  year={2023}
}

@ChrisRackauckas
Copy link
Member

Merged #839

@RomanSahakyan03
Copy link

I'm very glad that we managed to close this issue. I'm very grateful to @sathvikbhagavan and @ChrisRackauckas. Our team is very interested in using this method for calculations. If you're interested, we could organize a 30-40 minute seminar with the aim of future collaboration.

@ChrisRackauckas
Copy link
Member

I'd be happy to talk.

@RomanSahakyan03
Copy link

I'd be happy to talk.

Hi @ChrisRackauckas. Thanks for waiting. I will send you an invitation in your email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants