# Learning diffusion with a neural partial differential equation

In [2]:
using LinearAlgebra

using DifferentialEquations
using Flux
using DiffEqFlux
using Plots

using CuArrays

We want to solve the heat (or diffusion) equation

$$\frac{\partial u}{\partial t} = \kappa \frac{\partial^2 u}{\partial x^2}, \quad x \in \left[-\frac{1}{2}, \frac{1}{2}\right], \quad x\left(-\frac{1}{2}\right) = x\left(\frac{1}{2}\right)$$

Let's go with $N = 16$ grid points and $\kappa = 1$.

In [3]:
const N = 16
const L = 1
const Δx = L / N
const κ = 1

1

Discretizing the spatial derivative with a second-order centered finite-difference

$$\frac{\partial^2 u}{\partial x^2} \approx \frac{u_{i-1} - 2u_i + u_{i+1}}{\Delta x^2}$$

In [6]:
 d = -2 * ones(N)
sd = ones(N-1)
A = Array(Tridiagonal(sd, d, sd))
A[1, N] = 1
A[N, 1] = 1
A_diffusion = (κ/Δx^2) .* A |> gpu

function diffusion(∂u∂t, u, p, t)
    ∂u∂t .= A_diffusion * u
    return 
end

diffusion (generic function with 1 method)

In [16]:
x = range(-L/2, L/2, length=N)
u₀ = (@. exp(-100*x^2)) |> gpu
tspan = (0.0, 0.1)

(0.0, 0.1)

In [17]:
datasize = 30
t = range(tspan[1], tspan[2], length=datasize)

0.0:0.0034482758620689655:0.1

In [43]:
prob = ODEProblem(diffusion, u₀, tspan)
ode_data = solve(prob, Tsit5(), saveat=t) |> gpu

16×30 CuArray{Float32,2}:
 1.38879e-11  0.000160344  0.00215047  …  0.158278  0.159276  0.160143
 6.99705e-9   0.00087666   0.00636828     0.159394  0.16024   0.160985
 1.44928e-6   0.00462848   0.0199546      0.161718  0.162283  0.162772
 0.00012341   0.0203959    0.0546109      0.164542  0.164741  0.16492 
 0.00432024   0.0726686    0.126828       0.167793  0.167594  0.167415
 0.0621765    0.201762     0.245097    …  0.170616  0.170051  0.169562
 0.367879     0.417262     0.386382       0.172943  0.172096  0.171351
 0.894839     0.611587     0.48795        0.174057  0.173059  0.172192
 0.894839     0.611587     0.48795        0.174151  0.173152  0.172274
 0.367879     0.417262     0.386382       0.172848  0.172003  0.171269
 0.0621765    0.201762     0.245097    …  0.17071   0.170144  0.169645
 0.00432024   0.0726686    0.126828       0.167699  0.167501  0.167333
 0.00012341   0.0203959    0.0546109      0.164635  0.164834  0.165002
 1.44928e-6   0.00462848   0.0199546      0.161626 

In [28]:
dudt = Chain(Dense(N, 100, tanh),
             Dense(100, N)) |> gpu

Chain(Dense(16, 100, tanh), Dense(100, 16))

In [29]:
ps = Flux.params(dudt)
n_ode = x -> neural_ode(dudt, x, tspan, Tsit5(), saveat=t, reltol=1e-7, abstol=1e-9)

#7 (generic function with 1 method)

In [30]:
pred = n_ode(u₀)

function predict_n_ode()
  n_ode(u₀)
end

predict_n_ode (generic function with 1 method)

In [31]:
loss_n_ode() = sum(abs2, ode_data .- predict_n_ode())

loss_n_ode (generic function with 1 method)

In [32]:
data = Iterators.repeated((), 1000)
opt = ADAM(0.1)

ADAM(0.1, (0.9, 0.999), IdDict{Any,Any}())

In [46]:
cb = function ()  # callback function to observe training
  loss = loss_n_ode()
  println("loss = $loss")
  loss < 0.3 && Flux.stop()
end

#15 (generic function with 1 method)

In [47]:
Flux.train!(loss_n_ode, ps, data, opt, cb = cb)

loss = 0.46809602f0 (tracked)
loss = 0.47622025f0 (tracked)
loss = 0.47136202f0 (tracked)
loss = 0.48475534f0 (tracked)
loss = 0.5023926f0 (tracked)
loss = 0.47430715f0 (tracked)
loss = 0.4032059f0 (tracked)
loss = 0.3536617f0 (tracked)
loss = 0.36590967f0 (tracked)
loss = 0.39769518f0 (tracked)
loss = 0.3960112f0 (tracked)
loss = 0.3728591f0 (tracked)
loss = 0.35423586f0 (tracked)
loss = 0.33518186f0 (tracked)
loss = 0.31479177f0 (tracked)
loss = 0.3118969f0 (tracked)
loss = 0.33061108f0 (tracked)
loss = 0.34539747f0 (tracked)
loss = 0.3359766f0 (tracked)
loss = 0.31420934f0 (tracked)
loss = 0.302181f0 (tracked)
loss = 0.3034338f0 (tracked)
loss = 0.3077032f0 (tracked)
loss = 0.3095644f0 (tracked)
loss = 0.30998874f0 (tracked)
loss = 0.3069221f0 (tracked)
loss = 0.29778913f0 (tracked)


In [48]:
nn_pred = Flux.data(n_ode(u₀))

@gif for n=1:datasize
    plot(x, ode_data[:, n], linewidth=2, ylim=(0, 1), label="data", show=false)
    plot!(x, nn_pred[:, n], linewidth=2, ylim=(0, 1), label="Neural ODE", show=false)
end

┌ Info: Saved animation to 
│   fn = /home/gridsan/aramadhan/6S898-climate-parameterization/notebooks/tmp.gif
└ @ Plots /home/gridsan/aramadhan/.julia/packages/Plots/Iuc9S/src/animation.jl:95


In [49]:
u₀_cos = (@. 1 + cos(2π * x)) |> gpu

prob = ODEProblem(diffusion, u₀_cos, tspan)
ode_data = Array(solve(prob, Tsit5(), saveat=t))

nn_pred = Flux.data(n_ode(u₀_cos))
@gif for n=1:datasize
    plot(x, ode_data[:, n], linewidth=2, ylim=(0, 2), label="data", show=false)
    plot!(x, nn_pred[:, n], linewidth=2, ylim=(0, 2), label="Neural ODE", show=false)
end

┌ Info: Saved animation to 
│   fn = /home/gridsan/aramadhan/6S898-climate-parameterization/notebooks/tmp.gif
└ @ Plots /home/gridsan/aramadhan/.julia/packages/Plots/Iuc9S/src/animation.jl:95
