# Couple of differential equations and deep learning

## IMPORTANT: Activate Julia environment first

In [2]:
using Pkg
Pkg.activate(".")

[32m[1m  Activating[22m[39m project at `~/Desktop/MyProjects/Julia_Tutorial_on_AI4MathBiology`


In [5]:
##
using Lux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationOptimJL, Random, Plots
using DataFrames
using CSV
using ComponentArrays
using OptimizationOptimisers
using Flux
using Plots
using LaTeXStrings
rng = Random.default_rng()
Random.seed!(1234);


# Gen
##
function model2(du, u, p, t)
    r, α = p
    du .= r .* u .* (1 .- u ./ α)
end
u_0 = [1.0]
p_data = [0.2, 30]
tspan_data = (0.0, 30)
prob_data = ODEProblem(model2, u_0, tspan_data, p_data)
data_solve = solve(prob_data, Tsit5(), abstol=1e-12, reltol=1e-12, saveat=1)
data_withoutnois = Array(data_solve)
data = data_withoutnois #+ Float32(2e-1)*randn(eltype(data_withoutnois), size(data_withoutnois))
tspan_predict = (0.0, 40)
prob_predict = ODEProblem(model2, u_0, tspan_predict, p_data)
test_data = solve(prob_predict, Tsit5(), abstol=1e-12, reltol=1e-12, saveat=1)
plot(test_data)

##
ann_node = Lux.Chain(Lux.Dense(1, 10, tanh), Lux.Dense(10, 1))
p, st = Lux.setup(rng, ann_node)
function model2_nn(du, u, p, t)
    du[1] = 0.1 * ann_node([t], p, st)[1][1] * u[1] - 0.1 * u[1]
end
prob_nn = ODEProblem(model2_nn, u_0, tspan_data, ComponentArray(p))
function train(θ)
    Array(concrete_solve(prob_nn, Tsit5(), u_0, θ, saveat=1,
        abstol=1e-6, reltol=1e-6))#,sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP())))
end
#println(train(p))
function loss(θ)
    pred = train(θ)
    sum(abs2, (data .- pred)), pred # + 1e-5*sum(sum.(abs, params(ann)))
end

const losses = []
callback(θ, l, pred) = begin
    push!(losses, l)
    if length(losses) % 50 == 0
        println(losses[end])
    end
    false
end

pinit = ComponentArray(p)
println(loss(p))
callback(pinit, loss(pinit)...)


##
adtype = Optimization.AutoZygote()

optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pinit)

result_neuralode = Optimization.solve(optprob,
    OptimizationOptimisers.ADAM(0.05),
    callback=callback,
    maxiters=500)

optprob2 = remake(optprob, u0=result_neuralode.u)

result_neuralode2 = Optimization.solve(optprob2,
    Optim.LBFGS(),
    callback=callback,
    allow_f_increases=false)


pfinal = result_neuralode2.u

println(pfinal)
prob_nn2 = ODEProblem(model2_nn, u_0, tspan_predict, pfinal)
s_nn = solve(prob_nn2, Tsit5(), saveat=1)

# I(t)
scatter(data_solve.t, data[1, :], label="Training Data")
plot!(test_data, label="Real Data")
plot!(s_nn, label="Neural Networks")
xlabel!("t(day)")
ylabel!("I(t)")
title!("Logistic Growth Model(I(t))")
savefig("Figures/logisticIt.png")
# R(t)
f(x) = 2 * (1 - x / p_data[2]) + 1
plot((f.(test_data))', label=L"R_t = 2(1-I(t)/K)+1")
plot!((f.(s_nn))', label=L"R_t = NN(t)")
xlabel!("t(day)")
ylabel!("Effective Reproduction Number")
title!("Logistic Growth Model(Rt)")
savefig("Figures/logisticRt.png")

(8302.391970147039, [1.0 0.8855641189397063 0.756511397493139 0.6326126042255029 0.5243433431241015 0.434004963091067 0.3600777364373986 0.299949804809959 0.25103783225124093 0.21112674937348183 0.17841543415053107 0.1514697977054114 0.12915666230690864 0.11058215498349554 0.0950400342679121 0.08197023036537579 0.07092727867473227 0.06155462749344566 0.05356596079755513 0.04672977687765638 0.04085824673311616 0.035797833164541576 0.031422753182942006 0.027628984941347585 0.024330559400754345 0.021455630986385792 0.01894416436916513 0.01674573740910931 0.014817612512391514 0.013123697335280472 0.011633194091910545])
1144.002831878164




32.453717564967214
12.433062811563678
16.43297706675565


15.70220191004424
10.764147709075822
104.10115362352289


22.570456109792058
137.14465902869964
8.892138590746951




└ @ SciMLBase /Users/aidishage/.julia/packages/SciMLBase/aft1j/src/integrator_interface.jl:611


DimensionMismatch: DimensionMismatch: arrays could not be broadcast to a common size; got a dimension with lengths 31 and 2