In [1]:
using Printf
using Statistics

using Flux
using DifferentialEquations
using DiffEqFlux
using LinearAlgebra
using JLD2
using Plots

using Flux: @epochs

In [2]:
file = jldopen("../data/ocean_convection_profiles.jld2");

In [3]:
Is = keys(file["timeseries/t"])

Nz = file["grid/Nz"]
Nt = length(Is)

t = zeros(Nt)
T = T_data = zeros(Nt, Nz)

for (i, I) in enumerate(Is)
    t[i] = file["timeseries/t/$I"]
    T[i, :] = file["timeseries/T/$I"][1, 1, 2:Nz+1]
end

In [4]:
z = file["grid/zC"]

anim = @gif for n=1:10:Nt
    t_str = @sprintf("%.2f", t[n] / 86400)
    plot(T[n, :], z, linewidth=2,
         xlim=(19, 20), ylim=(-100, 0), label="",
         xlabel="Temperature (C)", ylabel="Depth (z)",
         title="Free convection: $t_str days", show=false)
end

display(anim)

┌ Info: Saved animation to 
│   fn = C:\Users\daddyj\Documents\6S898-climate-parameterization\notebooks\tmp.gif
└ @ Plots C:\Users\daddyj\.julia\packages\Plots\Iuc9S\src\animation.jl:95


In [5]:
function coarse_grain(data, resolution)
    @assert length(data) % resolution == 0
    s = length(data) / resolution
    
    data_cs = zeros(resolution)
    for i in 1:resolution
        t = data[Int((i-1)*s+1):Int(i*s)]
        data_cs[i] = mean(t)
    end
    
    return data_cs
end

coarse_grain (generic function with 1 method)

In [6]:
coarse_resolution = cr = 32
T_cs = zeros(Nt, coarse_resolution)
for n=1:Nt
    T_cs[n, :] = coarse_grain(T[n, :], coarse_resolution)
end

T_cs = transpose(T_cs) |> Array;

In [7]:
T₀ = T_cs[:, 1]
n_train = round(Int, Nt/2)
t_train = t[1:n_train] ./ 86400
tspan_train = (t_train[1], t_train[end])

(0.0, 3.993078849612772)

In [8]:
#dTdt_NN = Chain(Dense(cr, 2cr, tanh),
#                Dense(2cr, cr))

ann = Chain(Dense(cr, 2cr, tanh),
                Dense(2cr, cr))

n_ann = 2*cr^2 + 2cr + 2*cr^2+cr

#function dTdt_NN(du,u,p,t)
#    du = DiffEqFlux.restructure(ann,p[1:n_ann])(u)
#end

4192

In [60]:
p1 = Flux.data(DiffEqFlux.destructure(ann))
ps = Flux.params(ann)

Params([Float32[0.07075596 -0.11194432 … -0.098255575 0.037179112; -0.017896652 0.050411344 … -0.04322982 0.2449969; … ; -0.24679124 0.20533681 … 0.21790558 0.06270504; -0.083367825 -0.1452614 … 0.09139168 -0.0508523] (tracked), Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] (tracked), Float32[-0.052592814 0.050476134 … -0.17914873 0.030118942; 0.09448749 -0.07606655 … 0.21055168 0.24841237; … ; 0.2272628 -0.13822407 … 0.11296004 -0.10701066; 0.17436433 0.20521313 … 0.15868717 0.23479712] (tracked), Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] (tracked)])

In [61]:
cr_z = Nt/cr
Dzᶠ = 1/cr_z * Tridiagonal(-ones(cr-1), ones(cr), zeros(cr-1))
Dzᶜ = 1/cr_z * Tridiagonal(zeros(cr-1), -ones(cr), ones(cr-1))
Dzᶠ[1, 1] = 0
Dzᶜ[cr, cr] = 0
Dzᶜ

32×32 Tridiagonal{Float64,Array{Float64,1}}:
 -0.0277537   0.0277537    ⋅         …    ⋅           ⋅          ⋅       
  0.0        -0.0277537   0.0277537       ⋅           ⋅          ⋅       
   ⋅          0.0        -0.0277537       ⋅           ⋅          ⋅       
   ⋅           ⋅          0.0             ⋅           ⋅          ⋅       
   ⋅           ⋅           ⋅              ⋅           ⋅          ⋅       
   ⋅           ⋅           ⋅         …    ⋅           ⋅          ⋅       
   ⋅           ⋅           ⋅              ⋅           ⋅          ⋅       
   ⋅           ⋅           ⋅              ⋅           ⋅          ⋅       
   ⋅           ⋅           ⋅              ⋅           ⋅          ⋅       
   ⋅           ⋅           ⋅              ⋅           ⋅          ⋅       
   ⋅           ⋅           ⋅         …    ⋅           ⋅          ⋅       
   ⋅           ⋅           ⋅              ⋅           ⋅          ⋅       
   ⋅           ⋅           ⋅              ⋅           ⋅          ⋅ 

