In [35]:
using Revise
using MatrixProductBP, MatrixProductBP.Models
using Graphs, IndexedGraphs, Statistics, Random, LinearAlgebra
import ProgressMeter; ProgressMeter.ijulia_behavior(:clear)
using TensorTrains: summary_compact
using SparseArrays;

In [36]:
T = 30
N = 30
seed = 4

c = 4
gg = erdos_renyi(N, c/N; seed)
g = IndexedGraph(gg)

λ_unif = 0.8
ρ_unif = 0.6
λ = zeros(N,N)
for i in CartesianIndices(λ)
    if !iszero(g.A[i])
        λ[i] = rand()
        # λ[i] = λ_unif
    end
end
λ = sparse(λ)
ρ = rand(N)
# ρ = fill(ρ_unif,N)
γ = 0.8;

# T = 7
# N = 2
# seed = 6

# A = [0 1; 1 0]
# g = IndexedGraph(A)

# λ_unif = 0.7
# ρ_unif = 0.6
# λ = sparse(λ_unif .* A)
# # λ = sparse([0 1e-12; λ_unif 0])
# ρ = fill(ρ_unif, N)
# γ = 0.5;
;

In [37]:
sis = SIS_heterogeneous(λ, ρ, T; γ);
bp_obs = mpbp(sis);

In [38]:
g.A

30×30 SparseMatrixCSC{Int64, Int64} with 134 stored entries:
⠀⠀⠀⢢⠤⡀⢡⠂⠁⠐⠴⡀⠀⢠⠠
⠠⣀⠀⠀⠀⡀⢈⡀⠀⢀⠀⡀⠛⠄⠐
⠀⠣⠀⠠⢀⠐⠀⠠⠈⠀⠂⠠⢁⠁⠀
⠡⠒⠂⠰⠀⡀⠊⠀⠥⠀⠀⠁⠈⠤⠅
⢁⠀⠀⢀⠂⠀⠁⠃⡀⡨⡁⠀⠒⠁⠁
⠐⠣⠀⠠⠈⡀⠄⠀⠁⠈⠀⡠⣉⠠⠐
⠀⣀⠛⠄⠅⠐⠂⡄⠜⠀⠃⡘⢀⠐⠒
⠀⠂⠐⠀⠀⠀⠁⠁⠁⠀⠐⠀⠘⠀⠀

In [39]:
obs_times = collect(0:T)
nobs = floor(Int, N * length(obs_times) * 1.0)
obs_fraction = nobs / N
rng = MersenneTwister(seed)
X, observed = draw_node_observations!(bp_obs, nobs, times = obs_times .+ 1, softinf=Inf; rng);

In [40]:
X

30×31 Matrix{Int64}:
 2  1  1  2  1  2  1  2  1  1  2  1  2  …  2  2  1  2  2  1  2  1  1  2  1  1
 2  1  2  1  2  1  2  1  2  1  2  1  1     1  2  2  1  2  2  1  1  2  2  1  2
 2  1  2  1  2  1  2  1  2  1  2  1  2     1  2  1  2  2  2  1  2  2  1  2  1
 2  1  2  1  2  2  2  2  2  1  2  1  2     1  2  2  1  2  1  2  1  2  1  2  1
 2  2  2  2  1  2  2  1  2  2  1  2  1     2  1  1  2  2  2  1  2  2  2  1  1
 2  1  1  2  2  1  2  1  1  2  1  2  1  …  2  1  1  2  1  2  1  2  1  2  1  1
 2  2  2  2  2  2  1  2  1  2  2  1  1     2  1  2  2  2  2  1  2  2  1  2  1
 2  1  2  1  1  2  1  2  1  2  1  2  1     2  1  2  1  2  1  2  1  2  1  2  1
 2  2  2  1  2  2  1  1  2  2  1  2  2     2  2  2  1  2  1  2  1  2  1  2  2
 2  1  2  1  1  2  1  2  2  2  1  2  1     1  2  1  2  1  1  2  1  2  1  1  1
 ⋮              ⋮              ⋮        ⋱     ⋮              ⋮              ⋮
 2  2  2  1  2  2  1  2  2  1  2  2  1     2  2  1  2  1  1  2  2  1  2  1  1
 2  2  2  1  2  1  2  1  2  2  1  2  1     

In [41]:
λinit = 0.5
ρinit = 0.5

A_complete = ones(N,N) - I
g_complete = IndexedGraph(A_complete)
λ_complete = sparse(λinit.*A_complete)
ρ_complete = fill(ρinit, N)

sis_inf = SIS_heterogeneous(g_complete, λ_complete, ρ_complete, T; γ, ϕ=deepcopy(bp_obs.ϕ))
bp_inf = mpbp(sis_inf);

