In [2]:
using Printf
using Statistics

using Flux
using DifferentialEquations
using DiffEqFlux
using LinearAlgebra
using JLD2
using Plots
using Pkg

using Flux: @epochs
#ENV["GRDIR"]=""
#Pkg.build("GR")

In [3]:
file = jldopen("../data/ocean_convection_profiles.jld2");

In [4]:
Is = keys(file["timeseries/t"])

Nz = file["grid/Nz"]
Nt = length(Is)

t = zeros(Nt)
T = T_data = zeros(Nt, Nz)

for (i, I) in enumerate(Is)
    t[i] = file["timeseries/t/$I"]
    T[i, :] = file["timeseries/T/$I"][1, 1, 2:Nz+1]
end

In [5]:
keys(file["timeseries/T/0"])

1×1×258 CartesianIndices{3,Tuple{Base.OneTo{Int64},Base.OneTo{Int64},Base.OneTo{Int64}}}:
[:, :, 1] =
 CartesianIndex(1, 1, 1)

[:, :, 2] =
 CartesianIndex(1, 1, 2)

[:, :, 3] =
 CartesianIndex(1, 1, 3)

...

[:, :, 256] =
 CartesianIndex(1, 1, 256)

[:, :, 257] =
 CartesianIndex(1, 1, 257)

[:, :, 258] =
 CartesianIndex(1, 1, 258)

In [6]:
z = file["grid/zC"]

256-element Array{Float64,1}:
 -99.8046875
 -99.4140625
 -99.0234375
 -98.6328125
 -98.2421875
 -97.8515625
 -97.4609375
 -97.0703125
 -96.6796875
 -96.2890625
 -95.8984375
 -95.5078125
 -95.1171875
   ⋮        
  -4.4921875
  -4.1015625
  -3.7109375
  -3.3203125
  -2.9296875
  -2.5390625
  -2.1484375
  -1.7578125
  -1.3671875
  -0.9765625
  -0.5859375
  -0.1953125

In [7]:

anim = @gif for n=1:10:Nt
    t_str = @sprintf("%.2f", t[n] / 86400)
    plot(T[n, :], z, linewidth=2,
         xlim=(19, 20), ylim=(-100, 0), label="",
         xlabel="Temperature (C)", ylabel="Depth (z)",
         title="Free convection: $t_str days", show=false)
end

display(anim)

┌ Info: Saved animation to 
│   fn = C:\Users\daddyj\Dropbox (MIT)\6.S898\6S898-climate-parameterization\notebooks\tmp.gif
└ @ Plots C:\Users\daddyj\.julia\packages\Plots\Iuc9S\src\animation.jl:95


In [8]:
function coarse_grain(data, resolution)
    @assert length(data) % resolution == 0
    s = length(data) / resolution
    
    data_cs = zeros(resolution)
    for i in 1:resolution
        t = data[Int((i-1)*s+1):Int(i*s)]
        data_cs[i] = mean(t)
    end
    
    return data_cs
end

coarse_grain (generic function with 1 method)

In [9]:
coarse_resolution = cr = 32
T_cs = zeros(Nt, coarse_resolution)
for n=1:Nt
    T_cs[n, :] = coarse_grain(T[n, :], coarse_resolution)
end

T_cs = transpose(T_cs) |> Array;

w_cs = zeros(Nt, coarse_resolution)
for n=1:Nt
    w_cs[n,:] = coarse_grain(T[n,:],coarse_resolution)
end
w_cs = transpose(w_cs) |> Array;

In [10]:
cr_z = Nt/cr
Dzᶠ = 1/cr_z * Tridiagonal(-ones(cr-1), ones(cr), zeros(cr-1))
Dzᶜ = 1/cr_z * Tridiagonal(zeros(cr-1), -ones(cr), ones(cr-1))
Dzᶠ[1, 1] = 0
Dzᶜ[cr, cr] = 0

0

In [11]:
ann = Chain(Dense(cr, 2cr, tanh),
            Dense(2cr, 2cr, tanh),
                Dense(2cr, cr))

n_ann = 2*cr^2 + 2cr + 4*cr^2+2cr+2*cr^2+cr


8352

In [26]:
function dTdt_NN(du,u,p,t)
    T = u
    du = Dzᶜ*(DiffEqFlux.restructure(ann,p[1:n_ann])(u).*(Dzᶠ*u))
end

foretold(u::Array{Float64,1},p::Array{Float64,1},t) = Dzᶜ*(ann(u).*(Dzᶠ*u))

p1 = Flux.data(DiffEqFlux.destructure(ann))
T₀ = T_cs[:, 1]
ps = Flux.params(p1,T₀)

n_train = round(Int, Nt/2)
t_train = t[1:n_train] ./ 86400
tspan_train = (t_train[1], t_train[end])

(0.0, 3.993078849612772)

