In [1]:
using Statistics
using DataInterpolations
using Flux
using OrdinaryDiffEq
using DiffEqSensitivity
using Zygote
using ProgressBars
using Random

┌ Info: Precompiling OrdinaryDiffEq [1dea7af3-3e70-54e6-95c3-0bf5283fa5ed]
└ @ Base loading.jl:1278
┌ Info: Precompiling DiffEqSensitivity [41bf760c-e81c-5289-8e54-58b1f1f8abe2]
└ @ Base loading.jl:1278


In [2]:
T = Float32

const bs = 512
X = [rand(T, 10, 50) for _ in 1:bs*10]

function create_spline(i)
    x = X[i]
    t = x[end, :]
    t = (t .- minimum(t)) ./ (maximum(t) - minimum(t))

    spline = QuadraticInterpolation(x, t)
end

splines = [create_spline(i) for i in tqdm(1:length(X))];

100.0%┣███████████████████████████████████┫ 5120/5120 [00:01<00:00, 6399.4 it/s]


In [3]:
rand_inds = randperm(length(X))

i_sz = size(X[1], 1)
h_sz = 16

use_gpu = true
batches = [[splines[rand_inds[(i-1)*bs+1:i*bs]]] for i in tqdm(1:length(X)÷bs)];

100.0%┣█████████████████████████████████████████┫ 10/10 [00:00<00:00, 73.6 it/s]


In [4]:
data_ = Iterators.cycle(batches)

function call_and_cat(splines, t)
    vals = Zygote.ignore() do
        vals = reduce(hcat,[spline(t) for spline in splines])
    end
    vals |> (use_gpu ? gpu : cpu)
end

function derivative(A::QuadraticInterpolation, t::Number)
    idx = findfirst(x -> x >= t, A.t) - 1
    idx == 0 ? idx += 1 : nothing
    if idx == length(A.t) - 1
        i₀ = idx - 1; i₁ = idx; i₂ = i₁ + 1;
    else
        i₀ = idx; i₁ = i₀ + 1; i₂ = i₁ + 1;
    end
    dl₀ = (2t - A.t[i₁] - A.t[i₂]) / ((A.t[i₀] - A.t[i₁]) * (A.t[i₀] - A.t[i₂]))
    dl₁ = (2t - A.t[i₀] - A.t[i₂]) / ((A.t[i₁] - A.t[i₀]) * (A.t[i₁] - A.t[i₂]))
    dl₂ = (2t - A.t[i₀] - A.t[i₁]) / ((A.t[i₂] - A.t[i₀]) * (A.t[i₂] - A.t[i₁]))
    @views @. A.u[:, i₀] * dl₀ + A.u[:, i₁] * dl₁ + A.u[:, i₂] * dl₂
end

function derivative_call_and_cat(splines, t)
    vals = Zygote.ignore() do
        reduce(hcat,[derivative(spline, t) for spline in splines]) |> (use_gpu ? gpu : cpu)
    end
end

derivative_call_and_cat (generic function with 1 method)

In [5]:
cde = Chain(
    Dense(h_sz, h_sz, relu),
    Dense(h_sz, h_sz*i_sz, tanh),
) |> (use_gpu ? gpu : cpu)

h_to_out = Dense(h_sz, 2) |> (use_gpu ? gpu : cpu)

initial = Dense(i_sz, h_sz) |> (use_gpu ? gpu : cpu)

cde_p, cde_re = Flux.destructure(cde)
initial_p, initial_re = Flux.destructure(initial)
h_to_out_p, h_to_out_re = Flux.destructure(h_to_out);

In [6]:
basic_tgrad(u,p,t) = zero(u)

