# On parameterized scientific computing

Scientific computing often involves the optimization of feasible solutions.
This can be done by parameterizing a scientific system of interest.
For example, the governing equations can serve as parameterized constraints in physics-informed deep learning.
A suitable ML algorithm can be found with the optimization of trainable parameters.

## Optimization of an ODE system

Let's take the Lotka-Volterra equations, i.e. the predator–prey model, as an example.

In [39]:
function lotka_volterra(u, p, t)
    α, β, δ, γ = p
    🐰, 🦊 = u
    return [🐰 * (α - β * 🦊), -🦊 * (γ - δ * 🐰)]
end

lotka_volterra (generic function with 1 method)

In [40]:
using OrdinaryDiffEq, Plots, Flux, Solaris, Optim, Test

We define the initial computing parameters.

In [41]:
u0 = [1.0, 1.0] # initial number of 🐰 & 🦊
p = [1.5, 1.0, 1.0, 3.0] # α, β, δ, γ
tspan = (0.0, 10.0) # time span

(0.0, 10.0)

An ODE is formulated and solved.

In [42]:
prob = ODEProblem(lotka_volterra, u0, tspan, p)
sol = solve(prob, RK4())

retcode: Success
Interpolation: 3rd order Hermite
t: 92-element Vector{Float64}:
  0.0
  0.04096249791634249
  0.09492579003001751
  0.16461733554330665
  0.24797720114853214
  0.34492625714965414
  0.45379050296581175
  0.5737061160620319
  0.7031638316079847
  0.8402918691464465
  0.9817450277305837
  1.1227117080154019
  1.2581933635073077
  ⋮
  8.95997926998975
  9.034370841332136
  9.113913433668808
  9.196089730632035
  9.29363305492151
  9.384149125956666
  9.479480710618827
  9.576558391631197
  9.680435595783445
  9.79189744517297
  9.9146263422521
 10.0
u: 92-element Vector{Vector{Float64}}:
 [1.0, 1.0]
 [1.0223548367370965, 0.9217512759415004]
 [1.0574508739228143, 0.8292101153731414]
 [1.1121764508851177, 0.7255574058887769]
 [1.1916495092752522, 0.6218989991888514]
 [1.3039098186186444, 0.524647626081038]
 [1.4569993259821232, 0.43972490960557004]
 [1.6618533803393245, 0.3698163393710881]
 [1.9307872963654964, 0.31627683543398516]
 [2.277283700181586, 0.27949699071153483]


In [None]:
plot(sol)

Now suppose that we would like to maintain the stability of this ecosystem.

In [43]:
function loss(p)
    sol = solve(prob, Midpoint(), p = p, saveat = tspan[1]:0.2:tspan[2])
    loss = sum(abs2, sol .- 1)
    return loss
end

loss (generic function with 1 method)

In [44]:
cb = function (p, l)
    println("loss: $l")
    return false
end

#17 (generic function with 1 method)

In [50]:
res = sci_train(loss, p, ADAM(), callback = cb, maxiters = 1000)

loss: 508.97740745853497
loss: 506.6485126820388
loss: 504.3310082722748
loss: 502.0252858982016
loss: 499.7312171996283
loss: 497.44920241034436
loss: 495.1791823027128
loss: 492.92131931984164
loss: 490.67559688200436
loss: 488.442190992473
loss: 486.2213880795642
loss: 484.0131374919249
loss: 481.8173524640663
loss: 479.634210783654
loss: 477.4637510206442
loss: 475.30623032399114
loss: 473.1614188737079
loss: 471.02955889113247
loss: 468.91055186347796
loss: 466.8044786584641
loss: 464.7112508092704
loss: 462.6310512011346
loss: 460.5637786921046
loss: 458.50951593699034
loss: 456.4681310461464
loss: 454.43973866182796
loss: 452.42414240159275
loss: 450.4215199677923
loss: 448.431641380285
loss: 446.45475375589035
loss: 444.4905485594725
loss: 442.53890822245404
loss: 440.60004201711376
loss: 438.673822683938
loss: 436.76017721927707
loss: 434.8589598144539
loss: 432.97020534559175
loss: 431.0937048765529
loss: 429.229514907397
loss: 427.377496402895
loss: 425.5377824228707
loss: 4