In [27]:
println(typeof(T₀))
println(typeof(p1))
p2 = convert(Array{Float64},p1)
println(typeof(p2))
#foretold(T₀,p2,0.0)

Array{Float64,1}
Array{Float32,1}
Array{Float64,1}


In [57]:
prob = ODEProblem(dTdt_NN,T₀,tspan_train,p1)
prob2 = ODEProblem(f,T₀,tspan_train,p1)
s = diffeq_adjoint(ps,prob,Tsit5(),u0=T₀,abstol=1e-8,reltol=1e-6)
size(s)

MethodError: MethodError: no method matching getindex(::Tracker.Params, ::UnitRange{Int64})
Closest candidates are:
  getindex(::Any, !Matched::AbstractTrees.ImplicitRootState) at C:\Users\daddyj\.julia\packages\AbstractTrees\z1wBY\src\AbstractTrees.jl:344

In [52]:
function predict_adjoint()
  diffeq_adjoint(ps,prob,Tsit5(),u0=T₀,saveat=t_train,reltol=1e-7, abstol=1e-9)
end

#neural_pde_prediction(T₀) = neural_ode(dTdt_NN, T₀, tspan_train, Tsit5(), saveat=t_train, reltol=1e-7, abstol=1e-9)

predict_adjoint (generic function with 1 method)

In [53]:
opt = ADAM(0.1)

data = Iterators.repeated(((T₀, T_cs[:, 1:n_train])), 1000)

loss_function(T₀, T_da) = sum(abs2, T_da .- predict_adjoint())

loss_function (generic function with 1 method)

In [54]:
T_cs[:, 1:n_train]

32×576 Array{Float64,2}:
 19.0156  19.0156  19.0156  19.0156  …  19.0141  19.0141  19.014   19.0141
 19.0469  19.0469  19.0469  19.0469     19.0467  19.0469  19.0469  19.0469
 19.0781  19.0781  19.0781  19.0781     19.0784  19.0781  19.0782  19.0783
 19.1094  19.1094  19.1094  19.1094     19.1098  19.1099  19.1097  19.1095
 19.1406  19.1406  19.1406  19.1406     19.1402  19.1402  19.1404  19.1404
 19.1719  19.1719  19.1719  19.1719  …  19.1721  19.1721  19.1721  19.1722
 19.2031  19.2031  19.2031  19.2031     19.2031  19.2032  19.2031  19.203 
 19.2344  19.2344  19.2344  19.2344     19.2345  19.2344  19.2346  19.2347
 19.2656  19.2656  19.2656  19.2656     19.2659  19.266   19.2657  19.2658
 19.2969  19.2969  19.2969  19.2969     19.2971  19.2968  19.2969  19.2967
 19.3281  19.3281  19.3281  19.3281  …  19.3281  19.3284  19.3284  19.3284
 19.3594  19.3594  19.3594  19.3594     19.3596  19.3597  19.3597  19.3597
 19.3906  19.3906  19.3906  19.3906     19.3908  19.3909  19.3907  19.3908


In [55]:
# Callback function to observe training.
cb = function ()
    loss = loss_function(T₀, T_cs[:, 1:n_train]) # Not very generalizable...
    println("loss = $loss")
    loss < 10 && Flux.stop()
end

#9 (generic function with 1 method)

In [56]:
Flux.train!(loss_function, ps, data, opt, cb = cb)

MethodError: MethodError: no method matching getindex(::Tracker.Params, ::UnitRange{Int64})
Closest candidates are:
  getindex(::Any, !Matched::AbstractTrees.ImplicitRootState) at C:\Users\daddyj\.julia\packages\AbstractTrees\z1wBY\src\AbstractTrees.jl:344

In [23]:
tspan = (t[1], t[end]) ./ 86400
nn_pred = neural_ode(dTdt_NN, T₀, tspan, Tsit5(), saveat=t ./86400, reltol=1e-7, abstol=1e-9) |> Flux.data
z_cs = coarse_grain(z, cr)

anim = @gif for n=1:10:Nt
    t_str = @sprintf("%.2f", t[n] / 86400)
    plot(T_cs[:, n], z_cs, linewidth=2,
         xlim=(19, 20), ylim=(-100, 0), label="Data",
         xlabel="Temperature (C)", ylabel="Depth (z)",
         title="Free convection: $t_str days",
         legend=:bottomright, show=false)
    if n <= n_train
        plot!(nn_pred[:, n], z_cs, linewidth=2, label="Neural ODE (train)", show=false)
    else
        plot!(nn_pred[:, n], z_cs, linewidth=2, linestyle=:dash, label="Neural ODE (test)", show=false)
    end
end

┌ Info: Saved animation to 
│   fn = C:\Users\daddyj\Dropbox (MIT)\6.S898\6S898-climate-parameterization\notebooks\tmp.gif
└ @ Plots C:\Users\daddyj\.julia\packages\Plots\Iuc9S\src\animation.jl:95
