In [1]:
using LinearAlgebra

using DifferentialEquations
using Flux
using DiffEqFlux
using Plots

┌ Info: Precompiling Flux [587475ba-b771-5e3f-ad9e-33799f191a9c]
└ @ Base loading.jl:1186
┌ Info: Precompiling DiffEqFlux [aae7a2af-3d4f-5e19-a356-7da93b79d9d0]
└ @ Base loading.jl:1186
┌ Info: Recompiling stale cache file /home/gridsan/aramadhan/.julia/compiled/v1.1/DifferentialEquations/UQdwS.ji for DifferentialEquations [0c46a032-eb83-5123-abaf-570d42b7fbaa]
└ @ Base loading.jl:1184
┌ Info: Precompiling Plots [91a5bcdd-55d7-5caf-9e0b-520d859cae80]
└ @ Base loading.jl:1186


In [2]:
const N = 16
const L = 1
const Δx = L / N
const κ = 1

 d = -2 * ones(N)
sd = ones(N-1)
A = Array(Tridiagonal(sd, d, sd))
A[1, N] = 1
A[N, 1] = 1
A_diffusion = (κ/Δx^2) .* A

function diffusion(∂u∂t, u, p, t)
    ∂u∂t .= A_diffusion * u
    return 
end

diffusion (generic function with 1 method)

In [3]:
x = range(-L/2, L/2, length=N)
u₀ = @. exp(-100*x^2)
tspan = (0.0, 0.1)
params = [Δx, κ]

2-element Array{Float64,1}:
 0.0625
 1.0   

In [5]:
datasize = 30
t = range(tspan[1], tspan[2], length=datasize)

0.0:0.0034482758620689655:0.1

In [7]:
prob = ODEProblem(diffusion, u₀, tspan)
ode_data = Array(solve(prob, Tsit5(), saveat=t))

16×30 Array{Float64,2}:
 1.38879e-11  0.000160344  0.00215047  …  0.158223  0.159223  0.160099
 6.99705e-9   0.00087666   0.00636828     0.159464  0.160304  0.161033
 1.44928e-6   0.00462848   0.0199546      0.161636  0.16221   0.16272 
 0.00012341   0.020396     0.0546109      0.164632  0.16482   0.164975
 0.00432024   0.0726686    0.126828       0.167703  0.167515  0.16736 
 0.0621765    0.201762     0.245097    …  0.170699  0.170125  0.169615
 0.367879     0.417262     0.386382       0.172871  0.172031  0.171302
 0.894839     0.611587     0.48795        0.174112  0.173112  0.172236
 0.894839     0.611587     0.48795        0.174112  0.173112  0.172236
 0.367879     0.417262     0.386382       0.172871  0.172031  0.171302
 0.0621765    0.201762     0.245097    …  0.170699  0.170125  0.169615
 0.00432024   0.0726686    0.126828       0.167703  0.167515  0.16736 
 0.00012341   0.020396     0.0546109      0.164632  0.16482   0.164975
 1.44928e-6   0.00462848   0.0199546      0.161636  0

In [8]:
dudt = Chain(Dense(N, 100, tanh),
             Dense(100, N))

Chain(Dense(16, 100, tanh), Dense(100, 16))

In [9]:
ps = Flux.params(dudt)
n_ode = x -> neural_ode(dudt, x, tspan, Tsit5(), saveat=t, reltol=1e-7, abstol=1e-9)

#3 (generic function with 1 method)

In [10]:
pred = n_ode(u₀)

function predict_n_ode()
  n_ode(u₀)
end

predict_n_ode (generic function with 1 method)

In [11]:
loss_n_ode() = sum(abs2, ode_data .- predict_n_ode())

loss_n_ode (generic function with 1 method)

In [12]:
data = Iterators.repeated((), 1000)
opt = ADAM(0.1)

ADAM(0.1, (0.9, 0.999), IdDict{Any,Any}())

In [22]:
cb = function ()  # callback function to observe training
  loss = loss_n_ode()
  display(loss)

  # plot current prediction against data
  # cur_pred = Flux.data(predict_n_ode())

  loss < 0.1 && Flux.stop()
end

cb()

0.9878774926495675 (tracked)

false

In [23]:
Flux.train!(loss_n_ode, ps, data, opt, cb = cb)

1.1244774236232902 (tracked)

1.1655380554515398 (tracked)

0.9650170370991754 (tracked)

0.7130912424715691 (tracked)

0.6420090692789944 (tracked)

0.7295189703805592 (tracked)

0.7604272454051172 (tracked)

0.6750777033668696 (tracked)

0.5794394241904655 (tracked)

0.5298438031519348 (tracked)

