# Functions simulate NBDA

**NO NEED TO CHECK CODE AS THE PROCESS MAY BE DIFFERENT, IT IS JUST TO GENERATE DATA AND PARAMETERS TO COMPARE THE RESULTS BETWEEN R AND JAX**

In [2]:
import os
import sys
newPath = os.path.dirname(os.path.dirname(os.path.abspath("")))
if newPath not in sys.path:
    sys.path.append(newPath)

from main import *

bi()
import networkx as nx
import matplotlib.pyplot as plt
# Generating random network (In the future we will use STRAND network generations to include data collection biases)
def create_random_network(n, rate = 0.2, seed=0):
    """
    Create a random network adjacency matrix using the Erdős-Rényi model.
    
    Parameters:
        n (int): Number of nodes in the network.
        rate (float): ate parameter (mean of the distribution), must be >= 0.
        seed (int): Random seed for reproducibility.
        
    Returns:
        jax.numpy.ndarray: Adjacency matrix of the generated random network.
    """
    # Set the random seed for reproducibility
    key = jax.random.PRNGKey(seed)
    
    upper_tri = jax.random.poisson(key,lam = rate,  shape = (n, n))
    lower_tri = upper_tri
    
    # Make the matrix symmetric to represent an undirected graph
    m = upper_tri + lower_tri
    m = m.at[jnp.diag_indices(m.shape[0])].set(0)
    return m

def plot_network(m):
    """
    Plot the network adjacency matrix.
    
    Parameters:
    m (jnp.array): An array of agents' social contact matrix.
    
    Returns:
    None.
    """
    G = nx.from_numpy_array(m)
    plt.figure(figsize=(8, 6))
    pos = nx.spring_layout(G)  # Compute positions for visualization
    nx.draw(G, pos, with_labels=False, node_color='lightblue', edge_color='gray', node_size=100)
    plt.show()

def simulate_social_transmission(m, s=5, baseRate=1/100, BNoise=0.1, seed=0):
    """Simulate a social transmission process in a network.

    Args:
        m (2d array): Network adjacency matrix.
        s (int, optional): Social transmission coefficient. Defaults to 5.
        baseRate (float, optional): Base rate of transmission. Defaults to 1/100.
        BNoise (float, optional): Variance of s. Defaults to 0.1.
        seed (int, optional): Seed for reproducibility. Defaults to 0.

    Returns:
        tuple: Two arrays, time of acquisition and order of acquisition.
    """
    key = jax.random.PRNGKey(seed)
    N = m.shape[0]

    # Initialize the asocial learning propensity for all nodes to 1
    asocialLP = jnp.ones(N)

    # Social transmission coefficient vector with noise
    key, subkey = jax.random.split(key)
    BVect = jnp.exp(jax.random.normal(subkey, (N,)) * BNoise + jnp.log(2))

    # Arrays for acquisition status, acquisition order, and acquisition time
    z = jnp.zeros(N)
    orderAcq = jnp.zeros(N, dtype=int)
    timeAcq = jnp.zeros(N)

    runningTime = 0

    # Define the step function for the simulation
    def step(i, state):
        z, runningTime, orderAcq, timeAcq, key = state

        # m is transposed to account for the *i* propensity to transmit information to *j* proportional to *j*'s link weights toward *i*
        # We retranspose this matrix to get the i,j reading order.
        # We multiply it by *z* to have information on individual *i*'s status.
        # We multiply it by *s* to amplify the social influence effect, scaling how strongly an individual's status influences their neighbors.
        # This results in a transmission rate that combines asocial learning and weighted social influence, indicating the likelihood of transmission to each node.
        rate = baseRate * (jnp.exp(asocialLP) + s * (z @ (m.T * BVect).T)) * (1 - z)

        # Sample the time until the next transmission event
        key, subkey = jax.random.split(key)
        times = jax.random.exponential(subkey, shape=(N,)) / rate

        # Find the next individual to acquire the trait
        min_idx = jnp.argmin(times)
        min_time = times[min_idx]

        # Update cumulative time and acquisition status
        runningTime += min_time
        timeAcq = timeAcq.at[min_idx].set(runningTime)
        orderAcq = orderAcq.at[i].set(min_idx)
        z = z.at[min_idx].set(1)

        return z, runningTime, orderAcq, timeAcq, key

    # Use `lax.scan` to loop through each transmission step
    z, runningTime, orderAcq, timeAcq, _ = jax.lax.fori_loop(0, N, step, (z, runningTime, orderAcq, timeAcq, key))

    return timeAcq, orderAcq



jax.local_device_count 32


In [3]:
# Create a random network adjacency matrix with 100 nodes and a connection rate of 0.2
network = create_random_network(10, rate=0.5)
network

Array([[0, 2, 0, 4, 2, 0, 2, 2, 0, 4],
       [0, 0, 0, 2, 2, 0, 0, 0, 2, 0],
       [2, 4, 0, 0, 2, 2, 0, 8, 0, 0],
       [0, 0, 0, 0, 2, 0, 0, 0, 4, 0],
       [0, 0, 2, 0, 0, 2, 2, 0, 0, 2],
       [2, 0, 2, 0, 4, 0, 2, 2, 2, 2],
       [0, 2, 0, 0, 4, 2, 0, 0, 4, 0],
       [0, 2, 4, 2, 2, 0, 4, 0, 2, 2],
       [2, 0, 2, 0, 0, 0, 0, 0, 0, 2],
       [2, 2, 0, 4, 2, 2, 0, 2, 0, 0]], dtype=int64)

In [5]:
StimeAcq, SorderAcq = simulate_social_transmission(network, s=5, BNoise=0.1, seed=0)

