In [1]:
using Printf
using Statistics

using Flux
using DifferentialEquations
using DiffEqFlux
using LinearAlgebra
using JLD2
using Plots

using Flux: @epochs

In [2]:
file = jldopen("../data/ocean_convection_profiles.jld2");

In [3]:
Is = keys(file["timeseries/t"])

Nz = file["grid/Nz"]
Nt = length(Is)

t = zeros(Nt)
T = T_data = zeros(Nt, Nz)

for (i, I) in enumerate(Is)
    t[i] = file["timeseries/t/$I"]
    T[i, :] = file["timeseries/T/$I"][1, 1, 2:Nz+1]
end

In [4]:
z = file["grid/zC"]
'''
anim = @gif for n=1:10:Nt
    t_str = @sprintf("%.2f", t[n] / 86400)
    plot(T[n, :], z, linewidth=2,
         xlim=(19, 20), ylim=(-100, 0), label="",
         xlabel="Temperature (C)", ylabel="Depth (z)",
         title="Free convection: $t_str days", show=false)
end

display(anim)
'''

┌ Info: Saved animation to 
│   fn = C:\Users\daddyj\Documents\6S898-climate-parameterization\notebooks\tmp.gif
└ @ Plots C:\Users\daddyj\.julia\packages\Plots\Iuc9S\src\animation.jl:95


'\'': ASCII/Unicode U+0027 (category Po: Punctuation, other)

In [5]:
function coarse_grain(data, resolution)
    @assert length(data) % resolution == 0
    s = length(data) / resolution
    
    data_cs = zeros(resolution)
    for i in 1:resolution
        t = data[Int((i-1)*s+1):Int(i*s)]
        data_cs[i] = mean(t)
    end
    
    return data_cs
end

coarse_grain (generic function with 1 method)

In [6]:
coarse_resolution = cr = 32
T_cs = zeros(Nt, coarse_resolution)
for n=1:Nt
    T_cs[n, :] = coarse_grain(T[n, :], coarse_resolution)
end

T_cs = transpose(T_cs) |> Array;

In [7]:
T₀ = T_cs[:, 1]
n_train = round(Int, Nt/2)
t_train = t[1:n_train] ./ 86400
tspan_train = (t_train[1], t_train[end])

(0.0, 3.993078849612772)

In [8]:
#dTdt_NN = Chain(Dense(cr, 2cr, tanh),
#                Dense(2cr, cr))
#cr = 33
ann = Chain(Dense(cr, 2cr, tanh),
            Dense(2cr, 4cr, tanh),
            Dense(4cr, 2cr, tanh),
                Dense(2cr, cr))

n_ann = 2*cr^2 + 2cr + 2*cr^2+cr

#function dTdt_NN(du,u,p,t)
#    du = DiffEqFlux.restructure(ann,p[1:n_ann])(u)
#end

4192

In [9]:
p1 = Flux.data(DiffEqFlux.destructure(ann))
p2 = Float64[1,1]
p3 = param([p1;p2])
ps = Flux.params(p3,T₀)

Params([[-0.23029476404190063, -0.06634944677352905, 0.03485292196273804, 0.1414080262184143, 0.21903347969055176, 0.020103812217712402, 0.005270600318908691, 0.10784345865249634, 0.10896903276443481, 0.20171856880187988  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0] (tracked)])

In [10]:
cr_z = Nt/cr
#Dzᶠ = 1/cr_z * Tridiagonal(zeros(cr-1), -ones(cr),ones(cr-1))
#Dzᶜ = 1/cr_z * Tridiagonal( -ones(cr-1), ones(cr),zeros(cr-1))
Dzᶠ = Tridiagonal(zeros(cr-1), -ones(cr),ones(cr-1))
Dzᶜ = 1/cr_z*Tridiagonal( -ones(cr-1), ones(cr),zeros(cr-1))
Dzᶠ[cr, cr] = 0
Dzᶜ[1, 1] = 0
Dzᶠ

