# Learning diffusion with a neural partial differential equation

In [None]:
using LinearAlgebra

using DifferentialEquations
using Flux
using DiffEqFlux
using Plots

We want to solve the heat (or diffusion) equation

$$\frac{\partial u}{\partial t} = \kappa \frac{\partial^2 u}{\partial x^2}, \quad x \in \left[-\frac{1}{2}, \frac{1}{2}\right], \quad x\left(-\frac{1}{2}\right) = x\left(\frac{1}{2}\right)$$

Let's go with $N = 16$ grid points and $\kappa = 1$.

In [None]:
const N = 16
const L = 1
const Δx = L / N
const κ = 1

Discretizing the spatial derivative with a second-order centered finite-difference

$$\frac{\partial^2 u}{\partial x^2} \approx \frac{u_{i-1} - 2u_i + u_{i+1}}{\Delta x^2}$$

In [None]:
 d = -2 * ones(N)
sd = ones(N-1)
A = Array(Tridiagonal(sd, d, sd))
A[1, N] = 1
A[N, 1] = 1
A_diffusion = (κ/Δx^2) .* A

function diffusion(∂u∂t, u, p, t)
    ∂u∂t .= A_diffusion * u
    return 
end

In [None]:
x = range(-L/2, L/2, length=N)
u₀ = @. exp(-100*x^2)
tspan = (0.0, 0.1)

In [None]:
datasize = 30
t = range(tspan[1], tspan[2], length=datasize)

In [None]:
prob = ODEProblem(diffusion, u₀, tspan)
ode_data = Array(solve(prob, Tsit5(), saveat=t))

In [None]:
dudt = Chain(Dense(N, 100, tanh),
             Dense(100, N))

In [None]:
ps = Flux.params(dudt)
n_ode = x -> neural_ode(dudt, x, tspan, Tsit5(), saveat=t, reltol=1e-7, abstol=1e-9)

In [None]:
pred = n_ode(u₀)

function predict_n_ode()
  n_ode(u₀)
end

In [None]:
loss_n_ode() = sum(abs2, ode_data .- predict_n_ode())

In [None]:
data = Iterators.repeated((), 1000)
opt = ADAM(0.1)

In [None]:
cb = function ()  # callback function to observe training
  loss = loss_n_ode()
  println("loss = $loss")
  loss < 0.1 && Flux.stop()
end

cb()

In [None]:
Flux.train!(loss_n_ode, ps, data, opt, cb = cb)

In [None]:
nn_pred = Flux.data(n_ode(u₀))

@gif for n=1:datasize
    plot(x, ode_data[:, n], linewidth=2, ylim=(0, 1), label="data", show=false)
    plot!(x, nn_pred[:, n], linewidth=2, ylim=(0, 1), label="Neural ODE", show=false)
end

In [None]:
u₀_cos = @. 1 + cos(2π * x)

prob = ODEProblem(diffusion, u₀_cos, tspan)
ode_data = Array(solve(prob, Tsit5(), saveat=t))

nn_pred = Flux.data(n_ode(u₀_cos))
@gif for n=1:datasize
    plot(x, ode_data[:, n], ylim=(0, 2), label="data", show=false)
    plot!(x, nn_pred[:, n], ylim=(0, 2), label="Neural ODE", show=false)
end

In [None]:
plot(1:10)