0.5073183770679288 (tracked)

0.5025122862980889 (tracked)

0.5047831254375561 (tracked)

0.4806483629051853 (tracked)

0.4370372951704856 (tracked)

0.4133559428075808 (tracked)

0.40707257993453755 (tracked)

0.39173566535592597 (tracked)

0.3728646151262793 (tracked)

0.3684286901675059 (tracked)

0.3684485897947804 (tracked)

0.35697040106378874 (tracked)

0.34214990683270013 (tracked)

0.3357683564133558 (tracked)

0.3346289223606603 (tracked)

0.334009709829519 (tracked)

0.3348207953386077 (tracked)

0.3335565984231016 (tracked)

0.324662731996302 (tracked)

0.31409899123359986 (tracked)

0.3122212701334841 (tracked)

0.3166421583084412 (tracked)

0.31757736247962276 (tracked)

0.31285822144577136 (tracked)

0.3068091724016528 (tracked)

0.30212624170271396 (tracked)

0.3002082028157587 (tracked)

0.3019558239122372 (tracked)

0.30380494913645834 (tracked)

0.30064899733668593 (tracked)

0.29414740030576875 (tracked)

0.2905698363920781 (tracked)

0.29126452547858644 (tracked)

0.2922119006463454 (tracked)

0.2907575042464443 (tracked)

0.2874110492535603 (tracked)

0.28385011175712316 (tracked)

0.28203768531810763 (tracked)

0.28242750343370676 (tracked)

0.28256633498903067 (tracked)

0.2803688543624795 (tracked)

0.2772639202464709 (tracked)

0.27546056616576303 (tracked)

0.2748272385270413 (tracked)

0.2741592498408033 (tracked)

0.272891132275898 (tracked)

0.270983663026747 (tracked)

0.2689718637150607 (tracked)

0.26772073612804814 (tracked)

0.2671198519183927 (tracked)

0.26616015122245135 (tracked)

0.2645683978597102 (tracked)

0.2629794362866221 (tracked)

0.2617361585783582 (tracked)

0.2606828320066894 (tracked)

0.2596353512700322 (tracked)

0.2584093726699992 (tracked)

0.2569674766220358 (tracked)

0.25560213197176074 (tracked)

0.2544881648055259 (tracked)

0.2534029894252765 (tracked)

0.25217057133882576 (tracked)

0.25086810724066366 (tracked)

0.2495640227063859 (tracked)

0.24829626291330853 (tracked)

0.24711889244892982 (tracked)

0.24594322800734805 (tracked)

0.24464010326544924 (tracked)

0.24328965901781577 (tracked)

0.24203288726342478 (tracked)

0.24084065475389338 (tracked)

0.23961403440054452 (tracked)

0.23832329854673737 (tracked)

0.2370041443947506 (tracked)

0.23572606520765435 (tracked)

0.2345025189549599 (tracked)

0.23324969793274847 (tracked)

0.23192881492628345 (tracked)

0.23061116871181866 (tracked)

0.22933919389750088 (tracked)

0.22807176317298394 (tracked)

0.22677520881073812 (tracked)

0.22545561112004353 (tracked)

0.22413476633787624 (tracked)

0.22283097816830588 (tracked)

0.22152922850734422 (tracked)

0.2202004821207841 (tracked)

0.2188548428759701 (tracked)

0.21751803619734547 (tracked)

0.21618484829935652 (tracked)

0.21483964766047656 (tracked)

0.2134809761370031 (tracked)

0.2121133737575157 (tracked)

0.2107436453839158 (tracked)

0.20937223033464597 (tracked)

0.20798947629609535 (tracked)

0.20659325044696691 (tracked)

0.20519146384448153 (tracked)

0.20378562626247979 (tracked)

0.20237227610913008 (tracked)

0.20094899180404066 (tracked)

0.19951445113112806 (tracked)

0.1980722614348433 (tracked)

0.1966255332003557 (tracked)

0.19516926737458556 (tracked)

0.19370035324079996 (tracked)

0.19222296745597944 (tracked)

0.19073908376895088 (tracked)

0.18924644119248016 (tracked)

0.1877431893765062 (tracked)

0.18622940576362995 (tracked)

0.1847073227297304 (tracked)

0.1831775471147904 (tracked)

0.1816375488355526 (tracked)

0.1800867558658716 (tracked)

0.1785275312608745 (tracked)

0.17696011499674172 (tracked)

0.17538334128968064 (tracked)

0.17379677614329614 (tracked)

0.17220118847791768 (tracked)

0.1705976286435575 (tracked)

0.16898586587692566 (tracked)

0.16736499157878915 (tracked)

