In [1]:
using Statistics
using Printf
using Flux
using DifferentialEquations
using DiffEqFlux
using LinearAlgebra
using JLD2
using Plots

using Flux: @epochs

In [2]:
file = jldopen("../data/ocean_convection_profiles.jld2");

In [3]:
Is = keys(file["timeseries/t"])
println(keys(file["timeseries"]))

Nz = file["grid/Nz"]
Lz = file["grid/Lz"]
Nt = length(Is)
println([Nz,Lz,Nt])

t = zeros(Nt)
T = T_data = zeros(Nt, Nz)
wT = zeros(Nt, Nz)
v = zeros(Nt,Nz)
S = zeros(Nt,Nz)


for (i, I) in enumerate(Is)
    t[i] = file["timeseries/t/$I"]
    T[i, :] = file["timeseries/T/$I"][1, 1, 2:Nz+1]
    wT[i, :] = file["timeseries/wT/$I"][1, 1, 2:Nz+1]
    v[i, :] = file["timeseries/v/$I"][1, 1, 2:Nz+1]
    S[i, :] = file["timeseries/S/$I"][1, 1, 2:Nz+1]
end

["t", "vv", "w", "nu", "kappaT", "kappaS", "S", "dTdt", "v", "uv", "ww", "u", "uw", "T", "wT", "uu", "wS", "vw"]
[256.0, 100.0, 1153.0]


In [5]:
z = file["grid/zC"]

256-element Array{Float64,1}:
 -99.8046875
 -99.4140625
 -99.0234375
 -98.6328125
 -98.2421875
 -97.8515625
 -97.4609375
 -97.0703125
 -96.6796875
 -96.2890625
 -95.8984375
 -95.5078125
 -95.1171875
   ⋮        
  -4.4921875
  -4.1015625
  -3.7109375
  -3.3203125
  -2.9296875
  -2.5390625
  -2.1484375
  -1.7578125
  -1.3671875
  -0.9765625
  -0.5859375
  -0.1953125

In [6]:
function coarse_grain(data, resolution)
    @assert length(data) % resolution == 0
    s = length(data) / resolution
    
    data_cs = zeros(resolution)
    for i in 1:resolution
        t = data[Int((i-1)*s+1):Int(i*s)]
        data_cs[i] = mean(t)
    end
    
    return data_cs
end

coarse_grain (generic function with 1 method)

In [7]:
coarse_resolution = cr = 32

T_cs = zeros(coarse_resolution+2,Nt)
wT_cs = zeros(coarse_resolution+2, Nt)
v_cs = zeros(coarse_resolution+2, Nt)
S_cs =zeros(coarse_resolution+2, Nt)
for n=1:Nt
    T_cs[2:end-1, n] .= coarse_grain(T[n, :], coarse_resolution)
    wT_cs[2:end-1, n] .= coarse_grain(wT[n, :], coarse_resolution)
    v_cs[2:end-1, n] .= coarse_grain(v[n, :], coarse_resolution)
    S_cs[2:end-1, n] .= coarse_grain(S[n, :], coarse_resolution)    
end

T_cs[1,   :] .= T_cs[2,     :]
T_cs[end, :] .= T_cs[end-1, :]

wT_cs[1,   :] .= wT_cs[2,     :]
wT_cs[end, :] .= wT_cs[end-1, :]

v_cs[1,   :] .= v_cs[2,     :]
v_cs[end, :] .= v_cs[end-1, :]

S_cs[1,   :] .= S_cs[2,     :]
S_cs[end, :] .= S_cs[end-1, :]

1153-element view(::Array{Float64,2}, 34, :) with eltype Float64:
 19.98437500000003 
 19.98437471826549 
 19.984374436636237
 19.984374157576745
 19.984371277247593
 19.981201908710993
 19.978617505303095
 19.977228637085094
 19.97579463937481 
 19.973957230666027
 19.971742666169227
 19.96971476370894 
 19.96810126206183 
  ⋮                
 19.65204807845557 
 19.65187701843649 
 19.651724194934523
 19.651607789308834
 19.651502852581864
 19.651348799710888
 19.651191505681844
 19.651055146609114
 19.650895899712605
 19.65073316242373 
 19.65058084340377 
 19.650321221853364

