In [1]:
using Printf
using Statistics

using Flux
using DifferentialEquations
using DiffEqFlux
using LinearAlgebra
using JLD2
using Plots
using Pkg

using Flux: @epochs
#ENV["GRDIR"]=""
#Pkg.build("GR")

In [2]:
file = jldopen("../data/ocean_convection_profiles.jld2");

In [3]:
Is = keys(file["timeseries/t"])

Nz = file["grid/Nz"]
Nt = length(Is)

t = zeros(Nt)
T = T_data = zeros(Nt, Nz)

for (i, I) in enumerate(Is)
    t[i] = file["timeseries/t/$I"]
    T[i, :] = file["timeseries/T/$I"][1, 1, 2:Nz+1]
end

In [4]:
keys(file["timeseries/T/0"])

1×1×258 CartesianIndices{3,Tuple{Base.OneTo{Int64},Base.OneTo{Int64},Base.OneTo{Int64}}}:
[:, :, 1] =
 CartesianIndex(1, 1, 1)

[:, :, 2] =
 CartesianIndex(1, 1, 2)

[:, :, 3] =
 CartesianIndex(1, 1, 3)

...

[:, :, 256] =
 CartesianIndex(1, 1, 256)

[:, :, 257] =
 CartesianIndex(1, 1, 257)

[:, :, 258] =
 CartesianIndex(1, 1, 258)

In [5]:
z = file["grid/zC"]

256-element Array{Float64,1}:
 -99.8046875
 -99.4140625
 -99.0234375
 -98.6328125
 -98.2421875
 -97.8515625
 -97.4609375
 -97.0703125
 -96.6796875
 -96.2890625
 -95.8984375
 -95.5078125
 -95.1171875
   ⋮        
  -4.4921875
  -4.1015625
  -3.7109375
  -3.3203125
  -2.9296875
  -2.5390625
  -2.1484375
  -1.7578125
  -1.3671875
  -0.9765625
  -0.5859375
  -0.1953125

In [6]:

anim = @gif for n=1:10:Nt
    t_str = @sprintf("%.2f", t[n] / 86400)
    plot(T[n, :], z, linewidth=2,
         xlim=(19, 20), ylim=(-100, 0), label="",
         xlabel="Temperature (C)", ylabel="Depth (z)",
         title="Free convection: $t_str days", show=false)
end

display(anim)

┌ Info: Saved animation to 
│   fn = C:\Users\daddyj\Documents\6S898-climate-parameterization\notebooks\tmp.gif
└ @ Plots C:\Users\daddyj\.julia\packages\Plots\Iuc9S\src\animation.jl:95


In [7]:
function coarse_grain(data, resolution)
    @assert length(data) % resolution == 0
    s = length(data) / resolution
    
    data_cs = zeros(resolution)
    for i in 1:resolution
        t = data[Int((i-1)*s+1):Int(i*s)]
        data_cs[i] = mean(t)
    end
    
    return data_cs
end

coarse_grain (generic function with 1 method)

In [8]:
coarse_resolution = cr = 32
T_cs = zeros(Nt, coarse_resolution)
for n=1:Nt
    T_cs[n, :] = coarse_grain(T[n, :], coarse_resolution)
end

T_cs = transpose(T_cs) |> Array;

w_cs = zeros(Nt, coarse_resolution)
for n=1:Nt
    w_cs[n,:] = coarse_grain(T[n,:],coarse_resolution)
end
w_cs = transpose(w_cs) |> Array;

In [9]:
cr_z = Nt/cr
Dzᶠ = 1/cr_z * Tridiagonal(-ones(cr-1), ones(cr), zeros(cr-1))
Dzᶜ = 1/cr_z * Tridiagonal(zeros(cr-1), -ones(cr), ones(cr-1))
Dzᶠ[1, 1] = 0
Dzᶜ[cr, cr] = 0

0

In [10]:
ann = Chain(Dense(cr, 2cr, tanh),
            Dense(2cr, 2cr, tanh),
                Dense(2cr, cr))

n_ann = 2*cr^2 + 2cr + 4*cr^2+2cr+2*cr^2+cr


8352

In [14]:
function dTdt_NN(du,u,p,t)
    T = u
    du = DiffEqFlux.restructure(ann,p[1:n_ann])(u)
end

foretold(u::Array{Float64,1},p::Array{Float64,1},t) = Dzᶜ*(ann(u).*(Dzᶠ*u))

p1 = Flux.data(DiffEqFlux.destructure(ann))
T₀ = T_cs[:, 1]
ps = Flux.params(ann,T₀)

n_train = round(Int, Nt/2)
t_train = t[1:n_train] ./ 86400
tspan_train = (t_train[1], t_train[end])

(0.0, 3.993078849612772)

In [15]:
u = T₀
Dzᶠ*u

32-element Array{Float64,1}:
 0.0                  
 0.0008673026886383273
 0.0008673026886383273
 0.0008673026886383273
 0.0008673026886383273
 0.0008673026886382162
 0.0008673026886383273
 0.0008673026886385493
 0.0008673026886383273
 0.0008673026886387714
 0.0008673026886374391
 0.0008673026886383273
 0.0008673026886387714
 ⋮                    
 0.0008673026886386603
 0.0008673026886384383
 0.0008673026886377722
 0.000867302688637106 
 0.0008673026886377722
 0.0008673026886364399
 0.0008673026886432122
 0.0008673026886376611
 0.0008673026886394375
 0.000867302688637106 
 0.0008673026886423241
 0.0008673026886402146

In [16]:
ps

