In [84]:
using Printf
using Statistics

using Flux
using DifferentialEquations
using DiffEqFlux

using JLD2
using Plots

In [176]:
file = jldopen("/home/gridsan/aramadhan/ocean_convection_profiles.jld2");

In [177]:
Is = keys(file["timeseries/t"])

Nz = file["grid/Nz"]
Nt = length(Is)

t = zeros(Nt)
T = T_data = zeros(Nt, Nz)

for (i, I) in enumerate(Is)
    t[i] = file["timeseries/t/$I"]
    T[i, :] = file["timeseries/T/$I"][1, 1, 2:Nz+1]
end

In [179]:
z = file["grid/zC"]

anim = @gif for n=1:Nt
    t_str = @sprintf("%.2f", t[n] / 86400)
    plot(T[n, :], z, linewidth=2,
         xlim=(19, 20), ylim=(-100, 0), label="",
         xlabel="Temperature (C)", ylabel="Depth (z)",
         title="Free convection: $t_str days", show=false)
end

display(anim)

┌ Info: Saved animation to 
│   fn = /home/gridsan/aramadhan/6S898-climate-parameterization/notebooks/tmp.gif
└ @ Plots /home/gridsan/aramadhan/.julia/packages/Plots/Iuc9S/src/animation.jl:95


In [180]:
function coarse_grain(data, resolution)
    @assert length(data) % resolution == 0
    s = length(data) / resolution
    
    data_cs = zeros(resolution)
    for i in 1:resolution
        t = data[Int((i-1)*s+1):Int(i*s)]
        data_cs[i] = mean(t)
    end
    
    return data_cs
end

coarse_grain (generic function with 1 method)

In [210]:
coarse_resolution = cr = 32
T_cs = zeros(Nt, coarse_resolution)
for n=1:Nt
    T_cs[n, :] = coarse_grain(T[n, :], coarse_resolution)
end

T_cs = transpose(T_cs) |> Array;

In [211]:
dTdt_NN = Chain(Dense(cr, 2cr, tanh),
             Dense(2cr, cr))

ps = Flux.params(dTdt_NN)

n_subset = round(Int, Nt/2)
t_subset = t[1:n_subset] ./ 86400
tspan = (t[1], t_subset[end])
neural_pde_prediction = T₀ -> neural_ode(dTdt_NN, T₀, tspan, Tsit5(), saveat=t_subset, reltol=1e-7, abstol=1e-9)

#67 (generic function with 1 method)

In [212]:
data = Iterators.repeated((), 1000)
opt = ADAM(0.1)

loss_function() = sum(abs2, T_cs[:, 1:n_subset] .- neural_pde_prediction(T₀))

loss_function (generic function with 1 method)

In [213]:
# Callback function to observe training.
cb = function ()
    loss = loss_function()
    println("loss = $loss")
    loss < 1 && Flux.stop()
end

#69 (generic function with 1 method)

In [214]:
Flux.train!(loss_function, ps, data, opt, cb = cb)

loss = 2706.062446357436 (tracked)
loss = 1863.1633467202562 (tracked)
loss = 1874.8052247614066 (tracked)
loss = 1202.1088080301847 (tracked)
loss = 1031.3071229922414 (tracked)
loss = 899.2494160246133 (tracked)
loss = 883.4858394121286 (tracked)
loss = 818.7412086991643 (tracked)
loss = 615.3740432987674 (tracked)
loss = 488.32069426488954 (tracked)
loss = 510.75277191941296 (tracked)
loss = 546.4132893686234 (tracked)
loss = 440.36415616240976 (tracked)
loss = 319.13154066725747 (tracked)
loss = 299.75532367359057 (tracked)
loss = 334.4120427365194 (tracked)
loss = 300.9550825976251 (tracked)
loss = 242.0129499731759 (tracked)
loss = 212.84985526778138 (tracked)
loss = 198.35093825480388 (tracked)
loss = 177.7582885116287 (tracked)
loss = 172.46137484756642 (tracked)
loss = 172.57270369640378 (tracked)
loss = 134.6569214197189 (tracked)
loss = 96.90622106546068 (tracked)
loss = 109.56138860822719 (tracked)
loss = 131.28849376146394 (tracked)
loss = 100.74868996784267 (tracked)
loss

InterruptException: InterruptException:

In [215]:
tspan = (t[1], t[end]) ./ 86400
nn_pred = neural_ode(dTdt_NN, T₀, tspan, Tsit5(), saveat=t ./86400, reltol=1e-7, abstol=1e-9) |> Flux.data

z_cs = coarse_grain(z, cr)

anim = @gif for n=1:Nt
    t_str = @sprintf("%.2f", t[n] / 86400)
    plot(T_cs[:, n], z_cs, linewidth=2,
         xlim=(19, 20), ylim=(-100, 0), label="Data",
         xlabel="Temperature (C)", ylabel="Depth (z)",
         title="Free convection: $t_str days",
         legend=:bottomright, show=false)
    if n <= n_subset
        plot!(nn_pred[:, n], z_cs, linewidth=2, label="Neural ODE", show=false)
    else
        plot!(nn_pred[:, n], z_cs, linewidth=2, linestyle=:dash, label="Neural ODE", show=false)
    end
end

┌ Info: Saved animation to 
│   fn = /home/gridsan/aramadhan/6S898-climate-parameterization/notebooks/tmp.gif
└ @ Plots /home/gridsan/aramadhan/.julia/packages/Plots/Iuc9S/src/animation.jl:95
