In [25]:
using Printf
using Statistics

using Flux
using DifferentialEquations
using DiffEqFlux
using LinearAlgebra
using JLD2
using Plots

using Flux: @epochs

In [2]:
file = jldopen("../data/ocean_convection_profiles.jld2");

In [3]:
Is = keys(file["timeseries/t"])

Nz = file["grid/Nz"]
Nt = length(Is)

t = zeros(Nt)
T = T_data = zeros(Nt, Nz)

for (i, I) in enumerate(Is)
    t[i] = file["timeseries/t/$I"]
    T[i, :] = file["timeseries/T/$I"][1, 1, 2:Nz+1]
end

In [4]:
z = file["grid/zC"]

anim = @gif for n=1:10:Nt
    t_str = @sprintf("%.2f", t[n] / 86400)
    plot(T[n, :], z, linewidth=2,
         xlim=(19, 20), ylim=(-100, 0), label="",
         xlabel="Temperature (C)", ylabel="Depth (z)",
         title="Free convection: $t_str days", show=false)
end

display(anim)

┌ Info: Saved animation to 
│   fn = C:\Users\daddyj\Documents\6S898-climate-parameterization\notebooks\tmp.gif
└ @ Plots C:\Users\daddyj\.julia\packages\Plots\Iuc9S\src\animation.jl:95


In [5]:
function coarse_grain(data, resolution)
    @assert length(data) % resolution == 0
    s = length(data) / resolution
    
    data_cs = zeros(resolution)
    for i in 1:resolution
        t = data[Int((i-1)*s+1):Int(i*s)]
        data_cs[i] = mean(t)
    end
    
    return data_cs
end

coarse_grain (generic function with 1 method)

In [6]:
coarse_resolution = cr = 32
T_cs = zeros(Nt, coarse_resolution)
for n=1:Nt
    T_cs[n, :] = coarse_grain(T[n, :], coarse_resolution)
end

T_cs = transpose(T_cs) |> Array;

In [7]:
T₀ = T_cs[:, 1]
n_train = round(Int, Nt/2)
t_train = t[1:n_train] ./ 86400
tspan_train = (t_train[1], t_train[end])

(0.0, 3.993078849612772)

In [80]:
#dTdt_NN = Chain(Dense(cr, 2cr, tanh),
#                Dense(2cr, cr))

ann = Chain(Dense(cr, 2cr, tanh),
                Dense(2cr, cr))

n_ann = 2*cr^2 + 2cr + 2*cr^2+cr

#function dTdt_NN(du,u,p,t)
#    du = DiffEqFlux.restructure(ann,p[1:n_ann])(u)
#end

4192

In [81]:
p1 = Flux.data(DiffEqFlux.destructure(ann))
ps = Flux.params(ann)

Params([Float32[-0.2003274 0.08194095 … -0.1919033 -0.053825617; 0.24139124 -0.23989993 … -0.18962204 -0.17265826; … ; 0.16427964 0.19293076 … -0.0038294792 -0.11314386; -0.1969539 0.09005076 … -0.1270439 0.04919845] (tracked), Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] (tracked), Float32[0.22591776 -0.08994216 … -0.16212755 -0.08436048; 0.22652137 -0.051735282 … -0.22390294 -0.024571836; … ; 0.24483758 0.24031246 … -0.17377567 -0.14136171; -0.07514995 -0.08042383 … 0.113624394 0.20905578] (tracked), Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] (tracked)])

In [82]:
cr_z = Nt/cr
Dzᶠ = 1/cr_z * Tridiagonal(-ones(cr-1), ones(cr), zeros(cr-1))
Dzᶜ = 1/cr_z * Tridiagonal(zeros(cr-1), -ones(cr), ones(cr-1))
Dzᶠ[1, 1] = 0
Dzᶜ[cr, cr] = 0
Dzᶜ

32×32 Tridiagonal{Float64,Array{Float64,1}}:
 -0.0277537   0.0277537    ⋅         …    ⋅           ⋅          ⋅       
  0.0        -0.0277537   0.0277537       ⋅           ⋅          ⋅       
   ⋅          0.0        -0.0277537       ⋅           ⋅          ⋅       
   ⋅           ⋅          0.0             ⋅           ⋅          ⋅       
   ⋅           ⋅           ⋅              ⋅           ⋅          ⋅       
   ⋅           ⋅           ⋅         …    ⋅           ⋅          ⋅       
   ⋅           ⋅           ⋅              ⋅           ⋅          ⋅       
   ⋅           ⋅           ⋅              ⋅           ⋅          ⋅       
   ⋅           ⋅           ⋅              ⋅           ⋅          ⋅       
   ⋅           ⋅           ⋅              ⋅           ⋅          ⋅       
   ⋅           ⋅           ⋅         …    ⋅           ⋅          ⋅       
   ⋅           ⋅           ⋅              ⋅           ⋅          ⋅       
   ⋅           ⋅           ⋅              ⋅           ⋅          ⋅ 

In [83]:
#neural_pde_prediction(T₀) = neural_ode(dudt_, T₀, tspan_train, Tsit5(), saveat=t_train, reltol=1e-7, abstol=1e-9)
#function dudt_(du,u,p,t)
#    du = Flux.data(DiffEqFlux.restructure(ann,p[1:n_ann])(u))
#    println("du",du)
#end