In [6]:
def get_status_by_t(SolveOrders): 
    times = len(SolveOrders)
    N_id = len(SolveOrders)
    z_i = jnp.zeros((N_id, times))
    for t in range(times):
        z_i = z_i.at[t, SolveOrders[:t+1] - 1].set(1)
    
    Zt0 = jnp.vstack([jnp.zeros(z_i.shape[1]), z_i])
    #Zt0 = jnp.delete(Zt0, -1, axis=0)
    return Zt0, Zt0[0:Zt0.shape[0]-1,:], Zt0[1:Zt0.shape[0],:],
Z, Zt0, Zt1  = get_status_by_t(SorderAcq)


In [7]:
from rpy2.robjects import numpy2ri
import rpy2.robjects as robjects
# Convert JAX matrix to a nested list
Rnet = network.tolist()
Rz = Z.tolist()
s = bi.dist.uniform(0, 10, shape = (1,), name = 's', sample=True)
lambda0 = bi.dist.gamma(0.01, 0.01, shape= (1,), name = 'lambda0', sample=True)
Rs = float(s[0])
Rl = float(lambda0[0])

# R code

In [8]:
%load_ext rpy2.ipython

The rpy2.ipython extension is already loaded. To reload it, use:
  %reload_ext rpy2.ipython


In [9]:
%%R -i Rz,Rnet,Rs,Rl
print(Rz)
print(Rnet)
print(Rs)
print(Rl)


      [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10] [,11]
 [1,]    0    0    0    0    0    1    1    1    1     1     1
 [2,]    0    0    1    1    1    1    1    1    1     1     1
 [3,]    0    0    0    0    0    0    0    1    1     1     1
 [4,]    0    0    0    0    0    0    0    0    0     1     1
 [5,]    0    0    0    0    0    0    0    0    0     0     1
 [6,]    0    0    0    0    0    0    1    1    1     1     1
 [7,]    0    0    0    0    1    1    1    1    1     1     1
 [8,]    0    1    1    1    1    1    1    1    1     1     1
 [9,]    0    0    0    0    0    0    0    0    1     1     1
[10,]    0    0    0    1    1    1    1    1    1     1     1
      [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10]
 [1,]    0    0    2    0    0    2    0    0    2     2
 [2,]    2    0    4    0    0    0    2    2    0     2
 [3,]    0    0    0    0    2    2    0    4    2     0
 [4,]    4    2    0    0    0    0    0    2    0     4
 [5,]    2    2    2  

In (function (package, help, pos = 2, lib.loc = NULL, character.only = FALSE,  :
  libraries ‘/usr/local/lib/R/site-library’, ‘/usr/lib/R/site-library’ contain no packages


In [10]:
%%R
# R model equivalent 
X = Rz #NxT array of individuals status
X2 = matrix(NA, nrow=dim(X)[1], ncol=dim(X)[2])
A = Rnet # Network
T = dim(X)[2]
N = dim(X)[1]
for(t in 1:T){
  if(t==1){
    X2[,t] = Rl
  }else{
    for(i in 1:N){
      if(X[i, t-1] == 0){
        scrap = Rl  + Rs * A[i,] %*% (X[,(t-1)] + 1)
        X2[i,t] = scrap
      }
    }
  }
}
X2

              [,1]     [,2]     [,3]     [,4]     [,5]     [,6]     [,7]
 [1,] 8.909894e-07 17.30364 17.30364 17.30364 21.62955 21.62955       NA
 [2,] 8.909894e-07 25.95546 30.28136       NA       NA       NA       NA
 [3,] 8.909894e-07 21.62955 30.28136 30.28136 30.28136 30.28136 30.28136
 [4,] 8.909894e-07 25.95546 30.28136 34.60727 43.25909 43.25909 51.91091
 [5,] 8.909894e-07 43.25909 47.58500 51.91091 56.23682 64.88864 69.21455
 [6,] 8.909894e-07 17.30364 17.30364 17.30364 21.62955 25.95546 25.95546
 [7,] 8.909894e-07 21.62955 30.28136 30.28136 30.28136       NA       NA
 [8,] 8.909894e-07 30.28136       NA       NA       NA       NA       NA
 [9,] 8.909894e-07 30.28136 34.60727 38.93318 38.93318 47.58500 47.58500
[10,] 8.909894e-07 25.95546 30.28136 30.28136       NA       NA       NA
          [,8]     [,9]    [,10]    [,11]
 [1,]       NA       NA       NA       NA
 [2,]       NA       NA       NA       NA
 [3,] 34.60727       NA       NA       NA
 [4,] 51.91091 51.91091 51.91

Currently, JAX output do not have lambda0 as first value in the column, but it is easy to add.

In [13]:
vmap(lambda x : (lambda0 + (s * ((x + 1) @ network))) * (1-x))(Zt0).T# Transpose to compare the results

Array([[17.30363726, 17.30363726, 17.30363726, 21.62954635, 21.62954635,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
       [25.95545544, 30.28136454,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
       [21.62954635, 30.28136454, 30.28136454, 30.28136454, 30.28136454,
        30.28136454, 34.60727363,  0.        ,  0.        ,  0.        ],
       [25.95545544, 30.28136454, 34.60727363, 43.25909181, 43.25909181,
        51.91091   , 51.91091   , 51.91091   , 51.91091   ,  0.        ],
       [43.25909181, 47.5850009 , 51.91091   , 56.23681909, 64.88863727,
        69.21454636, 77.86636455, 82.19227364, 82.19227364, 86.51818273],
       [17.30363726, 17.30363726, 17.30363726, 21.62954635, 25.95545544,
        25.95545544,  0.        ,  0.        ,  0.        ,  0.        ],
       [21.62954635, 30.28136454, 30.28136454, 30.28136454,  0.        ,
         0.        ,  0.        ,  0.        