In [1]:
using CSV, DataFrames, Dates, Plots, DiffEqFlux, DifferentialEquations.OrdinaryDiffEq, Statistics

In [2]:
df = DataFrame(CSV.File("DadosMedia.csv"))

datasTudo = df[:,"DatasMedia"]
infTudo = df[:,"InfectadosMedia"]
recTudo = df[:,"RecuperadosMedia"]
decTudo = df[:,"ObitosMedia"]

limiteE = findall(data->data==Date(2020,03,18),datasTudo)[1]
limiteD = findall(data->data==Date(2020,06,30),datasTudo)[1]

datasOnda = datasTudo[limiteE:limiteD]
inf_onda = infTudo[limiteE:limiteD]
rec_onda = recTudo[limiteE:limiteD]
dec_onda = decTudo[limiteE:limiteD]

dados_onda = hcat(inf_onda,rec_onda,dec_onda)

u₀  = [inf_onda[1], rec_onda[1], dec_onda[1]]

nothing

In [3]:
N = 40
M = 10

SIR_model_name = "SIRD"
UODE_model_name = "SIRD_UODE_b"
cost_name = "SSR-max"

function dudtSIR!(du, u, θ, t)
    S, I, R, D = u
    N = S + I + R
    β, γ_R, γ_D = θ

    E_novos = β^2 * I*S / N
    dS = -E_novos 
    dI = E_novos - (γ_R^2 + γ_D^2)*I
    dR = γ_R^2*I
    dD = γ_D^2*I

    du[1] = dS; du[2] = dI; du[3] = dR; du[4] = dD
end

if UODE_model_name == "SIRD_UODE_bSI"
    NN = FastChain(FastDense(2,16,tanh), FastDense(16,16,tanh), FastDense(16,1), (x, θ) -> x.^2)

    function dudt_UODE!(du, u, θ, t)
        S, I, R, D = u
        N = S + I + R + D
        γ_R, γ_D = θ[1:2]
    
        E_novos =  NN([S/N, I], θ[3:end])[1]
        dS = -E_novos 
        dI = E_novos - (γ_R^2 + γ_D^2)*I
        dR = γ_R^2*I
        dD = γ_D^2*I
    
        du[1] = dS; du[2] = dI; du[3] = dR; du[4] = dD
    end
elseif UODE_model_name == "SIRD_UODE_b"
    NN = FastChain(FastDense(4,16,tanh), FastDense(16,16,tanh), FastDense(16,1), (x, θ) -> x.^2)

    function dudt_UODE!(du, u, θ, t)
        S, I, R, D = u
        N = S + I + R + D
        γ_R, γ_D = θ[1:2]

        E_novos =  NN(u, θ[3:end])[1] * I*S / N
        dS = -E_novos 
        dI = E_novos - (γ_R^2 + γ_D^2)*I
        dR = γ_R^2*I
        dD = γ_D^2*I

        du[1] = dS; du[2] = dI; du[3] = dR; du[4] = dD
    end
end

nothing

In [4]:
S₀ = CSV.read(string("parametros/", UODE_model_name, "/condicao_inicial_", cost_name, "_", N, "_dias.csv"), DataFrame, header = 
    false)[1, 1]

dadosTreino = dados_onda[1:N, :]    
fator_reducao = sum([S₀; u₀])

θ₁_SIR = CSV.read(string("parametros/", UODE_model_name, "/parametros_SIRD_", cost_name, "_", N, "_dias.csv"), DataFrame,
    header = false)[1, :]

h = 0.02

modeloSIR = solve(ODEProblem(dudtSIR!, [S₀; u₀], (1., size(dados_onda)[1]), θ₁_SIR), saveat = h)

inf_SIR = modeloSIR[2, :]
rec_SIR = modeloSIR[3, :]
dec_SIR = modeloSIR[4, :]

nothing

In [5]:
θ₁_UODE_amostral = CSV.read(string("parametros/", UODE_model_name, "/parametros_UODE_", cost_name, "_", N, "_dias.csv"),
    DataFrame, header = false)[:, 1:end-1]

θ₁_UODE = [col[1] for col in eachcol(θ₁_UODE_amostral)]

modelo_UODE = solve(ODEProblem(dudt_UODE!, [S₀; u₀] ./ fator_reducao, (1., size(dados_onda)[1]), θ₁_UODE), saveat = h
    ) .* fator_reducao