32×32 Tridiagonal{Float64,Array{Float64,1}}:
 -1.0   1.0    ⋅     ⋅     ⋅     ⋅   …    ⋅     ⋅     ⋅     ⋅     ⋅    ⋅ 
  0.0  -1.0   1.0    ⋅     ⋅     ⋅        ⋅     ⋅     ⋅     ⋅     ⋅    ⋅ 
   ⋅    0.0  -1.0   1.0    ⋅     ⋅        ⋅     ⋅     ⋅     ⋅     ⋅    ⋅ 
   ⋅     ⋅    0.0  -1.0   1.0    ⋅        ⋅     ⋅     ⋅     ⋅     ⋅    ⋅ 
   ⋅     ⋅     ⋅    0.0  -1.0   1.0       ⋅     ⋅     ⋅     ⋅     ⋅    ⋅ 
   ⋅     ⋅     ⋅     ⋅    0.0  -1.0  …    ⋅     ⋅     ⋅     ⋅     ⋅    ⋅ 
   ⋅     ⋅     ⋅     ⋅     ⋅    0.0       ⋅     ⋅     ⋅     ⋅     ⋅    ⋅ 
   ⋅     ⋅     ⋅     ⋅     ⋅     ⋅        ⋅     ⋅     ⋅     ⋅     ⋅    ⋅ 
   ⋅     ⋅     ⋅     ⋅     ⋅     ⋅        ⋅     ⋅     ⋅     ⋅     ⋅    ⋅ 
   ⋅     ⋅     ⋅     ⋅     ⋅     ⋅        ⋅     ⋅     ⋅     ⋅     ⋅    ⋅ 
   ⋅     ⋅     ⋅     ⋅     ⋅     ⋅   …    ⋅     ⋅     ⋅     ⋅     ⋅    ⋅ 
   ⋅     ⋅     ⋅     ⋅     ⋅     ⋅        ⋅     ⋅     ⋅     ⋅     ⋅    ⋅ 
   ⋅     ⋅     ⋅     ⋅     ⋅     ⋅        ⋅     ⋅     ⋅     ⋅     ⋅

In [11]:
#neural_pde_prediction(T₀) = neural_ode(dudt_, T₀, tspan_train, Tsit5(), saveat=t_train, reltol=1e-7, abstol=1e-9)
function dudt_(du,u,p,t)
    x = (Dzᶜ*(DiffEqFlux.restructure(ann,p[1:length(p)-2])(u).*(Dzᶠ*u)))
    for i in 1:length(u)
        du[i] = x[i]
    end
    #du[length(u)] = p[end-1]*u[end] + p[end]*u[end]
    #println("du",du)
end

dudt_ (generic function with 1 method)

In [12]:
#foretold(u,p,t) = DiffEqFlux.restructure(ann,p[1:n_ann])(u)

In [13]:
prob = ODEProblem(dudt_,T₀,tspan_train,p3)
diffeq_adjoint(p3,gpu(prob),Tsit5(),u0=gpu(T₀),abstol=1e-8,reltol=1e-6)
#solve(prob)

Tracked 32×4 Array{Float64,2}:
 19.0156  19.0156  19.0156  19.0156
 19.0469  19.0468  19.0459  19.0451
 19.0781  19.0782  19.079   19.0798
 19.1094  19.1092  19.1074  19.1054
 19.1406  19.1408  19.1422  19.1437
 19.1719  19.1717  19.1704  19.1691
 19.2031  19.203   19.202   19.2005
 19.2344  19.2347  19.2383  19.242 
 19.2656  19.2656  19.2651  19.2648
 19.2969  19.2968  19.2957  19.2946
 19.3281  19.3281  19.3281  19.3281
 19.3594  19.3594  19.3601  19.3608
 19.3906  19.3906  19.3907  19.3907
  ⋮                                
 19.6406  19.6406  19.6408  19.641 
 19.6719  19.6719  19.6722  19.6725
 19.7031  19.703   19.7022  19.7013
 19.7344  19.7344  19.7346  19.7348
 19.7656  19.7654  19.7628  19.76  
 19.7969  19.797   19.7983  19.8   
 19.8281  19.8284  19.8311  19.8337
 19.8594  19.8593  19.8583  19.8575
 19.8906  19.8904  19.8876  19.8845
 19.9219  19.9221  19.9248  19.9278
 19.9531  19.9532  19.9536  19.954 
 19.9844  19.9843  19.9831  19.9819

In [14]:
function predict_adjoint()
  diffeq_adjoint(p3,gpu(prob),Tsit5(),u0=gpu(T₀),saveat=t_train,reltol=1e-6, abstol=1e-8)
end

predict_adjoint (generic function with 1 method)

In [15]:
opt = ADAM(0.1)