In [62]:
#neural_pde_prediction(T₀) = neural_ode(dudt_, T₀, tspan_train, Tsit5(), saveat=t_train, reltol=1e-7, abstol=1e-9)
#function dudt_(du,u,p,t)
#    du = Flux.data(DiffEqFlux.restructure(ann,p[1:n_ann])(u))
#    println("du",du)
#end

In [63]:
foretold(u,p,t) = DiffEqFlux.restructure(ann,p[1:n_ann])(u)

foretold (generic function with 1 method)

In [64]:
prob = ODEProblem(foretold,T₀,tspan_train,p1)
Flux.Tracker.collect(diffeq_adjoint(p1,prob,Tsit5(),u0=T₀,abstol=1e-8,reltol=1e-6))

Tracked 32×16 Array{Float64,2}:
 19.0156  18.971   18.918   18.8491  …  16.9809  16.7607  16.6239  16.5437
 19.0469  18.9734  18.8836  18.7631     13.9446  13.1631  12.3825  11.5755
 19.0781  19.0817  19.0859  19.0912     19.526   19.6403  19.731   19.8042
 19.1094  19.1927  19.2942  19.4297     24.9482  25.993   26.9685  27.9093
 19.1406  19.1162  19.0864  19.0465     17.6889  17.5517  17.4831  17.4576
 19.1719  19.1929  19.2193  19.2558  …  21.1593  21.4233  21.6363  21.8242
 19.2031  19.1639  19.1156  19.0503     16.6969  16.3482  16.0168  15.6905
 19.2344  19.2319  19.2277  19.2203     17.9394  17.7074  17.5383  17.402 
 19.2656  19.2489  19.2293  19.2043     18.8467  19.0163  19.251   19.5254
 19.2969  19.2051  19.0938  18.9461     13.2622  12.4415  11.7417  11.0982
 19.3281  19.3193  19.3074  19.2897  …  17.883   17.4894  17.1068  16.7291
 19.3594  19.4373  19.5323  19.6593     24.7985  25.6569  26.4273  27.1576
 19.3906  19.4242  19.4647  19.5181     20.9213  20.9984  21.1001  2

In [87]:
function predict_adjoint()
  diffeq_adjoint(p1,prob,Tsit5(),u0=T₀,saveat=t_train,reltol=1e-6, abstol=1e-8)
end

predict_adjoint (generic function with 1 method)

In [105]:
opt = ADAM(1)

data = [(T₀, T_cs[:, 1:n_train])]
#data = Iterators.repeated((T₀, T_cs[:, 1:n_train]), 1000)

loss_function(T₀, T_data) = sum(abs2, T_data .- predict_adjoint())

loss_function (generic function with 1 method)

In [106]:
x = rand(3)
println(x)
sum(abs2,x)

[0.9287829131549417, 0.27911308713942073, 0.4503168752776665]


1.1433271033409191

In [107]:
# Callback function to observe training.
cb = function ()
    #println(DiffEqFlux.destructure(ann)[1:10])
    loss = loss_function(T₀, T_cs[:, 1:n_train]) # Not very generalizable...
    println("loss = $loss")
end

#35 (generic function with 1 method)

In [108]:
for _ in 1:100
    Flux.train!(loss_function, ps, data, opt, cb = cb)
end

MethodError: MethodError: no method matching back!(::Float64)
Closest candidates are:
  back!(::Any, !Matched::Any; once) at C:\Users\daddyj\.julia\packages\Tracker\JhqMQ\src\back.jl:75
  back!(!Matched::Tracker.TrackedReal; once) at C:\Users\daddyj\.julia\packages\Tracker\JhqMQ\src\lib\real.jl:14
  back!(!Matched::TrackedArray) at C:\Users\daddyj\.julia\packages\Tracker\JhqMQ\src\lib\array.jl:68

In [74]:
tspan = (t[1], t[end]) ./ 86400
nn_pred = neural_ode(dudt_, T₀, tspan, Tsit5(), saveat=t ./86400, reltol=1e-7, abstol=1e-9) |> Flux.data

z_cs = coarse_grain(z, cr)

anim = @gif for n=1:10:Nt
    t_str = @sprintf("%.2f", t[n] / 86400)
    plot(T_cs[:, n], z_cs, linewidth=2,
         xlim=(19, 20), ylim=(-100, 0), label="Data",
         xlabel="Temperature (C)", ylabel="Depth (z)",
         title="Free convection: $t_str days",
         legend=:bottomright, show=false)
    if n <= n_train
        plot!(nn_pred[:, n], z_cs, linewidth=2, label="Neural ODE (train)", show=false)
    else
        plot!(nn_pred[:, n], z_cs, linewidth=2, linestyle=:dash, label="Neural ODE (test)", show=false)
    end
end

UndefVarError: UndefVarError: dudt_ not defined