loss: 126.82431859416356
loss: 126.50010243388148
loss: 126.17699055551438
loss: 125.85491735663688
loss: 125.53395049353838
loss: 125.21407785119706
loss: 124.89519873703179
loss: 124.57735718203584
loss: 124.26061909996172
loss: 123.94495011571176
loss: 123.63025723550933
loss: 123.31662488784264
loss: 123.00403660768232
loss: 122.69242897453447
loss: 122.3818658623644
loss: 122.07235399414016
loss: 121.76381754130831
loss: 121.45629089049916
loss: 121.14979877658558
loss: 120.84429258652888
loss: 120.53975386543469
loss: 120.23625251762842
loss: 119.93374431346366
loss: 119.63215151318494
loss: 119.33155178617956
loss: 119.03197423821872
loss: 118.73333181299073
loss: 118.43562273860027
loss: 118.13894543149088
loss: 117.84320843291594
loss: 117.54837614053807
loss: 117.25451173384802
loss: 116.96162672744225
loss: 116.66965862913518
loss: 116.3786411245215
loss: 116.08857142235041
loss: 115.79938813407846
loss: 115.5111096340866
loss: 115.22376933277789
loss: 114.93737030016615
los

loss: 48.28668433950802
loss: 48.185012980461806
loss: 48.08356722032307
loss: 47.98233249946896
loss: 47.88132508565759
loss: 47.78056367122797
loss: 47.680049480915415
loss: 47.579758252926446
loss: 47.47968692514495
loss: 47.37984720276338
loss: 47.28026196024507
loss: 47.18091610619048
loss: 47.08178779355199
loss: 46.98286235447837
loss: 46.88415133028119
loss: 46.78567994643012
loss: 46.68745453046738
loss: 46.589460850409424
loss: 46.49168743893322
loss: 46.39412422157565
loss: 46.29679247491984
loss: 46.19969278688555
loss: 46.10281945650573
loss: 46.00616290437619
loss: 45.90971344562526
loss: 45.813481407503666
loss: 45.7174713419553
loss: 45.62169441958004
loss: 45.52615385818998
loss: 45.430827815131
loss: 45.33570616774367
loss: 45.24079470516073
loss: 45.146103216286086
loss: 45.05164188381661
loss: 44.957389469484994
loss: 44.86334532642762
loss: 44.769514626043566
loss: 44.675899219907244
loss: 44.582507768120394
loss: 44.48932144316477
loss: 44.39633708441996
loss: 44.

u: 4-element Vector{Float64}:
 1.4810722240498688
 1.4266348663473205
 1.5124241030755163
 2.364444861694887

Let's take a look at the evolution process.

In [20]:
prob = ODEProblem(lotka_volterra, u0, tspan, res.u)
sol1 = solve(prob, Midpoint())
plot(sol1)
plot!(sol, line=:dash)

LoadError: UndefVarError: res not defined

The strategy can be implemented in other systems.
A neural network is a typical parameterized architecture.
If NN is built to approximate the numerical solution of a physical system directly, by training of the NN parameters, we are getting closer to accurate solutions.
Similarly, if NN is used to surrogate a numerical scheme, effective training should lead to good numerical scheme.

## Optimization of a numerical scheme

Let's consider a specific task, i.e. the integrator for ODEs.

Runge-Kutta methods focus on a family of initial value problems.

$y'=f(t,y), \quad y(t_0)=y_0$

Runge-Kutta methods don't care about details of the right-hand side $f$, so it can be as generic as possible, e.g. scalar/vector, time homogeneous/inhomogeneous, etc.

In our implementation, we hope to keep this feature. Any function can be passed to the neural integrator as its argument.

Let's consider the scalar equation first.

- simple one: $f = -e^{-t} \rightarrow y=e^{-t}+C$ -> testing
- complex one: $f = -ye^{-t}$ -> training

In [21]:
f(y, t) = -exp(-t) * y
f(y, p, t) = -exp(-t) * y # parameterized equation for DifferentialEquations.jl

f (generic function with 2 methods)

#### Data generation

Then we prepare the dataset.
It can be produced following theoretical solution (if it exists) or by numerical solver.

In [22]:
h = 0.2 # time step
X = randn(Float32, 1, 1000)
#T = collect(1:49/1000:50)[1:end-1] |> permutedims
T = rand(Float32, 1, 1000) * 100

Y = zeros(Float32, 1, 1000)
for i in axes(X, 2)
    u0 = X[1, i]
    t0 = T[1, i]
    C = u0 - exp(-t0)
    Y[1, i] = exp(-(t0+h)) + C
end

Y1 = zeros(Float32, 1, 1000)
for i in axes(X, 2)
    u0 = X[1, i]
    tspan = (T[1, i], T[1, i] + h)
    prob = ODEProblem(f, u0, tspan)
    sol = solve(prob, Tsit5(), dt = h, adaptive=false)
    Y1[1, i] = sol.u[end]