In [42]:
svd_trunc = TruncBond(10)

iters, cb = inference_parameters!(bp_inf, method=2, maxiter=100, λstep=0.01, ρstep=0.01);

[32mRunning Gradient Ascent: iter 2    Time: 0:01:33[39m[K

[32mRunning Gradient Ascent: iter 3    Time: 0:02:18[39m[K

[32mRunning Gradient Ascent: iter 4    Time: 0:03:02[39m[K

[32mRunning Gradient Ascent: iter 5    Time: 0:03:47[39m[K

[32mRunning Gradient Ascent: iter 6    Time: 0:04:32[39m[K

[32mRunning Gradient Ascent: iter 7    Time: 0:05:17[39m[K

[32mRunning Gradient Ascent: iter 8    Time: 0:06:02[39m[K

[32mRunning Gradient Ascent: iter 9    Time: 0:06:46[39m[K

[32mRunning Gradient Ascent: iter 10    Time: 0:07:30[39m[K

[32mRunning Gradient Ascent: iter 11    Time: 0:08:15[39m[K

[32mRunning Gradient Ascent: iter 12    Time: 0:08:59[39m[K

[32mRunning Gradient Ascent: iter 13    Time: 0:09:44[39m[K

[32mRunning Gradient Ascent: iter 14    Time: 0:10:29[39m[K

[32mRunning Gradient Ascent: iter 15    Time: 0:11:14[39m[K

[32mRunning Gradient Ascent: iter 16    Time: 0:11:58[39m[K

[32mRunning Gradient Ascent: iter 17    Time: 0

In [43]:
@show cb.data

Excessive output truncated after 1669503 bytes.

101-element Vector{PARAMS{Float64}}:
 PARAMS{Float64}([[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5  …  0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5  …  0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5  …  0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5  …  0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5  …  0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5  …  0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5  …  0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5  …  0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5  …  0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], 

 PARAMS{Float64}([[0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401  …  0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401], [0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531  …  0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531], [0.51509898495, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531  …  0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531], [0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531  …  0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.49990000500000004, 0.6739244576664531, 0.5048990050500001, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531], [0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.6739244576664531, 0.6739244576664531, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401  …  0.3698501866941401, 0.3698501866941401, 0.6739244576664531, 0.3698501866941401, 0.3698501866941401, 0.6739244576664531, 0.3698501866941401, 0.3698501866941401, 0.6739244576664531, 0.3698501866941401], [0.3698501866941401, 0.6739244576664531, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.6739244576664531, 0.3698501866941401, 0.6739244576664531, 0.3698501866941401, 0.3698501866941401  …  0.3698501866941401, 0.6739244576664531, 0.6739244576664531, 0.3698501866941401, 0.6739244576664531, 0.3698501866941401, 0.3698501866941401, 0.6739244576664531, 0.3698501866941401, 0.3698501866941401], [0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401  …  0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401], [0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531  …  0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531], [0.6739244576664531, 0.5202499747995001, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.5253999293000404, 0.504848515149495, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531  …  0.6739244576664531, 0.509998995, 0.5047475504949502, 0.6739244576664531, 0.499800029998, 0.6739244576664531, 0.6739244576664531, 0.49480202969802006, 0.6739244576664531, 0.6739244576664531], [0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.6739244576664531  …  0.3698501866941401, 0.6739244576664531, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401]  …  [0.6739244576664531, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.6739244576664531, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401  …  0.6739244576664531, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401], [0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.6739244576664531, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.6739244576664531  …  0.3698501866941401, 0.6739244576664531, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.6739244576664531], [0.6739244576664531, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.6739244576664531, 0.3698501866941401, 0.6739244576664531, 0.3698501866941401, 0.3698501866941401, 0.6739244576664531  …  0.3698501866941401, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.3698501866941401, 0.3698501866941401, 0.6739244576664531, 0.3698501866941401, 0.3698501866941401, 0.6739244576664531], [0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531  …  0.6739244576664531, 0.48990299970101003, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531], [0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531  …  0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531], [0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531  …  0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531], [0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531  …  0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531], [0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531  …  0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531], [0.6739244576664531, 0.3698501866941401, 0.6739244576664531, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.6739244576664531  …  0.6739244576664531, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.6739244576664531], [0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531  …  0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531, 0.6739244576664531]], [0.4992505247725683, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.5858790430363657, 0.5408330793222207, 0.44279372217784657, 0.3698501866941401, 0.3698501866941401, 0.6739244576664531  …  0.6739244576664531, 0.46086509131070436, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.3698501866941401, 0.4992505247725683, 0.3698501866941401, 0.4254309647157214, 0.6739244576664531])