In [8]:
T₀ = T_cs[:, 1]
wT₀ = wT_cs[:,1]
n_train = round(Int, 4*Nt/5)
t_train = t[1:n_train] ./ 86400
tspan_train = (t_train[1], t_train[end])

(0.0, 6.395836659960485)

In [9]:
#dTdt_NN = Chain(Dense(cr, 2cr, tanh),
#                Dense(2cr, cr))
#cr = 33
cr = 34
ann = Chain(Dense(2cr, 4cr, tanh),
            Dense(4cr, 3cr, tanh),
            Dense(3cr, 3cr, tanh),
            Dense(3cr, 2cr, tanh),
            #Dense(8cr, 4cr, tanh),
            #Dense(4cr, 4cr, tanh),
            #Dense(4cr, 2cr, tanh),
                Dense(2cr, cr))

n_ann = 2*cr^2 + 2cr + 2*cr^2+cr

#function dTdt_NN(du,u,p,t)
#    du = DiffEqFlux.restructure(ann,p[1:n_ann])(u)
#end

4726

In [10]:
p1 = Flux.data(DiffEqFlux.destructure(ann)) |> Array{Float64}
p2 = i -> wT_cs[:,i]
p3 = param(p1)
ps = Flux.params(p3,T₀)

Params([[-0.13226325809955597, -0.03209342435002327, -0.015049483627080917, -0.14694732427597046, 0.12121245265007019, 0.15782110393047333, 0.035167377442121506, -0.1462775617837906, 0.07006002217531204, -0.032177653163671494  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] (tracked)])

In [11]:
cr_z = Nt/cr
#Dzᶠ = 1/cr_z * Tridiagonal(zeros(cr-1), -ones(cr),ones(cr-1))
#Dzᶜ = 1/cr_z * Tridiagonal( -ones(cr-1), ones(cr),zeros(cr-1))
Dzᶠ = Tridiagonal(zeros(cr-1), -ones(cr),ones(cr-1))
Dzᶜ = 1/cr_z*Tridiagonal( -ones(cr-1), ones(cr),zeros(cr-1))
Dzᶠ[cr, cr] = 0
Dzᶜ[1, 1] = 0
Dzᶠ

34×34 Tridiagonal{Float64,Array{Float64,1}}:
 -1.0   1.0    ⋅     ⋅     ⋅     ⋅   …    ⋅     ⋅     ⋅     ⋅     ⋅    ⋅ 
  0.0  -1.0   1.0    ⋅     ⋅     ⋅        ⋅     ⋅     ⋅     ⋅     ⋅    ⋅ 
   ⋅    0.0  -1.0   1.0    ⋅     ⋅        ⋅     ⋅     ⋅     ⋅     ⋅    ⋅ 
   ⋅     ⋅    0.0  -1.0   1.0    ⋅        ⋅     ⋅     ⋅     ⋅     ⋅    ⋅ 
   ⋅     ⋅     ⋅    0.0  -1.0   1.0       ⋅     ⋅     ⋅     ⋅     ⋅    ⋅ 
   ⋅     ⋅     ⋅     ⋅    0.0  -1.0  …    ⋅     ⋅     ⋅     ⋅     ⋅    ⋅ 
   ⋅     ⋅     ⋅     ⋅     ⋅    0.0       ⋅     ⋅     ⋅     ⋅     ⋅    ⋅ 
   ⋅     ⋅     ⋅     ⋅     ⋅     ⋅        ⋅     ⋅     ⋅     ⋅     ⋅    ⋅ 
   ⋅     ⋅     ⋅     ⋅     ⋅     ⋅        ⋅     ⋅     ⋅     ⋅     ⋅    ⋅ 
   ⋅     ⋅     ⋅     ⋅     ⋅     ⋅        ⋅     ⋅     ⋅     ⋅     ⋅    ⋅ 
   ⋅     ⋅     ⋅     ⋅     ⋅     ⋅   …    ⋅     ⋅     ⋅     ⋅     ⋅    ⋅ 
   ⋅     ⋅     ⋅     ⋅     ⋅     ⋅        ⋅     ⋅     ⋅     ⋅     ⋅    ⋅ 
   ⋅     ⋅     ⋅     ⋅     ⋅     ⋅        ⋅     ⋅     ⋅     ⋅     ⋅