In [84]:
foretold(u,p,t) = Dzᶜ*((DiffEqFlux.restructure(ann,p[1:n_ann])(u)).*(Dzᶠ*u))

foretold (generic function with 1 method)

In [85]:
prob = ODEProblem(foretold,T₀,tspan_train,p1)
Flux.Tracker.collect(diffeq_adjoint(p1,prob,Tsit5(),u0=T₀,abstol=1e-8,reltol=1e-6))

Tracked 32×5 Array{Float64,2}:
 19.0156  19.0156  19.0156  19.0155  19.0155
 19.0469  19.0469  19.0469  19.047   19.047 
 19.0781  19.0781  19.078   19.0775  19.0772
 19.1094  19.1094  19.1095  19.1099  19.1102
 19.1406  19.1406  19.1406  19.1407  19.1407
 19.1719  19.1719  19.172   19.1722  19.1724
 19.2031  19.2031  19.203   19.2027  19.2026
 19.2344  19.2344  19.2344  19.2343  19.2343
 19.2656  19.2656  19.2657  19.2658  19.2659
 19.2969  19.2969  19.2969  19.2969  19.2969
 19.3281  19.3281  19.3282  19.3285  19.3287
 19.3594  19.3593  19.3592  19.3587  19.3584
 19.3906  19.3906  19.3907  19.391   19.3912
  ⋮                                         
 19.6406  19.6406  19.6407  19.641   19.6412
 19.6719  19.6719  19.6718  19.6717  19.6716
 19.7031  19.7031  19.7031  19.703   19.703 
 19.7344  19.7344  19.7344  19.7343  19.7342
 19.7656  19.7656  19.7657  19.7658  19.7659
 19.7969  19.7969  19.7969  19.7971  19.7972
 19.8281  19.8281  19.828   19.8277  19.8275
 19.8594  19.8594  19.85

In [86]:
function predict_adjoint()
  Flux.Tracker.collect(diffeq_adjoint(p1,prob,Tsit5(),u0=T₀,saveat=t_train,reltol=1e-6, abstol=1e-8))
end

predict_adjoint (generic function with 1 method)

In [87]:
opt = ADAM(1)

data = [(T₀, T_cs[:, 1:n_train])]

loss_function(T₀, T_data) = sum(abs2, T_data .- predict_adjoint())

loss_function (generic function with 1 method)

In [90]:
# Callback function to observe training.
cb = function ()
    #println(DiffEqFlux.destructure(ann)[1:10])
    loss = loss_function(T₀, T_cs[:, 1:n_train]) # Not very generalizable...
    println("loss = $loss")
end

#21 (generic function with 1 method)

In [91]:
for _ in 1:100
    Flux.train!(loss_function, ps, data, opt, cb = cb)
end

loss = 180.21029942863254 (tracked)
loss = 180.21029942863254 (tracked)
loss = 180.21029942863254 (tracked)
loss = 180.21029942863254 (tracked)
loss = 180.21029942863254 (tracked)
loss = 180.21029942863254 (tracked)
loss = 180.21029942863254 (tracked)
loss = 180.21029942863254 (tracked)
loss = 180.21029942863254 (tracked)
loss = 180.21029942863254 (tracked)
loss = 180.21029942863254 (tracked)
loss = 180.21029942863254 (tracked)
loss = 180.21029942863254 (tracked)
loss = 180.21029942863254 (tracked)
loss = 180.21029942863254 (tracked)
loss = 180.21029942863254 (tracked)
loss = 180.21029942863254 (tracked)
loss = 180.21029942863254 (tracked)
loss = 180.21029942863254 (tracked)
loss = 180.21029942863254 (tracked)
loss = 180.21029942863254 (tracked)
loss = 180.21029942863254 (tracked)
loss = 180.21029942863254 (tracked)
loss = 180.21029942863254 (tracked)
loss = 180.21029942863254 (tracked)
loss = 180.21029942863254 (tracked)
loss = 180.21029942863254 (tracked)
loss = 180.21029942863254 (t

In [25]:
tspan = (t[1], t[end]) ./ 86400
nn_pred = neural_ode(dudt_, T₀, tspan, Tsit5(), saveat=t ./86400, reltol=1e-7, abstol=1e-9) |> Flux.data

z_cs = coarse_grain(z, cr)

anim = @gif for n=1:10:Nt
    t_str = @sprintf("%.2f", t[n] / 86400)
    plot(T_cs[:, n], z_cs, linewidth=2,
         xlim=(19, 20), ylim=(-100, 0), label="Data",
         xlabel="Temperature (C)", ylabel="Depth (z)",
         title="Free convection: $t_str days",
         legend=:bottomright, show=false)
    if n <= n_train
        plot!(nn_pred[:, n], z_cs, linewidth=2, label="Neural ODE (train)", show=false)
    else
        plot!(nn_pred[:, n], z_cs, linewidth=2, linestyle=:dash, label="Neural ODE (test)", show=false)
    end
end

┌ Info: Saved animation to 
│   fn = C:\Users\daddyj\Documents\6S898-climate-parameterization\notebooks\tmp.gif
└ @ Plots C:\Users\daddyj\.julia\packages\Plots\Iuc9S\src\animation.jl:95