inf_UODE = modelo_UODE[2, :]
rec_UODE = modelo_UODE[3, :]
dec_UODE = modelo_UODE[4, :]

inf_UODE_amostra = inf_UODE[:, :]
rec_UODE_amostra = rec_UODE[:, :]
dec_UODE_amostra = dec_UODE[:, :]

for m in range(2, M)
    θ₁_UODE = [col[m] for col in eachcol(θ₁_UODE_amostral)]
    
    modelo_UODE = solve(ODEProblem(dudt_UODE!, [S₀; u₀] ./ fator_reducao, (1., size(dados_onda)[1]), θ₁_UODE), saveat = h
        ) .* fator_reducao

    inf_UODE = modelo_UODE[2, :]
    rec_UODE = modelo_UODE[3, :]
    dec_UODE = modelo_UODE[4, :]
    
    inf_UODE_amostra = hcat(inf_UODE_amostra, inf_UODE)
    rec_UODE_amostra = hcat(rec_UODE_amostra, rec_UODE)
    dec_UODE_amostra = hcat(dec_UODE_amostra, dec_UODE)
end

In [6]:
errs = θ₁_UODE_amostral = CSV.read(string("parametros/", UODE_model_name, "/parametros_UODE_", cost_name, "_", N, "_dias.csv"),
    DataFrame, header = false)[:, end]

sorted_errs_indexes = sortperm(errs)

errs[sorted_errs_indexes]

10-element Vector{Float64}:
 1.6115121110310942e-5
 1.6337175862435352e-5
 1.6376157327089985e-5
 1.644103950441966e-5
 1.644166722961457e-5
 1.6445077252323783e-5
 1.645252489533188e-5
 1.64653615831097e-5
 1.6470577859930503e-5
 1.649830264019323e-5

In [7]:
(errs[sorted_errs_indexes][2] - errs[sorted_errs_indexes][1])/errs[sorted_errs_indexes][1]

0.013779279138171195

In [None]:
for best in 1:5
    pl = plot(1.:h:105., inf_UODE_amostra, color = "gray80", label = "", xlim = (1, 105), ylim = (-1_000, 45_000), title = 
    "SIR x UODE para $N dias de treino - INFECTADOS", titlefontsize = 14)

    plot!(pl, inf_onda, lw = 3, color = 1, linestyle = :dot, label = "dados")

    plot!(pl, 1.:h:105., inf_SIR, lw = 2, color = 1, linestyle = :dash, label = "SIR")

    plot!(pl, 1.:h:105., inf_UODE_amostra[:, sorted_errs_indexes[best]], lw = 2, color = 1, label = label = string(best,
            "ᵃ \u0022melhor\u0022 UODE"))

    plot!(pl, [N], seriestype = "vline", label = "", color = "red")

    display(pl)
end

In [None]:
for best in 1:5
    pl = plot(1.:h:105., rec_UODE_amostra, color = "gray80", label = "", xlim = (1, 105), ylim = (-1_000, 65_000), title = 
        "SIR x UODE para $N dias de treino - RECUPERADOS", titlefontsize = 14, legend = :topleft)

    plot!(pl, rec_onda, lw = 3, color = 2, linestyle = :dot, label = "dados")

    plot!(pl, 1.:h:105., rec_SIR, lw = 2, color = 2, linestyle = :dash, label = "SIR")

    plot!(pl, 1.:h:105., rec_UODE_amostra[:, sorted_errs_indexes[best]], lw = 2, color = 2, label = string(best,
            "ᵃ \u0022melhor\u0022 UODE"))

    plot!(pl, [N], seriestype = "vline", label = "", color = "red")

    display(pl)
end

In [None]:
for best in 1:5
    pl = plot(1.:h:105., dec_UODE_amostra, color = "gray80", label = "", xlim = (1, 105), ylim = (-200, 10_000), title = 
        "SIR x UODE para $N dias de treino - DECESSOS", titlefontsize = 14, legend = :topleft)

    plot!(pl, dec_onda, lw = 3, color = 3, linestyle = :dot, label = "dados")

    plot!(pl, 1.:h:105., dec_SIR, lw = 2, color = 3, linestyle = :dash, label = "SIR")

    plot!(pl, 1.:h:105., dec_UODE_amostra[:, sorted_errs_indexes[best]], lw = 2, color = 3, label = label = string(best,
            "ᵃ \u0022melhor\u0022 UODE"))

    plot!(pl, [N], seriestype = "vline", label = "", color = "red")

    display(pl)
end