In [14]:
#neural_pde_prediction(T₀) = neural_ode(dudt_, T₀, tspan_train, Tsit5(), saveat=t_train, reltol=1e-7, abstol=1e-9)
function curry(wT,v,S)
    function dudt_(du,u,p,t)
        #du .= (Dzᶜ*(DiffEqFlux.restructure(ann,p[1:length(p)])(u).*(Dzᶠ*u)))
        i = findmin(abs.(t_train.-t))[2]
        w = @view wT[:,i]
        #vi = v[:,i]
        #si = S[:,i]
        #println(size(w))
        du .= DiffEqFlux.restructure(ann,p[1:length(p)])([u;w])
        #println(findall(x->abs(x-t)<0.0035,t_train))
        
    end
end

curry (generic function with 1 method)

In [15]:
prob = ODEProblem(curry(wT_cs,v_cs,S_cs),T₀,tspan_train,p3)
diffeq_adjoint(p3,prob,Tsit5(),u0=T₀,abstol=1e-8,reltol=1e-6,saveat=t_train)
#solve(prob)

Tracked 34×922 Array{Float64,2}:
 19.0156  19.0223  19.0289  19.0355  …  24.5072  24.512   24.5168  24.5216
 19.0156  19.0171  19.0187  19.0201     19.3338  19.3326  19.3313  19.3301
 19.0469  19.046   19.0451  19.0442     18.4979  18.4974  18.4968  18.4962
 19.0781  19.0826  19.087   19.0914     22.1151  22.1172  22.1192  22.1212
 19.1094  19.1012  19.093   19.085      11.4994  11.4907  11.482   11.4734
 19.1406  19.1395  19.1383  19.1372  …  18.6746  18.6757  18.6768  18.678 
 19.1719  19.1718  19.1717  19.1716     19.2616  19.2616  19.2616  19.2617
 19.2031  19.1991  19.1952  19.1912     14.9791  14.9737  14.9683  14.9629
 19.2344  19.2316  19.2288  19.226      16.3397  16.3365  16.3333  16.3301
 19.2656  19.2652  19.2648  19.2644     18.4855  18.4843  18.4831  18.482 
 19.2969  19.2998  19.3027  19.3056  …  21.9954  21.9977  21.9999  22.0021
 19.3281  19.3241  19.32    19.316      15.895   15.8914  15.888   15.8845
 19.3594  19.3569  19.3544  19.352      16.3264  16.3221  16.3179  

In [16]:
function predict_adjoint()
  diffeq_adjoint(p3,prob,Tsit5(),u0=T₀,saveat=t_train,reltol=1e-6, abstol=1e-8)
end

predict_adjoint (generic function with 1 method)

In [17]:
opt = ADAM(0.1)

#data = [(T₀, T_cs[:, 1:n_train])]
data = Iterators.repeated((T₀, T_cs[:, 1:n_train]), 1000)

loss_function(T₀, T_data) = sum(abs2, T_data .- predict_adjoint())

loss_function (generic function with 1 method)

In [18]:
# Callback function to observe training.
cb = function ()
    #println(DiffEqFlux.destructure(ann)[1:10])
    loss = loss_function(T₀, T_cs[:, 1:n_train]) # Not very generalizable...
    println("loss = $loss")
    loss < 2 && Flux.stop()
end

#7 (generic function with 1 method)

In [19]:
Flux.train!(loss_function, ps, data, opt, cb = cb)