end

@show Y1 ≈ Y # the two results shoud be equivalent in the sense of numerical error
Y .= Y1

Y1 ≈ Y = false


1×1000 Matrix{Float32}:
 -0.505689  -0.94243  0.449392  -0.384492  …  -0.21778  -0.0831608  -0.70952

#### Model architecture

The key is to parameterized the numerical scheme.

We first follow the solution algorithm in a multi-step Runge-Kutta scheme, and define the learnable parameters.

![](nn.png)

In [23]:
# p[9]
# p[1:3]: a21, a31, a32
# p[4:6]: b
# p[7:9]: c

function init_params(p, init=:allzero)
    eval(init)(p)
end

function allzero(p)
    p .= 0. .+ rand(Float32, length(p)) / length(p)
end

function kutta3(p)
    a21 = 1/2
    a31 = -1
    a32 = 2
    b1 = 1/6
    b2 = 2/3
    b3 = 1/6
    c1 = 0
    c2 = 1/2
    c3 = 1
    
    p .= Float32[a21, a31, a32, b1, b2, b3, c1, c2, c3] .+ rand(Float32, length(p)) / length(p)
end

kutta3 (generic function with 1 method)

In [24]:
p = Array{Float32}(undef, 9)
#init_params(p, :allzero)
init_params(p, :kutta3)

9-element Vector{Float32}:
  0.5546772
 -0.9887149
  2.0918388
  0.2264811
  0.6862181
  0.23345572
  0.029876795
  0.6026921
  1.0462852

In [25]:
# our neural RK model
# it's nothing but just parameter flows
function model(p)
    k1 = @. f(X, T)
    k2 = @. f(X + h * p[1] * k1, T + p[8] * h)
    k3 = @. f(X + h * (p[2] * k1 + p[3] * k2), T + p[9] * h)
    
    return @. X + h * (p[4] * k1 + p[5] * k2 + p[6] * k3)
end

function loss(p)
    Y1 = model(p)
    return sum((Y .- Y1).^2) + (sum(p[4:6]) - 1.0)^2 + (p[1] - p[8])^2 + (p[2] + p[3] - p[9])^2 # the parameters should satisfy the constraints
end

loss (generic function with 1 method)

In [26]:
cb = function (p, l)
    display(l)
    return false
end

#9 (generic function with 1 method)

In [27]:
res = sci_train(loss, p, ADAM(), cb = cb, maxiters = 200)

0.028000312675804263

0.026561607043064685

0.025169405377358775

0.023824442547736177

0.022527345449783825

0.021278656458672073

0.02007879254671734

0.018927977724118156

0.017826317878567838

0.01677374308990837

0.01576997987796914

0.01481459622260558

0.013907007062003443

0.0130463643681713

0.012231688088414568

0.011461779016115486

0.010735241874515883

0.010050549525677181

0.009405997131563164

0.00879975679199235

0.008229890332646681

0.007694420496195464

0.007191268821520516

0.00671839662554862

0.006273779861983978

0.005855495270595733

0.005461643397281982

0.005090460170379034

0.004740322179873332

0.004409792752511232

0.004097510217458154

0.0038023281604181794

0.0035232124105178736

0.003259303204614023

0.0030098279162718584

0.0027741672145784337

0.0025517622961692475

0.0023421630255402053

0.0021449334107038103

0.001959736107735281

0.0017862140291320834

0.001624046421187872

0.001472890021824791

0.0013324151393957545

0.0012022624301380649

0.0010820636252902966

0.0009714218742279036

0.0008699120496118725

0.0007770918191580097

0.0006925046108745482

0.0006156765321601065

0.0005461151618660056

0.0004833349203848979

0.0004268527873320234

0.0003761902826233569

0.0003308741826978304

0.0002904550068052687

0.00025450171562434176

0.0002226084739883098

0.000194395510393885

0.00016950470392089042

0.0001476071364685182

0.0001284043926750769

0.00011161906627635705

9.699730322754124e-5

8.432173919416869e-5

7.338051362427604e-5

6.398716364940922e-5

5.597268503707367e-5

4.9182732973731126e-5

4.347816515292854e-5

3.872993416985556e-5

3.4822882665572366e-5

3.16486153572275e-5

2.9110263488609008e-5

2.7118912225226767e-5

2.5592989762912174e-5

2.4457846729886448e-5

2.3647107647642206e-5

2.3101878725454e-5

2.2768598281442294e-5

2.2601285686111746e-5

2.25593611312718e-5

2.2609030114893753e-5

