# Couple of differential equations and deep learning

In [1]:
##
using Lux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationOptimJL, Random, Plots
using DataFrames
using CSV
using ComponentArrays
using OptimizationOptimisers
using Flux
using Plots
using LaTeXStrings
rng = Random.default_rng()
Random.seed!(14);


# Gen
##
function model2(du, u, p, t)
    r, α = p
    du .= r .* u .* (1 .- u ./ α)
end
u_0 = [1.0]
p_data = [0.2, 30]
tspan_data = (0.0, 30)
prob_data = ODEProblem(model2, u_0, tspan_data, p_data)
data_solve = solve(prob_data, Tsit5(), abstol=1e-12, reltol=1e-12, saveat=1)
data_withoutnois = Array(data_solve)
data = data_withoutnois #+ Float32(2e-1)*randn(eltype(data_withoutnois), size(data_withoutnois))
tspan_predict = (0.0, 40)
prob_predict = ODEProblem(model2, u_0, tspan_predict, p_data)
test_data = solve(prob_predict, Tsit5(), abstol=1e-12, reltol=1e-12, saveat=1)
plot(test_data)

##
ann_node = Lux.Chain(Lux.Dense(1, 10, tanh), Lux.Dense(10, 1))
p, st = Lux.setup(rng, ann_node)
function model2_nn(du, u, p, t)
    du[1] = 0.1 * ann_node([t], p, st)[1][1] * u[1] - 0.1 * u[1]
end
prob_nn = ODEProblem(model2_nn, u_0, tspan_data, ComponentArray(p))
function train(θ)
    Array(concrete_solve(prob_nn, Tsit5(), u_0, θ, saveat=1,
        abstol=1e-6, reltol=1e-6))#,sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP())))
end
#println(train(p))
function loss(θ)
    pred = train(θ)
    sum(abs2, (data .- pred)), pred # + 1e-5*sum(sum.(abs, params(ann)))
end

const losses = []
callback(θ, l, pred) = begin
    push!(losses, l)
    if length(losses) % 50 == 0
        println(losses[end])
    end
    false
end

pinit = ComponentArray(p)
println(loss(p))
callback(pinit, loss(pinit)...)


##
adtype = Optimization.AutoZygote()

optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pinit)

result_neuralode = Optimization.solve(optprob,
    OptimizationOptimisers.ADAM(0.05),
    callback=callback,
    maxiters=500)

optprob2 = remake(optprob, u0=result_neuralode.u)

result_neuralode2 = Optimization.solve(optprob2,
    Optim.LBFGS(),
    callback=callback,
    allow_f_increases=false)


pfinal = result_neuralode2.u

println(pfinal)
prob_nn2 = ODEProblem(model2_nn, u_0, tspan_predict, pfinal)
s_nn = solve(prob_nn2, Tsit5(), saveat=1)

# I(t)
scatter(data_solve.t, data[1, :], label="Training Data")
plot!(test_data, label="Real Data")
plot!(s_nn, label="Neural Networks")
xlabel!("t(day)")
ylabel!("I(t)")
title!("Logistic Growth Model(I(t))")
savefig("Figures/logisticIt.png")
# R(t)
f(x) = 2 * (1 - x / p_data[2]) + 1
plot((f.(test_data))', label=L"R_t = 2(1-I(t)/K)+1")
plot!((f.(s_nn))', label=L"R_t = NN(t)")
xlabel!("t(day)")
ylabel!("Effective Reproduction Number")
title!("Logistic Growth Model(Rt)")
savefig("Figures/logisticRt.png")

(8330.73547088911, [1.0 0.8771770665184498 0.7293117579561292 0.5843960520602985 0.45760441410349345 0.353328550747374 0.27049872551384274 0.20601254263609214 0.15639938986701096 0.1184992296544495 0.08967138289781403 0.06780223512301703 0.05123945751219878 0.03870867291804422 0.029234821427150757 0.02207544853999207 0.016666863974797684 0.01258182049377875 0.009496966744758339 0.007167802832776178 0.005409399777060485 0.004081984268505488 0.0030801055786887563 0.0023239084597578602 0.0017532549406910083 0.0013226248433428189 0.0009976971002976143 0.0007525376093134055 0.0005675935199794983 0.00042805002154461583 0.0003228281154815041])
8265.245440927507


8260.345587115187
8246.85820325101


8226.177374084065
8189.218981351939
8101.911114036986


7644.234922979894
1016.1580996122227
871.6756464709824


701.0408236743888
0.4712655276221033


0.18086628628791138


0.05654521592013416


0.049716995565033825




└ @ SciMLBase /Users/aidishage/.julia/packages/SciMLBase/aft1j/src/integrator_interface.jl:611


DimensionMismatch: DimensionMismatch: arrays could not be broadcast to a common size; got a dimension with lengths 31 and 29