In [4]:
using Flux
using DifferentialEquations
using DiffEqFlux
using LinearAlgebra
using JLD2
using Plots
using Statistics

using Flux: @epochs

using CuArrays
using CUDAdrv; CUDAdrv.name(CuDevice(0))

"GeForce GTX 960"

In [5]:
file = jldopen("../data/ocean_convection_profiles.jld2");

In [6]:
Is = keys(file["timeseries/t"])

Nz = file["grid/Nz"]
Nt = length(Is)

t = zeros(Nt)
T = T_data = zeros(Nt, Nz)

for (i, I) in enumerate(Is)
    t[i] = file["timeseries/t/$I"]
    T[i, :] = file["timeseries/T/$I"][1, 1, 2:Nz+1]
end

In [7]:
z = file["grid/zC"]

256-element Array{Float64,1}:
 -99.8046875
 -99.4140625
 -99.0234375
 -98.6328125
 -98.2421875
 -97.8515625
 -97.4609375
 -97.0703125
 -96.6796875
 -96.2890625
 -95.8984375
 -95.5078125
 -95.1171875
   ⋮        
  -4.4921875
  -4.1015625
  -3.7109375
  -3.3203125
  -2.9296875
  -2.5390625
  -2.1484375
  -1.7578125
  -1.3671875
  -0.9765625
  -0.5859375
  -0.1953125

In [8]:
function coarse_grain(data, resolution)
    @assert length(data) % resolution == 0
    s = length(data) / resolution
    
    data_cs = zeros(resolution)
    for i in 1:resolution
        t = data[Int((i-1)*s+1):Int(i*s)]
        data_cs[i] = mean(t)
    end
    
    return data_cs
end

coarse_grain (generic function with 1 method)

In [9]:
coarse_resolution = cr = 32
T_cs = zeros(Nt, coarse_resolution)
for n=1:Nt
    T_cs[n, :] = coarse_grain(T[n, :], coarse_resolution)
end

T_cs = transpose(T_cs) |> gpu;

In [10]:
T₀ = T_cs[:, 1]
n_train = round(Int, Nt/2)
t_train = t[1:n_train] ./ 86400
tspan_train = (t_train[1], t_train[end])

└ @ GPUArrays C:\Users\daddyj\.julia\packages\GPUArrays\tIMl5\src\indexing.jl:16


(0.0, 3.993078849612772)

In [11]:
#dTdt_NN = Chain(Dense(cr, 2cr, tanh),
#                Dense(2cr, cr))
#cr = 33
ann = Chain(Dense(cr, 2cr, tanh),
            Dense(2cr, 4cr, tanh),
            Dense(4cr, 2cr, tanh),
                Dense(2cr, cr)) |> gpu

n_ann = 2*cr^2 + 2cr + 2*cr^2+cr

#function dTdt_NN(du,u,p,t)
#    du = DiffEqFlux.restructure(ann,p[1:n_ann])(u)
#end

4192

In [12]:
p1 = Flux.data(DiffEqFlux.destructure(ann))
p2 = Float64[1,1]
p3 = param(p1)
ps = Flux.params(p3,T₀)

Params([Float32[-0.19673109, -0.105196595, 0.2448945, 0.18395615, 0.027187467, 0.013831139, -0.025221705, 0.15300018, 0.17175686, 0.13892204  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] (tracked)])

In [13]:
cr_z = Nt/cr
#Dzᶠ = 1/cr_z * Tridiagonal(zeros(cr-1), -ones(cr),ones(cr-1))
#Dzᶜ = 1/cr_z * Tridiagonal( -ones(cr-1), ones(cr),zeros(cr-1))
Dzᶠ = Tridiagonal(zeros(cr-1), -ones(cr),ones(cr-1)) |> Array{Float32}
Dzᶠ = Tridiagonal(Dzᶠ)
Dzᶜ = 1/cr_z*Tridiagonal( -ones(cr-1), ones(cr),zeros(cr-1)) |> Array{Float32}
Dzᶜ = Tridiagonal(Dzᶜ)
Dzᶠ[cr, cr] = 0
Dzᶜ[1, 1] = 0
Dzᶠ