Params([Float32[0.20073187 0.1302979 … -0.038915932 0.010187089; -0.076903105 -0.005175829 … -0.16207802 0.05962795; … ; 0.21402717 -0.21143895 … -0.075412035 -0.18205726; 0.046816707 -0.110325396 … 0.20571941 -0.24642944] (tracked), Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] (tracked), Float32[-0.03536685 0.00016931076 … 0.19197507 0.05285273; -0.062483467 -0.11126711 … 0.08967395 -0.052125726; … ; -0.08816853 -0.13742511 … 0.0010865312 -0.1627027; -0.13865845 -0.105989985 … 0.06403664 -0.14208084] (tracked), Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] (tracked), Float32[-0.12380403 0.21538436 … 0.20699781 0.19908774; -0.15590972 -0.100210786 … -0.1958363 -0.09916198; … ; 0.1042552 -0.19593686 … 0.21571916 -0.1138947; -0.0094794035 -0.2467292 … -0.07204455 -0.24131268] (tracked), Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.

In [19]:
prob = ODEProblem(dTdt_NN,T₀,saveat = [1,100],p1)
#prob2 = ODEProblem(f,T₀,tspan_train,p1)
s = diffeq_adjoint(p1,prob,Tsit5(),u0=T₀,abstol=1e-8,reltol=1e-6)
println(s[1:5])
Flux.Tracker.collect(s)

MethodError: MethodError: no method matching promote_tspan(::Array{Float32,1})
Closest candidates are:
  promote_tspan(!Matched::Nothing) at C:\Users\daddyj\.julia\packages\DiffEqBase\E16PL\src\problems\problem_utils.jl:13
  promote_tspan(!Matched::Tuple{T,T}) where T at C:\Users\daddyj\.julia\packages\DiffEqBase\E16PL\src\problems\problem_utils.jl:7
  promote_tspan(!Matched::Tuple{T1,T2}) where {T1, T2} at C:\Users\daddyj\.julia\packages\DiffEqBase\E16PL\src\problems\problem_utils.jl:9
  ...

In [84]:
function predict_adjoint(T₀)
  Flux.Tracker.collect(diffeq_adjoint(p1,prob,Tsit5(),u0=T₀,saveat=t_train,reltol=1e-6, abstol=1e-8))
end

#neural_pde_prediction(T₀) = neural_ode(dTdt_NN, T₀, tspan_train, Tsit5(), saveat=t_train, reltol=1e-7, abstol=1e-9)

predict_adjoint (generic function with 2 methods)

In [90]:
opt = ADAM(0.1)

data = Iterators.repeated((T₀, T_cs[:, 1:n_train]), 1000)
data = [(T₀, T_cs[:, 1:n_train])]

loss_function(T₀, T_da) = sum(abs2, T_da .- predict_adjoint(T₀))

loss_function (generic function with 1 method)

In [93]:
# Callback function to observe training.
cb = function ()
    #println(ps)
    println(DiffEqFlux.destructure(ann)[1:10])
    loss = loss_function(T₀, T_cs[:, 1:n_train]) # Not very generalizable...
    println("loss = $loss")
    loss < 10 && Flux.stop()
end

#33 (generic function with 1 method)

In [1]:
for i in 1:20
    Flux.train!(loss_function, ps, data, opt, cb = cb)
end

UndefVarError: UndefVarError: cb not defined

In [19]:
Flux.train!(loss_function, ps, data, opt, cb = cb)

loss = 180.21361539083918 (tracked)
loss = 180.21361539083918 (tracked)
loss = 180.21361539083918 (tracked)
loss = 180.21361539083918 (tracked)
loss = 180.21361539083918 (tracked)
loss = 180.21361539083918 (tracked)
loss = 180.21361539083918 (tracked)
loss = 180.21361539083918 (tracked)
loss = 180.21361539083918 (tracked)
loss = 180.21361539083918 (tracked)
loss = 180.21361539083918 (tracked)
loss = 180.21361539083918 (tracked)
loss = 180.21361539083918 (tracked)
loss = 180.21361539083918 (tracked)
loss = 180.21361539083918 (tracked)
loss = 180.21361539083918 (tracked)
loss = 180.21361539083918 (tracked)
loss = 180.21361539083918 (tracked)
loss = 180.21361539083918 (tracked)
loss = 180.21361539083918 (tracked)
loss = 180.21361539083918 (tracked)
loss = 180.21361539083918 (tracked)
loss = 180.21361539083918 (tracked)
loss = 180.21361539083918 (tracked)
loss = 180.21361539083918 (tracked)
loss = 180.21361539083918 (tracked)
loss = 180.21361539083918 (tracked)
loss = 180.21361539083918 (t

InterruptException: InterruptException:

In [96]:
tspan = (t[1], t[end]) ./ 86400
nn_pred = diffeq_adjoint(p1,prob,Tsit5(),u0=T₀,abstol=1e-8,reltol=1e-6,saveat=t ./86400) |> Flux.data
z_cs = coarse_grain(z, cr)

anim = @gif for n=1:10:Nt
    t_str = @sprintf("%.2f", t[n] / 86400)
    plot(T_cs[:, n], z_cs, linewidth=2,
         xlim=(19, 20), ylim=(-100, 0), label="Data",
         xlabel="Temperature (C)", ylabel="Depth (z)",
         title="Free convection: $t_str days",
         legend=:bottomright, show=false)
    if n <= n_train
        plot!(nn_pred[:, n], z_cs, linewidth=2, label="Neural ODE (train)", show=false)
    else
        plot!(nn_pred[:, n], z_cs, linewidth=2, linestyle=:dash, label="Neural ODE (test)", show=false)
    end
end

BoundsError: BoundsError: attempt to access 576-element Array{Array{Float64,1},1} at index [581]