In [1]:
using Printf
using Statistics

using Flux
using DifferentialEquations
using DiffEqFlux

using JLD2
using Plots

In [2]:
file = jldopen("/home/gridsan/aramadhan/ocean_convection_profiles.jld2");

In [3]:
Is = keys(file["timeseries/t"])

Nz = file["grid/Nz"]
Nt = length(Is)

t = zeros(Nt)
T = T_data = zeros(Nt, Nz)

for (i, I) in enumerate(Is)
    t[i] = file["timeseries/t/$I"]
    T[i, :] = file["timeseries/T/$I"][1, 1, 2:Nz+1]
end

In [4]:
z = file["grid/zC"]

anim = @gif for n=1:Nt
    t_str = @sprintf("%.2f", t[n] / 86400)
    plot(T[n, :], z, linewidth=2,
         xlim=(19, 20), ylim=(-100, 0), label="",
         xlabel="Temperature (C)", ylabel="Depth (z)",
         title="Free convection: $t_str days", show=false)
end

display(anim)

┌ Info: Saved animation to 
│   fn = /home/gridsan/aramadhan/6S898-climate-parameterization/notebooks/tmp.gif
└ @ Plots /home/gridsan/aramadhan/.julia/packages/Plots/Iuc9S/src/animation.jl:95


In [5]:
function coarse_grain(data, resolution)
    @assert length(data) % resolution == 0
    s = length(data) / resolution
    
    data_cs = zeros(resolution)
    for i in 1:resolution
        t = data[Int((i-1)*s+1):Int(i*s)]
        data_cs[i] = mean(t)
    end
    
    return data_cs
end

coarse_grain (generic function with 1 method)

In [6]:
coarse_resolution = cr = 32
T_cs = zeros(Nt, coarse_resolution)
for n=1:Nt
    T_cs[n, :] = coarse_grain(T[n, :], coarse_resolution)
end

T_cs = transpose(T_cs) |> Array;

In [7]:
dTdt_NN = Chain(Dense(cr, 2cr, tanh),
             Dense(2cr, cr))

ps = Flux.params(dTdt_NN)

# n_subset = round(Int, Nt/2)
# t_subset = t[6:n_subset] ./ 86400
# tspan = (t_subset[1], t_subset[end])
# neural_pde_prediction = T₀ -> neural_ode(dTdt_NN, T₀, tspan, Tsit5(), saveat=t_subset, reltol=1e-7, abstol=1e-9)

T₀ = T_cs[:, 1]
n_train = round(Int, Nt/2)
t_train = t[1:n_train] ./ 86400
tspan_train = (t_train[1], t_train[end])
neural_pde_prediction = T₀ -> neural_ode(dTdt_NN, T₀, tspan_train, Tsit5(), saveat=t_train, reltol=1e-7, abstol=1e-9)

#3 (generic function with 1 method)

In [8]:
data = Iterators.repeated((), 1000)
opt = ADAM(0.1)

loss_function() = sum(abs2, T_cs[:, 1:n_train] .- neural_pde_prediction(T₀))

loss_function (generic function with 1 method)

In [9]:
# Callback function to observe training.
cb = function ()
    loss = loss_function()
    println("loss = $loss")
    loss < 500 && Flux.stop()
end

#5 (generic function with 1 method)

In [10]:
Flux.train!(loss_function, ps, data, opt, cb = cb)

loss = 679988.9808831854 (tracked)
loss = 513885.556727988 (tracked)
loss = 205270.06166129094 (tracked)
loss = 195461.9027166017 (tracked)
loss = 210649.71658004256 (tracked)
loss = 110673.94389786883 (tracked)
loss = 93012.42439102547 (tracked)
loss = 99172.39695962367 (tracked)
loss = 95073.65019254299 (tracked)
loss = 75136.38505421471 (tracked)
loss = 55971.448537842145 (tracked)
loss = 57039.38978936615 (tracked)
loss = 64298.248046931796 (tracked)
loss = 51644.37034981651 (tracked)
loss = 33937.94493673736 (tracked)
loss = 37059.70172364954 (tracked)
loss = 43706.57975807752 (tracked)
loss = 33054.74965167249 (tracked)
loss = 22389.15239496138 (tracked)
loss = 26066.405687523544 (tracked)
loss = 28658.32782595946 (tracked)
loss = 21237.677016090012 (tracked)
loss = 15876.013731373678 (tracked)
loss = 17564.955276368873 (tracked)
loss = 18365.835083502276 (tracked)
loss = 14898.858847902591 (tracked)
loss = 11377.12166946548 (tracked)
loss = 10896.525798237057 (tracked)
loss = 11

In [11]:
tspan = (t[1], t[end]) ./ 86400
nn_pred = neural_ode(dTdt_NN, T₀, tspan, Tsit5(), saveat=t ./86400, reltol=1e-7, abstol=1e-9) |> Flux.data

z_cs = coarse_grain(z, cr)

anim = @gif for n=1:Nt
    t_str = @sprintf("%.2f", t[n] / 86400)
    plot(T_cs[:, n], z_cs, linewidth=2,
         xlim=(19, 20), ylim=(-100, 0), label="Data",
         xlabel="Temperature (C)", ylabel="Depth (z)",
         title="Free convection: $t_str days",
         legend=:bottomright, show=false)
    if n <= n_subset
        plot!(nn_pred[:, n], z_cs, linewidth=2, label="Neural ODE (train)", show=false)
    else
        plot!(nn_pred[:, n], z_cs, linewidth=2, linestyle=:dash, label="Neural ODE (test)", show=false)
    end
end

UndefVarError: UndefVarError: n_subset not defined