2.272171065977005e-5

2.287361519898712e-5

2.3046115106068936e-5

2.322459142018397e-5

2.339799725664494e-5

2.355805712143155e-5

2.369979067698003e-5

2.3820007998292633e-5

2.3917235099358975e-5

2.3991039896422396e-5

2.4041843374531408e-5

2.407092754432679e-5

2.407992247398189e-5

2.4070334607312525e-5

2.404421068884219e-5

2.4003164939528975e-5

2.3948601043965678e-5

2.3882338823130213e-5

2.380576019738385e-5

2.372033936798549e-5

2.3627258395063493e-5

2.3528019947118104e-5

2.3423796772781263e-5

2.3316028751365927e-5

2.3205681132862222e-5

2.309440308105001e-5

2.298313751398286e-5

2.2873153726376548e-5

2.2765502663575236e-5

2.2661404156906752e-5

2.25616175478728e-5

2.2466841450427996e-5

2.2377695514145672e-5

2.2294646650025638e-5

2.2218085781303742e-5

2.214794421740928e-5

2.2084415098921853e-5

2.2027143391158164e-5

2.1976080598232862e-5

2.193068237658124e-5

2.189047377848975e-5

2.185509185738593e-5

2.1824183326857994e-5

2.1796844420364052e-5

2.1773198114261626e-5

2.1752221033525173e-5

2.1733877623627847e-5

2.1717780717202457e-5

2.1703222737680044e-5

2.169018011196205e-5

2.1678691603019474e-5

2.166815790449657e-5

2.1658345490440017e-5

2.164941555951791e-5

2.164108826327038e-5

2.1633266360019904e-5

2.1626036256039782e-5

2.1618849950176276e-5

2.1612105533071708e-5

2.160562456381839e-5

2.159909549975563e-5

2.159280852681239e-5

2.1586328861887585e-5

2.1579895010618767e-5

2.15735764564081e-5

2.1566904492369342e-5

2.156024378595663e-5

2.1553444991260642e-5

2.1546381008187098e-5

2.1539192782255595e-5

2.153216076153316e-5

2.1524646869216327e-5

2.151705372660108e-5

2.1509390082551615e-5

2.150164490601558e-5

2.1494018552963498e-5

2.148593753045712e-5

2.14783018166464e-5

2.1470177797817995e-5

2.1462187410225558e-5

2.145422084410817e-5

2.1446115709632056e-5

2.1438260620872675e-5

2.1430432677616232e-5

2.1422382399158744e-5

2.1414616743567863e-5

2.1406494149490652e-5

2.139856043068044e-5

2.1390718230468227e-5

2.1382758629609273e-5

2.1374920168913944e-5

2.136677475575369e-5

2.1359004578781432e-5

2.135089991167507e-5

2.134330732622801e-5

2.1335228553888243e-5

2.1327326951322286e-5

2.1319366785093688e-5

2.131131286645383e-5

2.1303269639696492e-5

2.1295477032492858e-5

2.1287565823420674e-5

2.1279441783159233e-5

2.127148718801696e-5

2.126346084903588e-5

2.125534899325564e-5

2.124732088624901e-5

2.1239220999806194e-5

2.1231224185652417e-5

2.122289376154952e-5

2.1214927471724994e-5

2.120666379092683e-5

2.1198540528754205e-5

2.1190366831040935e-5

2.118210433869897e-5

2.11739064519779e-5

2.11739064519779e-5

u: 9-element Vector{Float32}:
  0.57793486
 -1.0077721
  2.0728042
  0.17865033
  0.6377906
  0.18471488
  0.029876795
  0.5779062
  1.0650327

In [None]:
res = sci_train(loss, res.u, LBFGS(), cb = cb, maxiters = 1000)

#### Test the trained model

In [None]:
function predict(t, x, h, f, p)
    k1 = f(x, t)
    k2 = f(x + h * p[1] * k1, t + p[8] * h)
    k3 = f(x + h * (p[2] * k1 + p[3] * k2), t + p[9] * h)
    
    return @. x + h * (p[4] * k1 + p[5] * k2 + p[6] * k3)
end

In [None]:
f1(y, t) = -2 * exp(-2 * t)

In [None]:
u = zeros(100); u[1] = 1.0
t = zeros(100); t[1] = 0.0
for i = 2:length(u)
    t[i] = t[i-1] + h
    u[i] = predict(t[i-1], u[i-1], h, f1, res.u)
end

In [None]:
plot(t, u, label="nn", xlabel="t")
scatter!(t[1:4:end], exp.(-2t)[1:4:end], label="exact")