#data = [(T₀, T_cs[:, 1:n_train])]
data = Iterators.repeated((gpu(T₀), gpu(T_cs[:, 1:n_train])), 1000)

loss_function(T₀, T_data) = sum(abs2, gpu(T_data) .- predict_adjoint())

loss_function (generic function with 1 method)

In [16]:
# Callback function to observe training.
cb = function ()
    #println(DiffEqFlux.destructure(ann)[1:10])
    loss = loss_function(T₀, T_cs[:, 1:n_train]) # Not very generalizable...
    println("loss = $loss")
    loss < 100 && Flux.stop()
end

#3 (generic function with 1 method)

In [None]:
Flux.train!(loss_function, ps, data, opt, cb = cb)

loss = 179.71609112700816 (tracked)
loss = 178.64590654025687 (tracked)
loss = 174.4296263144533 (tracked)
loss = 170.3160594338408 (tracked)
loss = 168.06691547614838 (tracked)


InterruptException: InterruptException:




In [47]:
#for _ in 1:20
#    Flux.train!(loss_function, ps, data, opt, cb = cb)
#end

Chain(Dense(32, 64, tanh), Dense(64, 128, tanh), Dense(128, 64, tanh), Dense(64, 32))

In [18]:
tspan = (t[1], t[end]) ./ 86400
#nn_pred = neural_ode(dudt_, T₀, tspan, Tsit5(), saveat=t ./86400, reltol=1e-7, abstol=1e-9) |> Flux.data
prob2 = ODEProblem(dudt_,T₀,tspan,p3)
nn_pred = diffeq_adjoint(p3,prob2,Tsit5(),u0=T₀,abstol=1e-8,reltol=1e-6,saveat=t ./86400) |> Flux.data
z_cs = coarse_grain(z, cr)

anim = @gif for n=1:10:Nt
    t_str = @sprintf("%.2f", t[n] / 86400)
    #println("here1")
    plot(T_cs[:, n], z_cs, linewidth=2,
         xlim=(19, 20), ylim=(-100, 0), label="Data",
         xlabel="Temperature (C)", ylabel="Depth (z)",
         title="Free convection: $t_str days",
         legend=:bottomright, show=false)
    #println("here2")
    if n <= n_train
        plot!(nn_pred[:, n], z_cs, linewidth=2, label="Neural ODE (train)", show=false)
    else
        plot!(nn_pred[:, n], z_cs, linewidth=2, linestyle=:dash, label="Neural ODE (test)", show=false)
    end
end

┌ Info: Saved animation to 
│   fn = C:\Users\daddyj\Documents\6S898-climate-parameterization\notebooks\tmp.gif
└ @ Plots C:\Users\daddyj\.julia\packages\Plots\Iuc9S\src\animation.jl:95


In [135]:
nn_pred = diffeq_adjoint(p3,prob2,Tsit5(),u0=T₀,abstol=1e-8,reltol=1e-6,saveat=t ./86400) |> Flux.data


32×1153 Array{Float64,2}:
 19.0156  19.0153  19.015   19.0147  …  18.6447  18.6444  18.6441  18.6438
 19.0469  19.0471  19.0472  19.0474     19.2497  19.2498  19.25    19.2502
 19.0781  19.0776  19.0771  19.0766     18.512   18.5115  18.511   18.5106
 19.1094  19.1084  19.1074  19.1065     17.9969  17.9959  17.9949  17.994 
 19.1406  19.1412  19.1418  19.1423     19.7957  19.7962  19.7968  19.7974
 19.1719  19.1724  19.1729  19.1734  …  19.7699  19.7704  19.7709  19.7714
 19.2031  19.2022  19.2013  19.2004     18.166   18.1651  18.1642  18.1633
 19.2344  19.2343  19.2343  19.2342     19.1856  19.1856  19.1856  19.1855
 19.2656  19.2656  19.2655  19.2655     19.1993  19.1992  19.1992  19.1991
 19.2969  19.2973  19.2978  19.2982     19.7993  19.7997  19.8002  19.8006
 19.3281  19.3284  19.3287  19.3289  …  19.6353  19.6356  19.6358  19.6361
 19.3594  19.3586  19.3579  19.3572     18.5278  18.5271  18.5264  18.5256
 19.3906  19.3911  19.3917  19.3922     19.984   19.9845  19.985   19.9856