# Testing inference on Network Diffusion 

Here, I will examine the utility of sampling and variational inference for inferring values from a simple network diffusion model on Erdos-Renyi random graphs. The primary aim of this document is to assess how well inference can scale as the network grows in size and as topology -- in this case, connection probability -- change. 

### Environment
First, check environment to ensure all packages needed are present and document their versions. 

In [4]:
using Pkg

In [5]:
Pkg.status();

[32m[1mStatus[22m[39m `~/Projects/NetworkTopology/Project.toml`
 [90m [76274a88] [39m[37mBijectors v0.8.14[39m
 [90m [0c46a032] [39m[37mDifferentialEquations v6.16.0[39m
 [90m [31c24e10] [39m[37mDistributions v0.24.13[39m
 [90m [7073ff75] [39m[37mIJulia v1.23.1[39m
 [90m [093fc24a] [39m[37mLightGraphs v1.3.5[39m
 [90m [c7f686f2] [39m[37mMCMCChains v4.7.0[39m
 [90m [91a5bcdd] [39m[37mPlots v1.10.4[39m
 [90m [37e2e3b7] [39m[37mReverseDiff v1.5.0[39m
 [90m [f3b207a7] [39m[37mStatsPlots v0.14.19[39m
 [90m [fce5fe82] [39m[37mTuring v0.15.10[39m
 [90m [e88e6eb3] [39m[37mZygote v0.6.3[39m


### Model Setup 

The first step in defining our model will be to initialise a graph on which to run the model. We do this using `LightGraphs` to generate a Erdos-Renyi random graph of size `N`. 

In [7]:
using LightGraphs

N = 5
P = 0.5

G = erdos_renyi(N, P);
L = laplacian_matrix(G);

┌ Info: Precompiling LightGraphs [093fc24a-ae57-5d10-9952-331d41423f4d]
└ @ Base loading.jl:1278


The second step of the modelling process will be to define the ODE model. For network diffusion, this is given by: 

$$ \frac{d\mathbf{u}}{dt} = -\rho \mathbf{L} \mathbf{u} $$ 

We can set this up as a julia function as follows:

In [6]:
NetworkDiffusion(u, p, t) = -p * L * u

NetworkDiffusion (generic function with 1 method)

To run a simulation, we set some initial conditions and define an `ODEProblem` using `DifferentialEquations`

In [10]:
u0 = rand(N)
p = 2.0 
t_span = (0.0,1.0);


In [14]:
using DifferentialEquations

problem = ODEProblem(NetworkDiffusion, eltype(p).(u0), (0.0,1.0), p);
sol = solve(problem, Tsit5(), saveat=0.05)

retcode: Success
Interpolation: 1st order linear
t: 21-element Array{Float64,1}:
 0.0
 0.05
 0.1
 0.15
 0.2
 0.25
 0.3
 0.35
 0.4
 0.45
 0.5
 0.55
 0.6
 0.65
 0.7
 0.75
 0.8
 0.85
 0.9
 0.95
 1.0
u: 21-element Array{Array{Float64,1},1}:
 [0.6257917078751645, 0.07217480676699561, 0.4505773085190672, 0.43029838343893534, 0.10786392392132527]
 [0.5122950028113811, 0.14721969625786693, 0.4315982335201711, 0.3937225004492241, 0.20187069748284486]
 [0.44345463482663255, 0.1980466216169032, 0.4138287803308324, 0.37153773972895626, 0.2598383540181636]
 [0.4017016625427205, 0.2332213124415705, 0.39846172823703885, 0.35808226692143474, 0.2952391603787235]
 [0.37637772781787593, 0.2581146033535708, 0.38574228061554433, 0.34992127869878575, 0.3165502400357113]
 [0.3610196681446935, 0.27612624180308853, 0.37548493976035086, 0.34497193155042905, 0.3291033492629263]
 [0.35169948393176753, 0.28943879840726194, 0.3673440383032259, 0.34196837327952634, 0.33625543659970664]
 [0.3460530990881128, 0.299463

And we can view the solution. 

In [15]:
using Plots
plotly()
plot(sol)

### Inference

Now that we have a model, we generate some data and start to using `Turing` to perform inference.
To do this, we should define a generative model.

Our data $\mathbf{y}$ is given by a normal distribution centered around our model $f(\mathbf{u0}, \rho)$ with variance $\sigma$. 

$$\mathbf{y} = \mathcal{N}(f(\mathbf{u0}, \rho), \sigma)$$

and we assume our paramters are generated from the following distributions: 

$$\sigma \approx \Gamma^{-1}(2, 3)$$ 
$$\rho \approx \mathcal{N}(5,.10.[0,10])$$

We can make this into a `Turing` model. 





In [21]:
using Turing
Turing.setadbackend(:forwarddiff)
@model function fit(data, prob)
    σ ~ InverseGamma(2, 3) # ~ is the tilde character
    ρ ~ truncated(Normal(5,10.0),0.0,10)

    prob = remake(problem, p=ρ)
    predicted = solve(prob, Tsit5(),saveat=0.05)

    for i = 1:length(predicted)
        data[:,i] ~ MvNormal(predicted[i], σ)
    end
end

fit (generic function with 1 method)

To fit this model, we first need to generate some data. We can then feed in our data and our model into the `Turing` model and begin to sample from it. 

For now, we'll just use the data generated form our ODE solution above. 

In [18]:
data = Array(sol)

5×21 Array{Float64,2}:
 0.625792   0.512295  0.443455  0.401702  …  0.337379  0.337364  0.337355
 0.0721748  0.14722   0.198047  0.233221     0.333553  0.334243  0.334807
 0.450577   0.431598  0.413829  0.398462     0.33934   0.338962  0.338657
 0.430298   0.393723  0.371538  0.358082     0.337353  0.337349  0.337346
 0.107864   0.201871  0.259838  0.295239     0.339082  0.338789  0.338541

In [22]:
model = fit(data, problem)

DynamicPPL.Model{var"#7#8",(:data, :prob),(),(),Tuple{Array{Float64,2},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},false,Float64,ODEFunction{false,typeof(NetworkDiffusion),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},SciMLBase.StandardODEProblem}},Tuple{}}(:fit, var"#7#8"(), (data = [0.6257917078751645 0.5122950028113811 … 0.3373639108661273 0.3373551072775948; 0.07217480676699561 0.14721969625786693 … 0.33424327503069357 0.33480686619348554; … ; 0.43029838343893534 0.3937225004492241 … 0.3373485365825523 0.3373456995043845; 0.10786392392132527 0.20187069748284486 … 0.338788772145355 0.33854119065066496], prob = [36mODEProblem[0m with uType [36mArray{Float64,1}[0m and tType [36mFloat64[0m. In-place: [36mfalse[0m
timespan: (0.0, 1.0)
u0: [0.6257917078751645, 0.0721748067669956

In [23]:
chain = sample(model, NUTS(0.65), 1000)

┌ Info: Found initial step size
│   ϵ = 0.4
└ @ Turing.Inference /home/chaggar/.julia/packages/Turing/XLLTf/src/inference/hmc.jl:188
[32mSampling: 100%|█████████████████████████████████████████| Time: 0:00:02[39m


Chains MCMC chain (1000×14×1 Array{Float64,3}):

Iterations        = 1:1000
Thinning interval = 1
Chains            = 1
Samples per chain = 1000
parameters        = ρ, σ
internals         = acceptance_rate, hamiltonian_energy, hamiltonian_energy_error, is_accept, log_density, lp, max_hamiltonian_energy_error, n_steps, nom_step_size, numerical_error, step_size, tree_depth

Summary Statistics
 [1m parameters [0m [1m    mean [0m [1m     std [0m [1m naive_se [0m [1m    mcse [0m [1m       ess [0m [1m    rhat [0m
 [90m     Symbol [0m [90m Float64 [0m [90m Float64 [0m [90m  Float64 [0m [90m Float64 [0m [90m   Float64 [0m [90m Float64 [0m

           ρ    2.0323    0.1578     0.0050    0.0041    910.6086    1.0007
           σ    0.0285    0.0026     0.0001    0.0001   1055.8662    0.9990

Quantiles
 [1m parameters [0m [1m    2.5% [0m [1m   25.0% [0m [1m   50.0% [0m [1m   75.0% [0m [1m   97.5% [0m
 [90m     Symbol [0m [90m Float64 [0m [90m Float64 

In [25]:
using StatsPlots
plot(chain)

┌ Info: Precompiling StatsPlots [f3b207a7-027a-5e70-b257-86293d7955fd]
└ @ Base loading.jl:1278
