In [1]:
using Printf
using Statistics
using LinearAlgebra

using Flux
using DifferentialEquations
using DiffEqFlux

using JLD2
using Plots

using Flux: @epochs

In [2]:
file = jldopen("../data/ocean_convection_profiles.jld2");

In [57]:
Is = keys(file["timeseries/t"])

Nz = file["grid/Nz"]
Lz = file["grid/Lz"]
Nt = length(Is)

t = zeros(Nt)
T = T_data = zeros(Nt, Nz)
wT = zeros(Nt, Nz)

for (i, I) in enumerate(Is)
    t[i] = file["timeseries/t/$I"]
    T[i, :] = file["timeseries/T/$I"][1, 1, 2:Nz+1]
    wT[i, :] = file["timeseries/wT/$I"][1, 1, 2:Nz+1]
end

In [4]:
z = file["grid/zC"]

anim = @animate for n=1:10:Nt
    t_str = @sprintf("%.2f", t[n] / 86400)
    plot(T[n, :], z, linewidth=2,
         xlim=(19, 20), ylim=(-100, 0), label="",
         xlabel="Temperature (C)", ylabel="Depth (z)",
         title="Deepening mixed layer: $t_str days", show=false)
end

gif(anim, "deepening_mixed_layer.gif", fps=15)

┌ Info: Saved animation to 
│   fn = /home/alir/6S898-climate-parameterization/notebooks/deepening_mixed_layer.gif
└ @ Plots /home/alir/.julia/packages/Plots/Iuc9S/src/animation.jl:95


In [58]:
function coarse_grain(data, resolution)
    @assert length(data) % resolution == 0
    s = length(data) / resolution
    
    data_cs = zeros(resolution)
    for i in 1:resolution
        t = data[Int((i-1)*s+1):Int(i*s)]
        data_cs[i] = mean(t)
    end
    
    return data_cs
end

coarse_grain (generic function with 1 method)

In [60]:
coarse_resolution = cr = 32

T_cs = zeros(coarse_resolution+2, Nt)
wT_cs = zeros(coarse_resolution+2, Nt)

for n=1:Nt
    T_cs[2:end-1, n] .= coarse_grain(T[n, :], coarse_resolution)
    wT_cs[2:end-1, n] .= coarse_grain(wT[n, :], coarse_resolution)
end


T_cs[1,   :] .= T_cs[2,     :]
T_cs[end, :] .= T_cs[end-1, :]

wT_cs[1,   :] .= wT_cs[2,     :]
wT_cs[end, :] .= wT_cs[end-1, :]

34×1153 Array{Float64,2}:
 0.0  -2.30125e-26  8.71449e-26   9.19177e-26  …   6.32709e-8   3.07046e-8
 0.0  -2.30125e-26  8.71449e-26   9.19177e-26      6.32709e-8   3.07046e-8
 0.0  -6.18334e-26  1.07636e-25   1.55799e-25      9.17763e-7   1.3074e-7 
 0.0  -7.8825e-26   1.35601e-25   2.0034e-25       1.83954e-6  -1.31288e-6
 0.0  -9.93663e-26  1.74263e-25   2.55798e-25      2.16761e-6   7.44881e-7
 0.0  -1.31238e-25  2.30056e-25   3.23143e-25  …   4.87671e-7   1.16325e-6
 0.0  -1.69133e-25  2.94401e-25   4.12384e-25      1.82237e-6  -4.41882e-7
 0.0  -2.16432e-25  3.72444e-25   5.35574e-25      7.72243e-7  -2.58858e-6
 0.0  -2.8043e-25   4.82045e-25   6.81883e-25     -4.281e-7    -1.68256e-6
 0.0  -3.5289e-25   6.09157e-25   8.65815e-25     -6.44661e-7  -1.21975e-7
 0.0  -4.49962e-25  7.89255e-25   1.08192e-24  …   1.70495e-6  -2.21358e-6
 0.0  -5.88042e-25  1.0169e-24    1.38808e-24     -1.10878e-6  -6.40545e-7
 0.0  -7.46524e-25  1.29864e-24   1.73614e-24     -1.78399e-6  -1.38452e-6

In [61]:
Tₙ   = zeros(coarse_resolution+2, Nt-1)
wTₙ  = zeros(coarse_resolution+2, Nt-1)
Tₙ₊₁ = zeros(coarse_resolution+2, Nt-1)

for i in 1:Nt-1
      Tₙ[:, i] .= T_cs[:,   i]
     wTₙ[:, i] .= wT_cs[:,   i]
    Tₙ₊₁[:, i] .= T_cs[:, i+1]
end

In [62]:
# Generate differentiation matrices
cr_Δz = Lz / cr  # Coarse resolution Δz

# Dzᶠ computes the derivative from cell center to cell (F)aces
Dzᶠ = 1/cr_Δz * Tridiagonal(-ones(cr+1), ones(cr+2), zeros(cr+1))

# Dzᶜ computes the derivative from cell faces to cell (C)enters
Dzᶜ = 1/cr_Δz * Tridiagonal(zeros(cr+1), -ones(cr+2), ones(cr+1))

# Impose boundary condition that derivative goes to zero at top and bottom.
Dzᶠ[1, 1] = 0
Dzᶜ[cr, cr] = 0;

