In [46]:
using Revise
using MatrixProductBP, MatrixProductBP.Models
using Graphs, IndexedGraphs, Statistics, Random, LinearAlgebra
import ProgressMeter; ProgressMeter.ijulia_behavior(:clear)
using TensorTrains: summary_compact
using SparseArrays;

In [47]:
T = 15
N = 4
seed = 4

c = 4
gg = erdos_renyi(N, c/N; seed)
g = IndexedGraph(gg)

λ_unif = 0.8
ρ_unif = 0.6
λ = zeros(N,N)
for i in CartesianIndices(λ)
    if !iszero(g.A[i])
        # λ[i] = rand()
        λ[i] = λ_unif
    end
end
λ = sparse(λ)
# ρ = rand(N)
ρ = fill(ρ_unif,N)
γ = 0.8;

# T = 7
# N = 2
# seed = 6

# A = [0 1; 1 0]
# g = IndexedGraph(A)

# λ_unif = 0.7
# ρ_unif = 0.6
# λ = sparse(λ_unif .* A)
# # λ = sparse([0 1e-12; λ_unif 0])
# ρ = fill(ρ_unif, N)
# γ = 0.5;
;

In [48]:
sis = SIS_heterogeneous(λ, ρ, T; γ);
bp_obs = mpbp(sis);

In [49]:
g.A

4×4 SparseMatrixCSC{Int64, Int64} with 12 stored entries:
 ⋅  1  2  4
 1  ⋅  3  5
 2  3  ⋅  6
 4  5  6  ⋅

In [50]:
obs_times = collect(0:T)
nobs = floor(Int, N * length(obs_times) * 1.0)
obs_fraction = nobs / N
rng = MersenneTwister(seed)
X, observed = draw_node_observations!(bp_obs, nobs, times = obs_times .+ 1, softinf=Inf; rng);

In [51]:
X

4×16 Matrix{Int64}:
 2  2  1  2  1  1  2  1  1  2  1  2  2  2  2  2
 2  2  2  2  1  2  2  2  1  2  2  1  2  1  2  1
 2  1  2  1  2  2  2  1  2  2  1  2  2  2  1  2
 2  2  1  2  1  1  2  1  2  1  2  1  2  2  1  2

In [52]:
λinit = 0.5
ρinit = 0.5

A_complete = ones(N,N) - I
g_complete = IndexedGraph(A_complete)
λ_complete = sparse(λinit.*A_complete)
ρ_complete = fill(ρinit, N)

sis_inf = SIS_heterogeneous(g_complete, λ_complete, ρ_complete, T; γ, ϕ=deepcopy(bp_obs.ϕ))
bp_inf = mpbp(sis_inf);

In [53]:
svd_trunc = TruncBond(10)

iters, cb = inference_parameters!(bp_inf, method=2, maxiter=200, λstep=0.01, ρstep=0.01);

[32mRunning Gradient Ascent: iter 2    Time: 0:00:00[39m[K

[32mRunning Gradient Ascent: iter 3    Time: 0:00:01[39m[K

[32mRunning Gradient Ascent: iter 4    Time: 0:00:01[39m[K

[32mRunning Gradient Ascent: iter 5    Time: 0:00:01[39m[K

[32mRunning Gradient Ascent: iter 6    Time: 0:00:02[39m[K

[32mRunning Gradient Ascent: iter 7    Time: 0:00:02[39m[K

[32mRunning Gradient Ascent: iter 8    Time: 0:00:02[39m[K

[32mRunning Gradient Ascent: iter 9    Time: 0:00:03[39m[K

[32mRunning Gradient Ascent: iter 10    Time: 0:00:03[39m[K

[32mRunning Gradient Ascent: iter 11    Time: 0:00:03[39m[K

[32mRunning Gradient Ascent: iter 12    Time: 0:00:04[39m[K

[32mRunning Gradient Ascent: iter 13    Time: 0:00:04[39m[K

[32mRunning Gradient Ascent: iter 14    Time: 0:00:04[39m[K

[32mRunning Gradient Ascent: iter 15    Time: 0:00:05[39m[K

[32mRunning Gradient Ascent: iter 16    Time: 0:00:05[39m[K

[32mRunning Gradient Ascent: iter 17    Time: 0

In [54]:
@show cb.data

cb.data = PARAMS{Float64}[PARAMS{Float64}([[0.5, 0.5, 0.5], [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]], [0.5, 0.5, 0.5, 0.5]), PARAMS{Float64}([[0.495, 0.505, 0.505], [0.505, 0.505, 0.505], [0.505, 0.505, 0.505], [0.505, 0.505, 0.505]], [0.495, 0.495, 0.495, 0.505]), PARAMS{Float64}([[0.49005, 0.49995, 0.51005], [0.51005, 0.51005, 0.51005], [0.51005, 0.51005, 0.51005], [0.51005, 0.51005, 0.51005]], [0.49005, 0.49005, 0.49995, 0.51005]), PARAMS{Float64}([[0.4949505, 0.5049495, 0.5151505], [0.5151505, 0.5151505, 0.5151505], [0.5151505, 0.5151505, 0.5151505], [0.5151505, 0.5151505, 0.5151505]], [0.48514949999999996, 0.48514949999999996, 0.5049495, 0.5151505]), PARAMS{Float64}([[0.490000995, 0.49990000500000004, 0.520302005], [0.520302005, 0.520302005, 0.520302005], [0.520302005, 0.520302005, 0.520302005], [0.520302005, 0.520302005, 0.520302005]], [0.480298005, 0.480298005, 0.49990000500000004, 0.520302005]), PARAMS{Float64}([[0.49490100495, 0.5048990050500001, 0.52550502505], [0.5

201-element Vector{PARAMS{Float64}}:
 PARAMS{Float64}([[0.5, 0.5, 0.5], [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]], [0.5, 0.5, 0.5, 0.5])
 PARAMS{Float64}([[0.495, 0.505, 0.505], [0.505, 0.505, 0.505], [0.505, 0.505, 0.505], [0.505, 0.505, 0.505]], [0.495, 0.495, 0.495, 0.505])
 PARAMS{Float64}([[0.49005, 0.49995, 0.51005], [0.51005, 0.51005, 0.51005], [0.51005, 0.51005, 0.51005], [0.51005, 0.51005, 0.51005]], [0.49005, 0.49005, 0.49995, 0.51005])
 PARAMS{Float64}([[0.4949505, 0.5049495, 0.5151505], [0.5151505, 0.5151505, 0.5151505], [0.5151505, 0.5151505, 0.5151505], [0.5151505, 0.5151505, 0.5151505]], [0.48514949999999996, 0.48514949999999996, 0.5049495, 0.5151505])
 PARAMS{Float64}([[0.490000995, 0.49990000500000004, 0.520302005], [0.520302005, 0.520302005, 0.520302005], [0.520302005, 0.520302005, 0.520302005], [0.520302005, 0.520302005, 0.520302005]], [0.480298005, 0.480298005, 0.49990000500000004, 0.520302005])
 PARAMS{Float64}([[0.49490100495, 0.5048990050500001, 0.525505