loss = 975766.369701097 (tracked)
loss = 7.686857793021376e6 (tracked)
loss = 1.8358624466933676e6 (tracked)
loss = 510044.5908239147 (tracked)
loss = 1.2737700267459005e6 (tracked)
loss = 922069.578914303 (tracked)
loss = 319198.1721096053 (tracked)
loss = 567877.4484396947 (tracked)
loss = 693489.5430116712 (tracked)
loss = 422137.7516481764 (tracked)
loss = 282559.53883173334 (tracked)
loss = 361146.1964000608 (tracked)
loss = 395070.07368244603 (tracked)
loss = 301478.93742253917 (tracked)
loss = 208794.00532481534 (tracked)
loss = 205809.9071887528 (tracked)
loss = 243854.4695541676 (tracked)
loss = 221785.91934523598 (tracked)
loss = 151160.18661132656 (tracked)
loss = 122291.6860885652 (tracked)
loss = 149228.7648392668 (tracked)
loss = 158188.61204204286 (tracked)
loss = 114484.46094263333 (tracked)
loss = 74655.77563140343 (tracked)
loss = 89113.92098752529 (tracked)
loss = 111366.63730077802 (tracked)
loss = 84271.95684976016 (tracked)
loss = 47571.91523891278 (tracked)
loss 

InterruptException: InterruptException:

In [20]:
#for _ in 1:20
#    Flux.train!(loss_function, ps, data, opt, cb = cb)
#end
cr =32

32

In [23]:
tspan = (t[1], t[end]) ./ 86400
#nn_pred = neural_ode(dudt_, T₀, tspan, Tsit5(), saveat=t ./86400, reltol=1e-7, abstol=1e-9) |> Flux.data
prob2 = ODEProblem(curry(wT_cs,v_cs,S_cs),T₀,tspan,p1)
nn_pred = diffeq_adjoint(p1,prob2,Tsit5(),u0=T₀,abstol=1e-8,reltol=1e-6,saveat=t ./86400) |> Flux.data
z_cs = coarse_grain(z, cr)

anim = @gif for n=1:10:Nt
    t_str = @sprintf("%.2f", t[n] / 86400)
    #println("here1")
    plot(T_cs[2:end-1, n], z_cs, linewidth=2,
         xlim=(19, 20), ylim=(-100, 0), label="Data",
         xlabel="Temperature (C)", ylabel="Depth (z)",
         title="Free convection: $t_str days",
         legend=:bottomright, show=false)
    #println("here2")
    if n <= n_train
        plot!(nn_pred[2:end-1, n], z_cs, linewidth=2, label="Neural ODE (train)", show=false)
    else
        plot!(nn_pred[2:end-1, n], z_cs, linewidth=2, linestyle=:dash, label="Neural ODE (test)", show=false)
    end
end

BoundsError: BoundsError: attempt to access 32-element Array{Float64,1} at index [1:34]

In [135]:
nn_pred = diffeq_adjoint(p3,prob2,Tsit5(),u0=T₀,abstol=1e-8,reltol=1e-6,saveat=t ./86400) |> Flux.data


32×1153 Array{Float64,2}:
 19.0156  19.0153  19.015   19.0147  …  18.6447  18.6444  18.6441  18.6438
 19.0469  19.0471  19.0472  19.0474     19.2497  19.2498  19.25    19.2502
 19.0781  19.0776  19.0771  19.0766     18.512   18.5115  18.511   18.5106
 19.1094  19.1084  19.1074  19.1065     17.9969  17.9959  17.9949  17.994 
 19.1406  19.1412  19.1418  19.1423     19.7957  19.7962  19.7968  19.7974
 19.1719  19.1724  19.1729  19.1734  …  19.7699  19.7704  19.7709  19.7714
 19.2031  19.2022  19.2013  19.2004     18.166   18.1651  18.1642  18.1633
 19.2344  19.2343  19.2343  19.2342     19.1856  19.1856  19.1856  19.1855
 19.2656  19.2656  19.2655  19.2655     19.1993  19.1992  19.1992  19.1991
 19.2969  19.2973  19.2978  19.2982     19.7993  19.7997  19.8002  19.8006
 19.3281  19.3284  19.3287  19.3289  …  19.6353  19.6356  19.6358  19.6361
 19.3594  19.3586  19.3579  19.3572     18.5278  18.5271  18.5264  18.5256
 19.3906  19.3911  19.3917  19.3922     19.984   19.9845  19.985   19.9856