32×32 Tridiagonal{Float32,Array{Float32,1}}:
 -1.0   1.0    ⋅     ⋅     ⋅     ⋅   …    ⋅     ⋅     ⋅     ⋅     ⋅    ⋅ 
  0.0  -1.0   1.0    ⋅     ⋅     ⋅        ⋅     ⋅     ⋅     ⋅     ⋅    ⋅ 
   ⋅    0.0  -1.0   1.0    ⋅     ⋅        ⋅     ⋅     ⋅     ⋅     ⋅    ⋅ 
   ⋅     ⋅    0.0  -1.0   1.0    ⋅        ⋅     ⋅     ⋅     ⋅     ⋅    ⋅ 
   ⋅     ⋅     ⋅    0.0  -1.0   1.0       ⋅     ⋅     ⋅     ⋅     ⋅    ⋅ 
   ⋅     ⋅     ⋅     ⋅    0.0  -1.0  …    ⋅     ⋅     ⋅     ⋅     ⋅    ⋅ 
   ⋅     ⋅     ⋅     ⋅     ⋅    0.0       ⋅     ⋅     ⋅     ⋅     ⋅    ⋅ 
   ⋅     ⋅     ⋅     ⋅     ⋅     ⋅        ⋅     ⋅     ⋅     ⋅     ⋅    ⋅ 
   ⋅     ⋅     ⋅     ⋅     ⋅     ⋅        ⋅     ⋅     ⋅     ⋅     ⋅    ⋅ 
   ⋅     ⋅     ⋅     ⋅     ⋅     ⋅        ⋅     ⋅     ⋅     ⋅     ⋅    ⋅ 
   ⋅     ⋅     ⋅     ⋅     ⋅     ⋅   …    ⋅     ⋅     ⋅     ⋅     ⋅    ⋅ 
   ⋅     ⋅     ⋅     ⋅     ⋅     ⋅        ⋅     ⋅     ⋅     ⋅     ⋅    ⋅ 
   ⋅     ⋅     ⋅     ⋅     ⋅     ⋅        ⋅     ⋅     ⋅     ⋅     ⋅

In [17]:
#neural_pde_prediction(T₀) = neural_ode(dudt_, T₀, tspan_train, Tsit5(), saveat=t_train, reltol=1e-7, abstol=1e-9)
function dudt_(du,u,p,t)
    x = (Dzᶜ*(DiffEqFlux.restructure(ann,p[1:length(p)])(u).*(Dzᶠ*u)))
    for i in 1:length(u)
        du[i] = x[i]
    end
    #du[length(u)] = p[end-1]*u[end] + p[end]*u[end]
    #println("du",du)
end

dudt_ (generic function with 1 method)

In [18]:
#foretold(u,p,t) = DiffEqFlux.restructure(ann,p[1:n_ann])(u)

In [19]:
prob = ODEProblem(dudt_,T₀,tspan_train,p3)
diffeq_adjoint(p3,prob,Tsit5(),u0=T₀,abstol=1e-8,reltol=1e-6)
#solve(prob)

Tracked 32×7 CuArray{Float32,2}:
 19.0156  19.0156  19.0156  19.0156  19.0156  19.0156  19.0156
 19.0469  19.0469  19.047   19.047   19.0471  19.0473  19.0474
 19.0781  19.0784  19.0787  19.0792  19.08    19.081   19.0824
 19.1094  19.1092  19.1089  19.1084  19.1077  19.1068  19.1055
 19.1406  19.1407  19.1408  19.141   19.1412  19.1416  19.1422
 19.1719  19.1718  19.1717  19.1715  19.1711  19.1707  19.17  
 19.2031  19.2033  19.2036  19.2039  19.2045  19.2054  19.2065
 19.2344  19.2343  19.2343  19.2342  19.234   19.2338  19.2336
 19.2656  19.2656  19.2657  19.2657  19.2657  19.2658  19.2659
 19.2969  19.2969  19.297   19.2971  19.2973  19.2976  19.2979
 19.3281  19.328   19.3278  19.3275  19.3271  19.3265  19.3257
 19.3594  19.3594  19.3595  19.3596  19.3598  19.36    19.3603
 19.3906  19.3907  19.3909  19.3911  19.3914  19.3919  19.3925
  ⋮                                            ⋮              
 19.6406  19.6406  19.6407  19.6407  19.6408  19.6409  19.641 
 19.6719  19.672   19.