function predict_func(p, BX)
    By = call_and_cat(BX, 1)

    x0 = call_and_cat(BX, 0)
    i = 1
    j = (i-1)+length(initial_p)

    h0 = initial_re(p[i:j])(x0)

    function dhdt(h,p,t)
        x = derivative_call_and_cat(BX, t)
        bs = size(h, 2)
        a = reshape(cde_re(p)(h), (i_sz, h_sz, bs))
        b = reshape(x, (1, i_sz, bs))

        dh = batched_mul(b,a)[1,:,:]
    end

    i = j+1
    j = (i-1)+length(cde_p)

    tspan = (0.0f0, 0.8f0)

    ff = ODEFunction{false}(dhdt,tgrad=basic_tgrad)
    prob = ODEProblem{false}(ff,h0,tspan,p[i:j])
    sense = InterpolatingAdjoint(autojacvec=ZygoteVJP())
    solver = Tsit5()

    sol = solve(prob,solver,u0=h0,saveat=tspan[end], save_start=false, sensealg=sense)
    #@show sol.destats
    i = j+1
    j = (i-1)+length(h_to_out_p)

    y_hat = h_to_out_re(p[i:j])(sol[end])

    y_hat, By[1:2, :]
end

predict_func (generic function with 1 method)

In [12]:
initial_p

176-element CUDA.CuArray{Float32,1}:
 -0.053237036
  0.32876566
  0.08456616
  0.24081425
 -0.40230474
  0.075121574
  0.40263265
  0.38744345
  0.11204001
  0.24094518
 -0.34157017
 -0.4137195
  0.42026117
  ⋮
  0.0
  0.0
  0.0
  0.0
  0.0
  0.0
  0.0
  0.0
  0.0
  0.0
  0.0
  0.0

In [7]:
function loss_func(p, BX)
    y_hat, y = predict_func(p, BX)

    mean(sum(sqrt.((y .- y_hat).^2), dims=1))
end

p = vcat(initial_p, cde_p, h_to_out_p)

callback = function (p, l)
  display(l)
  return false
end

#16 (generic function with 1 method)

In [8]:
using DiffEqFlux

In [9]:
Zygote.gradient((p)->loss_func(p, first(data_)...),p)
@time Zygote.gradient((p)->loss_func(p, first(data_)...),p)

  1.118676 seconds (4.39 M allocations: 381.118 MiB, 18.04% gc time)


(Float32[-39.13369, -28.500084, 82.741745, -154.48625, 128.3319, -124.11713, 21.34516, -44.174145, -7.20418, 49.7554  …  33.494568, -17.520622, -7.33539, -0.3212571, 12.6459, -21.196257, -21.12806, -5.5857267, -0.16796875, -0.375],)

In [10]:
result_neuralode = DiffEqFlux.sciml_train(loss_func, p, ADAM(0.05),
    data_,
    cb = callback,
    maxiters = 10)

98.03215f0

205.55554f0

[32mloss: 206:  20%|████████                                |  ETA: 0:00:16[39m

119.18447f0

[32mloss: 119:  30%|████████████                            |  ETA: 0:00:12[39m

71.97586f0

[32mloss: 72:  40%|████████████████▍                        |  ETA: 0:00:10[39m

162.17494f0

[32mloss: 162:  50%|████████████████████                    |  ETA: 0:00:08[39m

295.29974f0

[32mloss: 295:  60%|████████████████████████                |  ETA: 0:00:07[39m

50.824574f0

[32mloss: 50.8:  70%|███████████████████████████▎           |  ETA: 0:00:05[39m

57.74757f0

[32mloss: 57.7:  80%|███████████████████████████████▎       |  ETA: 0:00:03[39m

239.55023f0

[32mloss: 240:  90%|████████████████████████████████████    |  ETA: 0:00:02[39m

54.899662f0

[32mloss: 54.9: 100%|███████████████████████████████████████| Time: 0:00:16[39m


50.824574f0

 * Status: success

 * Candidate solution
    Final objective value:     5.489966e+01

 * Found with
    Algorithm:     ADAM

 * Convergence measures
    |x - x'|               = NaN ≰ 0.0e+00
    |x - x'|/|x'|          = NaN ≰ 0.0e+00
    |f(x) - f(x')|         = NaN ≰ 0.0e+00
    |f(x) - f(x')|/|f(x')| = NaN ≰ 0.0e+00
    |g(x)|                 = NaN ≰ 0.0e+00

 * Work counters
    Seconds run:   19  (vs limit Inf)
    Iterations:    10
    f(x) calls:    10
    ∇f(x) calls:   10
