We can use the survival analysis to evaluate rate of social acquisition versus rate of asocial acquisition. We need simply to generate surv_object for this type of data. 

$$
\lambda_i(t) = \lambda_0(t) (1- z_i(t)) \left[ 1 + s \sum_{j = 1}^{N} a_{ij} z_j (t)  \right] 
$$

where $\lambda_i(t)$ is the asocial rate and $s$ is the social learning rate.

In [2]:
import os
import sys
newPath = os.path.dirname(os.path.abspath(""))
if newPath not in sys.path:
    sys.path.append(newPath)

from main import *
m = bi(platform='cpu')
import networkx as nx
import matplotlib.pyplot as plt
# Generating random network (In the future we will use STRAND network generations to include data collection biases)
def create_random_network(n, rate = 0.2, seed=0):
    """
    Create a random network adjacency matrix using the Erdős-Rényi model.
    
    Parameters:
        n (int): Number of nodes in the network.
        rate (float): ate parameter (mean of the distribution), must be >= 0.
        seed (int): Random seed for reproducibility.
        
    Returns:
        jax.numpy.ndarray: Adjacency matrix of the generated random network.
    """
    # Set the random seed for reproducibility
    key = jax.random.PRNGKey(seed)
    
    upper_tri = jax.random.poisson(key,lam = rate,  shape = (n, n))
    lower_tri = upper_tri
    
    # Make the matrix symmetric to represent an undirected graph
    m = upper_tri + lower_tri
    m = m.at[jnp.diag_indices(m.shape[0])].set(0)
    return m

def plot_network(m):
    """
    Plot the network adjacency matrix.
    
    Parameters:
    m (jnp.array): An array of agents' social contact matrix.
    
    Returns:
    None.
    """
    G = nx.from_numpy_array(m)
    plt.figure(figsize=(8, 6))
    pos = nx.spring_layout(G)  # Compute positions for visualization
    nx.draw(G, pos, with_labels=False, node_color='lightblue', edge_color='gray', node_size=100)
    plt.show()

def simulate_social_transmission(m, s=5, baseRate=1/100, BNoise=0.1, seed=0):
    """Simulate a social transmission process in a network.

    Args:
        m (2d array): Network adjacency matrix.
        s (int, optional): Social transmission coefficient. Defaults to 5.
        baseRate (float, optional): Base rate of transmission. Defaults to 1/100.
        BNoise (float, optional): Variance of s. Defaults to 0.1.
        seed (int, optional): Seed for reproducibility. Defaults to 0.

    Returns:
        tuple: Two arrays, time of acquisition and order of acquisition.
    """
    key = jax.random.PRNGKey(seed)
    N = m.shape[0]

    # Initialize the asocial learning propensity for all nodes to 1
    asocialLP = jnp.ones(N)

    # Social transmission coefficient vector with noise
    key, subkey = jax.random.split(key)
    BVect = jnp.exp(jax.random.normal(subkey, (N,)) * BNoise + jnp.log(2))

    # Arrays for acquisition status, acquisition order, and acquisition time
    z = jnp.zeros(N)
    orderAcq = jnp.zeros(N, dtype=int)
    timeAcq = jnp.zeros(N)

    runningTime = 0

    # Define the step function for the simulation
    def step(i, state):
        z, runningTime, orderAcq, timeAcq, key = state

        # m is transposed to account for the *i* propensity to transmit information to *j* proportional to *j*'s link weights toward *i*
        # We retranspose this matrix to get the i,j reading order.
        # We multiply it by *z* to have information on individual *i*'s status.
        # We multiply it by *s* to amplify the social influence effect, scaling how strongly an individual's status influences their neighbors.
        # This results in a transmission rate that combines asocial learning and weighted social influence, indicating the likelihood of transmission to each node.
        rate = baseRate * (jnp.exp(asocialLP) + s * (z @ (m.T * BVect).T)) * (1 - z)

        # Sample the time until the next transmission event
        key, subkey = jax.random.split(key)
        times = jax.random.exponential(subkey, shape=(N,)) / rate

        # Find the next individual to acquire the trait
        min_idx = jnp.argmin(times)
        min_time = times[min_idx]

        # Update cumulative time and acquisition status
        runningTime += min_time
        timeAcq = timeAcq.at[min_idx].set(runningTime)
        orderAcq = orderAcq.at[i].set(min_idx)
        z = z.at[min_idx].set(1)

        return z, runningTime, orderAcq, timeAcq, key

    # Use `lax.scan` to loop through each transmission step
    z, runningTime, orderAcq, timeAcq, _ = jax.lax.fori_loop(0, N, step, (z, runningTime, orderAcq, timeAcq, key))

    return timeAcq, orderAcq

jax.local_device_count 32


In [3]:
# Create a random network adjacency matrix with 100 nodes and a connection rate of 0.2
network = create_random_network(10, rate=0.5)
network

