In [5]:
using Pkg
pkg"st"
pkg"instantiate"

[36m[1mProject [22m[39mClimateParameterization v0.1.0
[32m[1m    Status[22m[39m `~/6S898-climate-parameterization/Project.toml`
 [90m [aae7a2af][39m[37m DiffEqFlux v0.7.0[39m
 [90m [0c46a032][39m[37m DifferentialEquations v6.8.0[39m
 [90m [587475ba][39m[37m Flux v0.9.0[39m
 [90m [033835bb][39m[37m JLD2 v0.1.3[39m
 [90m [9e8cae18][39m[37m Oceananigans v0.14.1[39m
 [90m [91a5bcdd][39m[37m Plots v0.27.0[39m


In [7]:
using Printf
using Statistics

using Flux
using DifferentialEquations
using DiffEqFlux

using JLD2
using Plots

In [8]:
file = jldopen("../data/ocean_convection_profiles.jld2");

In [9]:
Is = keys(file["timeseries/t"])

Nz = file["grid/Nz"]
Nt = length(Is)

t = zeros(Nt)
T = T_data = zeros(Nt, Nz)

for (i, I) in enumerate(Is)
    t[i] = file["timeseries/t/$I"]
    T[i, :] = file["timeseries/T/$I"][1, 1, 2:Nz+1]
end

In [18]:
z = file["grid/zC"]

anim = @gif for n=1:10:Nt
    t_str = @sprintf("%.2f", t[n] / 86400)
    plot(T[n, :], z, linewidth=2,
         xlim=(19, 20), ylim=(-100, 0), label="",
         xlabel="Temperature (C)", ylabel="Depth (z)",
         title="Free convection: $t_str days", show=false)
end

display(anim)

┌ Info: Saved animation to 
│   fn = /home/alir/6S898-climate-parameterization/notebooks/tmp.gif
└ @ Plots /home/alir/.julia/packages/Plots/Iuc9S/src/animation.jl:95


In [19]:
function coarse_grain(data, resolution)
    @assert length(data) % resolution == 0
    s = length(data) / resolution
    
    data_cs = zeros(resolution)
    for i in 1:resolution
        t = data[Int((i-1)*s+1):Int(i*s)]
        data_cs[i] = mean(t)
    end
    
    return data_cs
end

coarse_grain (generic function with 1 method)

In [20]:
coarse_resolution = cr = 32
T_cs = zeros(Nt, coarse_resolution)
for n=1:Nt
    T_cs[n, :] = coarse_grain(T[n, :], coarse_resolution)
end

T_cs = transpose(T_cs) |> Array;

In [21]:
dTdt_NN = Chain(Dense(cr, 2cr, tanh),
                Dense(2cr, cr))

ps = Flux.params(dTdt_NN)

T₀ = T_cs[:, 1]
n_train = round(Int, Nt/2)
t_train = t[1:n_train] ./ 86400
tspan_train = (t_train[1], t_train[end])
neural_pde_prediction = T₀ -> neural_ode(dTdt_NN, T₀, tspan_train, Tsit5(), saveat=t_train, reltol=1e-7, abstol=1e-9)

#7 (generic function with 1 method)

In [22]:
data = Iterators.repeated((), 1000)
opt = ADAM(0.1)

loss_function() = sum(abs2, T_cs[:, 1:n_train] .- neural_pde_prediction(T₀))

loss_function (generic function with 1 method)

In [35]:
# Callback function to observe training.
cb = function ()
    loss = loss_function()
    println("loss = $loss")
    loss < 5 && Flux.stop()
end

#17 (generic function with 1 method)

In [36]:
Flux.train!(loss_function, ps, data, opt, cb = cb)

loss = 10.466750298014475 (tracked)
loss = 9.460028443511291 (tracked)
loss = 9.101091559783775 (tracked)
loss = 8.786542533428557 (tracked)
loss = 8.467523758659178 (tracked)
loss = 8.42651734262251 (tracked)
loss = 7.691266135421882 (tracked)
loss = 7.708739061527556 (tracked)
loss = 7.781001504802827 (tracked)
loss = 7.026281444136295 (tracked)
loss = 7.1946710118797865 (tracked)
loss = 7.111724173159896 (tracked)
loss = 6.708637722931412 (tracked)
loss = 6.764039451760459 (tracked)
loss = 6.5863768793894675 (tracked)
loss = 6.525112712848079 (tracked)
loss = 6.392718291377726 (tracked)
loss = 6.271094376333861 (tracked)
loss = 6.318695909266896 (tracked)
loss = 6.149306210717498 (tracked)
loss = 6.065484472347909 (tracked)
loss = 6.113724328763357 (tracked)
loss = 6.0125371248927575 (tracked)
loss = 5.888194072623142 (tracked)
loss = 5.994350040221017 (tracked)
loss = 5.8591100833629195 (tracked)
loss = 5.822981129311834 (tracked)
loss = 5.850880308049506 (tracked)
loss = 5.7785849

InterruptException: InterruptException:

In [37]:
tspan = (t[1], t[end]) ./ 86400
nn_pred = neural_ode(dTdt_NN, T₀, tspan, Tsit5(), saveat=t ./86400, reltol=1e-7, abstol=1e-9) |> Flux.data

z_cs = coarse_grain(z, cr)

anim = @gif for n=1:10:Nt
    t_str = @sprintf("%.2f", t[n] / 86400)
    plot(T_cs[:, n], z_cs, linewidth=2,
         xlim=(19, 20), ylim=(-100, 0), label="Data",
         xlabel="Temperature (C)", ylabel="Depth (z)",
         title="Free convection: $t_str days",
         legend=:bottomright, show=false)
    if n <= n_train
        plot!(nn_pred[:, n], z_cs, linewidth=2, label="Neural ODE (train)", show=false)
    else
        plot!(nn_pred[:, n], z_cs, linewidth=2, linestyle=:dash, label="Neural ODE (test)", show=false)
    end
end

┌ Info: Saved animation to 
│   fn = /home/alir/6S898-climate-parameterization/notebooks/tmp.gif
└ @ Plots /home/alir/.julia/packages/Plots/Iuc9S/src/animation.jl:95