In [20]:
function predict_adjoint()
  diffeq_adjoint(p3,prob,Tsit5(),u0=T₀,saveat=t_train,reltol=1e-6, abstol=1e-8)
end

predict_adjoint (generic function with 1 method)

In [21]:
opt = ADAM(0.1)

#data = [(T₀, T_cs[:, 1:n_train])]
data = Iterators.repeated((T₀, T_cs[:, 1:n_train]), 1000)

loss_function(T₀, T_data) = sum(abs2, T_data .- predict_adjoint())

loss_function (generic function with 1 method)

In [22]:
# Callback function to observe training.
cb = function ()
    #println(DiffEqFlux.destructure(ann)[1:10])
    loss = loss_function(T₀, T_cs[:, 1:n_train]) # Not very generalizable...
    println("loss = $loss")
    loss < 100 && Flux.stop()
end

#3 (generic function with 1 method)

In [None]:
Flux.train!(loss_function, ps, data, opt, cb = cb)

In [47]:
#for _ in 1:20
#    Flux.train!(loss_function, ps, data, opt, cb = cb)
#end

Chain(Dense(32, 64, tanh), Dense(64, 128, tanh), Dense(128, 64, tanh), Dense(64, 32))

In [18]:
tspan = (t[1], t[end]) ./ 86400
#nn_pred = neural_ode(dudt_, T₀, tspan, Tsit5(), saveat=t ./86400, reltol=1e-7, abstol=1e-9) |> Flux.data
prob2 = ODEProblem(dudt_,T₀,tspan,p3)
nn_pred = diffeq_adjoint(p3,prob2,Tsit5(),u0=T₀,abstol=1e-8,reltol=1e-6,saveat=t ./86400) |> Flux.data
z_cs = coarse_grain(z, cr)

anim = @gif for n=1:10:Nt
    t_str = @sprintf("%.2f", t[n] / 86400)
    #println("here1")
    plot(T_cs[:, n], z_cs, linewidth=2,
         xlim=(19, 20), ylim=(-100, 0), label="Data",
         xlabel="Temperature (C)", ylabel="Depth (z)",
         title="Free convection: $t_str days",
         legend=:bottomright, show=false)
    #println("here2")
    if n <= n_train
        plot!(nn_pred[:, n], z_cs, linewidth=2, label="Neural ODE (train)", show=false)
    else
        plot!(nn_pred[:, n], z_cs, linewidth=2, linestyle=:dash, label="Neural ODE (test)", show=false)
    end
end

┌ Info: Saved animation to 
│   fn = C:\Users\daddyj\Documents\6S898-climate-parameterization\notebooks\tmp.gif
└ @ Plots C:\Users\daddyj\.julia\packages\Plots\Iuc9S\src\animation.jl:95


In [135]:
nn_pred = diffeq_adjoint(p3,prob2,Tsit5(),u0=T₀,abstol=1e-8,reltol=1e-6,saveat=t ./86400) |> Flux.data


32×1153 Array{Float64,2}:
 19.0156  19.0153  19.015   19.0147  …  18.6447  18.6444  18.6441  18.6438
 19.0469  19.0471  19.0472  19.0474     19.2497  19.2498  19.25    19.2502
 19.0781  19.0776  19.0771  19.0766     18.512   18.5115  18.511   18.5106
 19.1094  19.1084  19.1074  19.1065     17.9969  17.9959  17.9949  17.994 
 19.1406  19.1412  19.1418  19.1423     19.7957  19.7962  19.7968  19.7974
 19.1719  19.1724  19.1729  19.1734  …  19.7699  19.7704  19.7709  19.7714
 19.2031  19.2022  19.2013  19.2004     18.166   18.1651  18.1642  18.1633
 19.2344  19.2343  19.2343  19.2342     19.1856  19.1856  19.1856  19.1855
 19.2656  19.2656  19.2655  19.2655     19.1993  19.1992  19.1992  19.1991
 19.2969  19.2973  19.2978  19.2982     19.7993  19.7997  19.8002  19.8006
 19.3281  19.3284  19.3287  19.3289  …  19.6353  19.6356  19.6358  19.6361
 19.3594  19.3586  19.3579  19.3572     18.5278  18.5271  18.5264  18.5256
 19.3906  19.3911  19.3917  19.3922     19.984   19.9845  19.985   19.9856