Array([[0, 2, 0, 4, 2, 0, 2, 2, 0, 4],
       [0, 0, 0, 2, 2, 0, 0, 0, 2, 0],
       [2, 4, 0, 0, 2, 2, 0, 8, 0, 0],
       [0, 0, 0, 0, 2, 0, 0, 0, 4, 0],
       [0, 0, 2, 0, 0, 2, 2, 0, 0, 2],
       [2, 0, 2, 0, 4, 0, 2, 2, 2, 2],
       [0, 2, 0, 0, 4, 2, 0, 0, 4, 0],
       [0, 2, 4, 2, 2, 0, 4, 0, 2, 2],
       [2, 0, 2, 0, 0, 0, 0, 0, 0, 2],
       [2, 2, 0, 4, 2, 2, 0, 2, 0, 0]], dtype=int64)

In [4]:
StimeAcq, SorderAcq = simulate_social_transmission(network, s=5, BNoise=0.1, seed=0)
SorderAcq

Array([8, 2, 0, 7, 1, 6, 3, 9, 4, 5], dtype=int64)

In [9]:
def get_status_by_t(SolveOrders): 
    times = len(SolveOrders)
    N_id = len(SolveOrders)
    z_i = jnp.zeros((N_id, times))
    for t in range(times):
        z_i = z_i.at[t, SolveOrders[:t+1] - 1].set(1)
    
    Zt0 = jnp.vstack([jnp.zeros(z_i.shape[1]), z_i])
    #Zt0 = jnp.delete(Zt0, -1, axis=0)
    return Zt0, Zt0[0:Zt0.shape[0]-1,:], Zt0[1:Zt0.shape[0],:],
Z, Zt0, Zt1  = get_status_by_t(SorderAcq)
Z.T # this is the equivalent of death matrix in survival object

Array([[0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
       [0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1.],
       [0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1.],
       [0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1.],
       [0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float64)

In [6]:
# Parameters priors distributions-------------------------
## Base hazard distribution
lambda0 = bi.dist.gamma(0.01, 0.01, shape= Zt0.T.shape, name ='lambda0', sample=True)
## Covariate effect distribution
s = bi.dist.normal(0, 1000, shape = (1,),  name='beta',sample=True)


In [24]:
dyadic = create_random_network(10, rate=0.2) # regression ouput from all covariates (nodal, dyadic, fixed or varying)
Zt0.T[:,1] * network * dyadic 
tmp= Zt0.T[:,1] * network *  dyadic 
jnp.sum(tmp, axis=1)

jax.vmap(lambda x:x@(network*beta), in_axes=1)(Zt0.T).T

Array([[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  4.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0., 16.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  4.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  4.,  0.,  0.]], dtype=float64)

In [None]:
dyadic = create_random_network(10, rate=0.2) # regression ouput from all covariates (nodal, dyadic, fixed or varying)
# Now we need to encode a varying covariate that represent for an in individual i, the sum of links weights with informed individuals for each time steps.
# For this we multiply each column of Z.T with the network matrix and then sum by row.
# Example at time t = 0
Zt0.T[:,0] * network # As no one is informed at time 0, we have a 0 in the first column
# Example at time t = 1
tmp= Zt0.T[:,1] * network *  dyadic 
print(tmp)# As only individual 8 is informed at time 1, we have only its connections that are non-zero
print(jnp.sum(tmp, axis=1)) # We sum the columns to get the sum of the links weights with informed individuals at time 1 and rea5ise the dot product woth the c oefficient beta

t=jax.vmap(lambda x:x@(network*beta), in_axes=1)(Zt0.T).T
# repeat this for each networks with intercept in each ones that will work as a weigth of importance of this network in the transmission process.


In [None]:
[2, 0, 5, 0] # Multiple traits 

In [None]:
network2 = create_random_network(5, rate=0.5)
cat_expression1 = jnp.array([[2, 0, 0],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0]])

cat_expression2 = jnp.array([[0, 2, 0],
[0, 0, 0],
[0, 0, 0],
[0, 1, 0],
[0, 0, 0]])

cat_expression3 = jnp.array([[0, 0, 5],
[0, 0, 0],
[0, 0, 1],
[0, 0, 0],
[0, 0, 1]])
jnp.stack([cat_expression1, cat_expression2, cat_expression3]) # this will be the ouput

jax.vmap(lambda x:x*network2)(cat_expression.T)

In [40]:
network2

Array([[0, 4, 2, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 2, 0],
       [0, 2, 0, 0, 4],
       [0, 0, 0, 2, 0]], dtype=int64)

In [32]:
Zt1.T

Array([[0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
       [0., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [0., 0., 0., 0., 0., 0., 1., 1., 1., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.],
       [0., 0., 0., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [0., 0., 0., 0., 0., 0., 0., 1., 1., 1.],
       [0., 0., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float64)