In [15]:
using Printf
using Statistics

using Flux
using DifferentialEquations
using DiffEqFlux

using JLD2
using Plots

using Flux: @epochs

In [2]:
file = jldopen("../data/ocean_convection_profiles.jld2");

In [3]:
Is = keys(file["timeseries/t"])

Nz = file["grid/Nz"]
Nt = length(Is)

t = zeros(Nt)
T = T_data = zeros(Nt, Nz)

for (i, I) in enumerate(Is)
    t[i] = file["timeseries/t/$I"]
    T[i, :] = file["timeseries/T/$I"][1, 1, 2:Nz+1]
end

In [4]:
z = file["grid/zC"]

anim = @gif for n=1:10:Nt
    t_str = @sprintf("%.2f", t[n] / 86400)
    plot(T[n, :], z, linewidth=2,
         xlim=(19, 20), ylim=(-100, 0), label="",
         xlabel="Temperature (C)", ylabel="Depth (z)",
         title="Free convection: $t_str days", show=false)
end

display(anim)

┌ Info: Saved animation to 
│   fn = /home/alir/6S898-climate-parameterization/notebooks/tmp.gif
└ @ Plots /home/alir/.julia/packages/Plots/Iuc9S/src/animation.jl:95


In [5]:
function coarse_grain(data, resolution)
    @assert length(data) % resolution == 0
    s = length(data) / resolution
    
    data_cs = zeros(resolution)
    for i in 1:resolution
        t = data[Int((i-1)*s+1):Int(i*s)]
        data_cs[i] = mean(t)
    end
    
    return data_cs
end

coarse_grain (generic function with 1 method)

In [6]:
coarse_resolution = cr = 32
T_cs = zeros(Nt, coarse_resolution)
for n=1:Nt
    T_cs[n, :] = coarse_grain(T[n, :], coarse_resolution)
end

T_cs = transpose(T_cs) |> Array;

In [7]:
dTdt_NN = Chain(Dense(cr, 2cr, tanh),
                Dense(2cr, cr))

ps = Flux.params(dTdt_NN)

T₀ = T_cs[:, 1]
n_train = round(Int, Nt/2)
t_train = t[1:n_train] ./ 86400
tspan_train = (t_train[1], t_train[end])

neural_pde_prediction(T₀) = neural_ode(dTdt_NN, T₀, tspan_train, Tsit5(), saveat=t_train, reltol=1e-7, abstol=1e-9)

neural_pde_prediction (generic function with 1 method)

In [8]:
opt = ADAM(0.1)

data = [(T₀, T_cs[:, 1:n_train])]

loss_function(T₀, T_data) = sum(abs2, T_data .- neural_pde_prediction(T₀))

loss_function (generic function with 1 method)

In [16]:
# Callback function to observe training.
cb = function ()
    loss = loss_function(T₀, T_cs[:, 1:n_train]) # Not very generalizable...
    println("loss = $loss")
end

#9 (generic function with 1 method)

In [18]:
for _ in 1:10
    Flux.train!(loss_function, ps, data, opt, cb = cb)
end

loss = 89088.93874210154 (tracked)
loss = 59804.84817204311 (tracked)
loss = 52345.53072604725 (tracked)
loss = 61221.80255229286 (tracked)
loss = 53524.74858639417 (tracked)
loss = 39311.858308891184 (tracked)
loss = 38274.32900512774 (tracked)
loss = 40469.45639064265 (tracked)
loss = 32620.42814464225 (tracked)
loss = 26183.02627073725 (tracked)


In [None]:
tspan = (t[1], t[end]) ./ 86400
nn_pred = neural_ode(dTdt_NN, T₀, tspan, Tsit5(), saveat=t ./86400, reltol=1e-7, abstol=1e-9) |> Flux.data

z_cs = coarse_grain(z, cr)

anim = @gif for n=1:10:Nt
    t_str = @sprintf("%.2f", t[n] / 86400)
    plot(T_cs[:, n], z_cs, linewidth=2,
         xlim=(19, 20), ylim=(-100, 0), label="Data",
         xlabel="Temperature (C)", ylabel="Depth (z)",
         title="Free convection: $t_str days",
         legend=:bottomright, show=false)
    if n <= n_train
        plot!(nn_pred[:, n], z_cs, linewidth=2, label="Neural ODE (train)", show=false)
    else
        plot!(nn_pred[:, n], z_cs, linewidth=2, linestyle=:dash, label="Neural ODE (test)", show=false)
    end
end