0.16573586898296622 (tracked)

0.16409948026635268 (tracked)

0.16245569735102441 (tracked)

0.16080444029672988 (tracked)

0.15914609074173802 (tracked)

0.15748148045473326 (tracked)

0.15581107939772967 (tracked)

0.15413489084635856 (tracked)

0.1524533799171507 (tracked)

0.150767526154225 (tracked)

0.14907772152964666 (tracked)

0.14738432245454428 (tracked)

0.14568795903767218 (tracked)

0.14398943680223658 (tracked)

0.14228936981899612 (tracked)

0.14058827728844359 (tracked)

0.1388869695622943 (tracked)

0.13718621453283136 (tracked)

0.13548681593889145 (tracked)

0.1337893781309692 (tracked)

0.1320947438773598 (tracked)

0.1304037027284312 (tracked)

0.12871711765342259 (tracked)

0.12703564365153736 (tracked)

0.1253600554545829 (tracked)

0.1236913171312844 (tracked)

0.12203009509288999 (tracked)

0.12037730978363363 (tracked)

0.11873368343145259 (tracked)

0.11709988708402831 (tracked)

0.11547678770175053 (tracked)

0.11386511281789055 (tracked)

0.11226560774328236 (tracked)

0.1106790021657475 (tracked)

0.10910593986708718 (tracked)

0.107547154608971 (tracked)

0.10600311402173368 (tracked)

0.10447453408705486 (tracked)

0.10296188216401464 (tracked)

0.10146584664629645 (tracked)

0.09998677409210671 (tracked)

In [24]:
nn_pred = Flux.data(predict_n_ode())
@gif for n=1:datasize
    plot(x, ode_data[:, n], ylim=(0, 1), label="data", show=false)
    plot!(x, nn_pred[:, n], ylim=(0, 1), label="Neural ODE", show=false)
end

┌ Info: Saved animation to 
│   fn = /home/gridsan/aramadhan/6S898-climate-parameterization/notebooks/tmp.gif
└ @ Plots /home/gridsan/aramadhan/.julia/packages/Plots/Iuc9S/src/animation.jl:95


In [25]:
u1 = @. exp(-10*abs(x))

16-element Array{Float64,1}:
 0.006737946999085467
 0.013123728736940956
 0.025561533206507402
 0.049787068367863944
 0.09697196786440505 
 0.18887560283756188 
 0.36787944117144233 
 0.7165313105737893  
 0.7165313105737893  
 0.36787944117144233 
 0.18887560283756188 
 0.09697196786440505 
 0.049787068367863944
 0.025561533206507402
 0.013123728736940956
 0.006737946999085467

In [37]:
u₀_cos = @. 1 + cos(2π * x)

prob = ODEProblem(diffusion, u₀_cos, tspan)
ode_data = Array(solve(prob, Tsit5(), saveat=t))

nn_pred = Flux.data(n_ode(u₀_cos))
@gif for n=1:datasize
    plot(x, ode_data[:, n], ylim=(0, 2), label="data", show=false)
    plot!(x, nn_pred[:, n], ylim=(0, 2), label="Neural ODE", show=false)
end

┌ Info: Saved animation to 
│   fn = /home/gridsan/aramadhan/6S898-climate-parameterization/notebooks/tmp.gif
└ @ Plots /home/gridsan/aramadhan/.julia/packages/Plots/Iuc9S/src/animation.jl:95


In [52]:
dudt.layers[1].W

Tracked 100×16 Array{Float32,2}:
  0.219476     0.387633   -0.115517    …   0.633287    0.142618    0.479156 
  0.187322     0.346181    0.0562944       0.779724    0.104201    0.612242 
  0.0395474    0.29002     0.146384        0.51009     0.287297    0.307355 
 -0.499803    -0.312583   -0.212347       -0.284019   -0.245401   -0.225501 
  0.117012     0.386874   -0.161308        0.343742    0.251343    0.787597 
 -0.460909    -0.576501   -0.523187    …  -0.22026    -0.582976    0.0193454
 -0.00468038  -0.0315812   0.207367        0.0825084   0.326486   -0.174053 
  0.386303     0.437262    0.613244        0.385165    0.47834     0.102952 
 -0.382644    -0.0915628  -0.37098        -0.232857   -0.384615    0.0556619
 -0.181194     0.167625   -0.0316671       0.117747   -0.162832   -0.164185 
 -0.178336    -0.208751   -0.337404    …  -0.20487    -0.117129   -0.233638 
 -0.544685    -0.104979   -0.00568667     -0.537699   -0.291795   -0.741161 
 -0.106091     0.0180631   0.384981       -