In [1]:
using Printf
using Statistics

using Flux
using DifferentialEquations
using DiffEqFlux

using JLD2
using Plots

using Flux: @epochs
using CuArrays
using CUDAdrv; CUDAdrv.name(CuDevice(0))

"GeForce GTX 960"

In [2]:
file = jldopen("../data/ocean_convection_profiles.jld2");

In [3]:
Is = keys(file["timeseries/t"])

Nz = file["grid/Nz"]
Nt = length(Is)

t = zeros(Nt)
T = T_data = zeros(Nt, Nz)

for (i, I) in enumerate(Is)
    t[i] = file["timeseries/t/$I"]
    T[i, :] = file["timeseries/T/$I"][1, 1, 2:Nz+1]
end

In [4]:
z = file["grid/zC"]

256-element Array{Float64,1}:
 -99.8046875
 -99.4140625
 -99.0234375
 -98.6328125
 -98.2421875
 -97.8515625
 -97.4609375
 -97.0703125
 -96.6796875
 -96.2890625
 -95.8984375
 -95.5078125
 -95.1171875
   ⋮        
  -4.4921875
  -4.1015625
  -3.7109375
  -3.3203125
  -2.9296875
  -2.5390625
  -2.1484375
  -1.7578125
  -1.3671875
  -0.9765625
  -0.5859375
  -0.1953125

In [5]:
function coarse_grain(data, resolution)
    @assert length(data) % resolution == 0
    s = length(data) / resolution
    
    data_cs = zeros(resolution)
    for i in 1:resolution
        t = data[Int((i-1)*s+1):Int(i*s)]
        data_cs[i] = mean(t)
    end
    
    return data_cs
end

coarse_grain (generic function with 1 method)

In [9]:
coarse_resolution = cr = 32
T_cs = zeros(Nt, coarse_resolution)
for n=1:Nt
    T_cs[n, :] = coarse_grain(T[n, :], coarse_resolution)
end

T_cs = transpose(T_cs) |> gpu;

In [10]:
dTdt_NN = Chain(Dense(cr, 2cr, tanh),
                Dense(2cr, cr)) |> gpu

ps = Flux.params(dTdt_NN)

T₀ = T_cs[:, 1]
n_train = round(Int, Nt/2)
t_train = t[1:n_train] ./ 86400
tspan_train = (t_train[1], t_train[end])

neural_pde_prediction(T₀) = neural_ode(dTdt_NN, T₀, tspan_train, Tsit5(), saveat=t_train, reltol=1e-7, abstol=1e-9)

neural_pde_prediction (generic function with 1 method)

In [11]:
opt = ADAM(0.1)

#data = [(T₀, T_cs[:, 1:n_train])]
data = Iterators.repeated((T₀, T_cs[:, 1:n_train]), 100)

loss_function(T₀, T_data) = sum(abs2, T_data .- neural_pde_prediction(T₀))

loss_function (generic function with 1 method)

In [12]:
# Callback function to observe training.
cb = function ()
    loss = loss_function(T₀, T_cs[:, 1:n_train]) # Not very generalizable...
    println("loss = $loss")
end

#3 (generic function with 1 method)

In [None]:
Flux.train!(loss_function, ps, data, opt, cb = cb)
#for _ in 1:10
#    Flux.train!(loss_function, ps, data, opt, cb = cb)
#end

loss = 1.03841244f6 (tracked)


InterruptException: InterruptException:




In [None]:
tspan = (t[1], t[end]) ./ 86400
nn_pred = neural_ode(dTdt_NN, T₀, tspan, Tsit5(), saveat=t ./86400, reltol=1e-7, abstol=1e-9) |> Flux.data

z_cs = coarse_grain(z, cr)

anim = @gif for n=1:10:Nt
    t_str = @sprintf("%.2f", t[n] / 86400)
    plot(T_cs[:, n], z_cs, linewidth=2,
         xlim=(19, 20), ylim=(-100, 0), label="Data",
         xlabel="Temperature (C)", ylabel="Depth (z)",
         title="Free convection: $t_str days",
         legend=:bottomright, show=false)
    if n <= n_train
        plot!(nn_pred[:, n], z_cs, linewidth=2, label="Neural ODE (train)", show=false)
    else
        plot!(nn_pred[:, n], z_cs, linewidth=2, linestyle=:dash, label="Neural ODE (test)", show=false)
    end
end