In [63]:
dTdt_NN = Chain(Dense(cr+2,  2cr, tanh),
                Dense(2cr,  cr+2))
# dTdt_NN = Chain(T -> Dzᶠ*T,
#                 Dense(cr+2,  2cr, tanh),
#                 Dense(2cr,  cr+2),
#                 NNDzT -> Dzᶜ * NNDzT)

NN_params = Flux.params(dTdt_NN)

tspan = (0.0, 600.0)  # 10 minutes
neural_pde_prediction(T₀) = neural_ode(dTdt_NN, T₀, tspan, Tsit5(), reltol=1e-4, save_start=false, saveat=tspan[2])

neural_pde_prediction (generic function with 1 method)

In [101]:
N = 100
pre_training_data = [(Tₙ[:, i], wTₙ[:, i]) for i in 5:N+5]
pre_loss_function(Tₙ, wTₙ) = sum(abs2, dTdt_NN(Tₙ) .- wTₙ)
popt = ADAM(0.01)

ADAM(0.01, (0.9, 0.999), IdDict{Any,Any}())

In [103]:
function precb()
    loss = sum(abs2, [pre_loss_function(pre_training_data[i]...) for i in 1:N-1])
    println("loss = $loss")
end

precb()

loss = 0.001050726170451445 (tracked)


In [104]:
Flux.train!(pre_loss_function, NN_params, pre_training_data, popt, cb = precb)

loss = 19441.526925437374 (tracked)
loss = 82.43273897664903 (tracked)
loss = 890.7500012397583 (tracked)
loss = 6427.454689818436 (tracked)
loss = 3643.4587697000975 (tracked)
loss = 291.6611854109149 (tracked)
loss = 0.015156221095973825 (tracked)
loss = 298.20051651815743 (tracked)
loss = 1449.8994643600556 (tracked)
loss = 1356.6842617207976 (tracked)
loss = 338.19445523903846 (tracked)
loss = 5.8585415123809845 (tracked)
loss = 3.3051187331028045 (tracked)
loss = 163.2637168144713 (tracked)
loss = 447.0136345896831 (tracked)
loss = 335.57739758404034 (tracked)
loss = 67.83323198001965 (tracked)
loss = 0.539355046660905 (tracked)
loss = 2.3871967429099303 (tracked)
loss = 59.48885849337132 (tracked)
loss = 136.8708385903774 (tracked)
loss = 89.88745217276015 (tracked)
loss = 14.242353850556823 (tracked)
loss = 0.030499012937749793 (tracked)
loss = 1.855291727856702 (tracked)
loss = 24.432446195459665 (tracked)
loss = 43.15165203510688 (tracked)
loss = 21.335359358937215 (tracked)
l

In [105]:
opt = ADAM(0.1)

N = 10
training_data = [(Tₙ[:, i], Tₙ₊₁[:, i]) for i in 1:N]

loss_function(Tₙ, Tₙ₊₁) = sum(abs2, Tₙ₊₁ .- neural_pde_prediction(Tₙ))

loss_function (generic function with 1 method)

In [106]:
# Callback function to observe training.
cb = function ()
    nn_pred = neural_ode(dTdt_NN, Tₙ[:, 1], (t[1], t[N]), Tsit5(), saveat=t[1:N], reltol=1e-4) |> Flux.data
    loss = sum(abs2, T_cs[:, 1:N] .- nn_pred)
    println("total loss = $loss")
end

cb()

total loss = 6826.981446915973


In [108]:
@epochs 1 Flux.train!(loss_function, NN_params, training_data, opt, cb=cb) # cb=Flux.throttle(cb, 20))

┌ Info: Epoch 1
└ @ Main /home/alir/.julia/packages/Flux/dkJUV/src/optimise/train.jl:105


total loss = 6826.981446915973
total loss = 6826.981446915973
total loss = 6826.981446915973
total loss = 6826.981446915973
total loss = 6826.981446915973
total loss = 6826.981446915973
total loss = 6826.981446915973
total loss = 6826.981446915973
total loss = 6826.981446915973
total loss = 6826.981446915973


In [13]:
tspan2 = (t[1], t[end])
nn_pred = neural_ode(dTdt_NN, Tₙ[:, 1], tspan2, Tsit5(), saveat=t, reltol=1e-4) |> Flux.data

z_cs = coarse_grain(z, cr)

anim = @animate for n=1:10:Nt
    t_str = @sprintf("%.2f", t[n] / 86400)
    plot(T_cs[2:end-1, n], z_cs, linewidth=2,
         xlim=(19, 20), ylim=(-100, 0), label="Data",
         xlabel="Temperature (C)", ylabel="Depth (z)",
         title="Deepening mixed layer: $t_str days",
         legend=:bottomright, show=false)
    plot!(nn_pred[2:end-1, n], z_cs, linewidth=2, label="Neural PDE", show=false)
end

gif(anim, "deepening_mixed_layer_neural_PDE.gif", fps=15)

┌ Info: Saved animation to 
│   fn = /home/alir/6S898-climate-parameterization/notebooks/deepening_mixed_layer_neural_PDE.gif
└ @ Plots /home/alir/.julia/packages/Plots/Iuc9S/src/animation.jl:95
