In [2]:
include("./BayesianUtilities-main/BayesianUtilities.jl")
using Distributions, LinearAlgebra, Plots, Random, .BayesianUtilities, Zygote
using DataStructures: Queue, enqueue!, dequeue!
using Flux.Optimise
using CSV, DataFrames, JLD

In [4]:
include("./ICM functions/predict_ICM.jl");
include("./ICM functions/optimize_ICM.jl");

In [5]:
data = CSV.read(".\\Covid_RIVM.csv",DataFrame, delim=";");
observation = Float64.(data.Tested_positive)
T = length(observation);

In [6]:
# Newtonian dynamics assumption
A = [1. 1.; 0. 1.]
B = [1. 0.]

# Priors
p_z_init = MvNormal(zeros(2),diagm(0=>ones(2)))
#p_W = Wishart(2,diagm(0=>ones(2)))
W = diagm(0=>ones(2))

# Non-conjugate loglikelihood function
function log_likelihood(y,z)
    x = exp((B*z)[1]) # B*z returns 1 dimensional array
    logpdf(Poisson(x),y)
end

log_likelihood (generic function with 1 method)

In [7]:
function approx_message_z(obs::Real,mes_income::MvNormal,η::Real)
    opt = AdaMax(η) #change at this location: different optimizer
    q = mes_income # initialize variational distribution with closed form incoming message
    _, _, λ_q, _, _ = exp_family(q) # initial variational parameters
    λ_m = zeros(6) # initial message parameters
    violation = 0 # keep track of number of updates lead to invalid parameters 
    converge = 1.0; # convergent condition
    while converge > 1e-4
        logp(z) = log_likelihood(obs,z)
        grad = cvi(logp,q) # grad estimates
        λ_m_old = copy(λ_m); #check convergence
        update!(opt,λ_m,λ_m - grad)
        λ_q = λ_q .+ λ_m # update variational parameters
        try
            q = exp_family(MvNormal, λ_q) # update q
            converge = norm(λ_m - λ_m_old) 
        catch
            λ_q = λ_q .- λ_m # avoid updates that violate the support of Gaussian
            violation += 1
            converge = 1.0
            @show violation
        end
    end
    return λ_m # return approximate message parameters
end

approx_message_z (generic function with 1 method)

In [8]:
function filter_CVI(W, T)
    time = [];
    forward_estimates = Array{MvNormal}(undef, T);
    λ_m = approx_message_z(observation[1],p_z_init,0.4);   
    forward_estimates[1] = collide(p_z_init,Canonical(MvNormal, λ_m));
    for t=2:T
       t1 = @elapsed predict = A*forward_estimates[t-1] + MvNormal(zeros(2),Matrix(Hermitian(inv(W))));
       t2 = @elapsed λ_m = approx_message_z(observation[t],predict,0.4);
       t3 = @elapsed forward_estimates[t] = collide(predict,Canonical(MvNormal, λ_m));
        time = append!(time,t1+t2+t3); #start from t=2
    end
    return forward_estimates, time
end


function smooth(forward_estimates,W, T)
    smooth_estimates = Array{MvNormal}(undef, T);
    smooth_estimates[T] = forward_estimates[T];
    joint_dists = Array{MvNormal}(undef, T-1);
    for t=T-1:-1:1
        smooth_estimates[t], joint_dists[t] = transit(forward_estimates[t],smooth_estimates[t+1],A,W);
    end
    return smooth_estimates, joint_dists
end

smooth (generic function with 1 method)

In [9]:
Random.seed!(1)
@time begin
forward_estimates_cvi, time_cvi = filter_CVI(W,T);
smooth_estimates_CVI, joint_dists_CVI = smooth(forward_estimates_cvi, W, T);
end
;

691.379332 seconds (2.74 G allocations: 221.649 GiB, 8.21% gc time, 1.79% compilation time)


In [10]:
Random.seed!(10)
FE_cvi = []; #store FE values at each time step
for t=1:T
    if t==1
        FE_temp = -entropy(smooth_estimates_CVI[t]) + cross_entropy(smooth_estimates_CVI[t], p_z_init);
        FE_temp -= normal_conditional_entropy(smooth_estimates_CVI[t], smooth_estimates_CVI[t+1], joint_dists_CVI[t]);
        FE_temp += transit(smooth_estimates_CVI[t],smooth_estimates_CVI[t+1],joint_dists_CVI[t],A,W)
        Av_t = 0;
        for s=1:500
            Av_t -= log_likelihood(observation[t],rand(smooth_estimates_CVI[t]));
        end
        FE_temp += Av_t/500;
        append!(FE_cvi,FE_temp);
    elseif t==T
        Av_t = 0;
        for s=1:500
            Av_t -= log_likelihood(observation[t],rand(smooth_estimates_CVI[t]));
        end
        FE_temp = Av_t/500;
        append!(FE_cvi,FE_temp);
    else
        FE_temp = -normal_conditional_entropy(smooth_estimates_CVI[t], smooth_estimates_CVI[t+1], joint_dists_CVI[t]);
        FE_temp =+ transit(smooth_estimates_CVI[t],smooth_estimates_CVI[t+1],joint_dists_CVI[t],A,W)
        Av_t = 0;
        for s=1:500
            Av_t -= log_likelihood(observation[t],rand(smooth_estimates_CVI[t]));
        end
        FE_temp += Av_t/500;
        append!(FE_cvi,FE_temp);
    end
end
sum(FE_cvi)

2277.861348917162

In [11]:
state1_est_CVI = zeros(T)
var1_est_CVI = zeros(T)

for t=1:T
    state1_est_CVI[t] = mean(smooth_estimates_CVI[t])[1]
    var1_est_CVI[t] = var(smooth_estimates_CVI[t])[1]
end

In [12]:
save(".\\results\\state1_estimates_CVI.jld","state1_estimate",state1_est_CVI);
save(".\\results\\var1_estimates_CVI.jld","var1_estimate",var1_est_CVI);
save(".\\results\\inference_time_cvi.jld","inference_time",time_cvi);