In [None]:
import network_diffusion as nd
import numpy as np
import torch
import torch_geometric as pyg

In [None]:
p = 0.7

In [None]:
net = nd.mln.functions.get_toy_network_piotr()
print(net)

for actor in net.get_actors():
    for layer_name in actor.layers:
        net.layers[layer_name].nodes[actor.actor_id]["status"] = np.random.choice([-1, 0, 1], p=[0, 0.8, 0.2])

In [None]:
l1 = pyg.utils.from_networkx(net["l1"])

In [None]:
A = pyg.utils.to_torch_sparse_tensor(l1.edge_index)
S = l1["status"].to(torch.int8)
S, A

## Functions

In [None]:
def draw_live_edges(A: torch.Tensor, p: float) -> torch.Tensor:
    """Draw eges which transmit the state (i.e. their random weight < p)."""
    raw_signals = torch.rand_like(A.values(), dtype=float)
    thre_signals = (raw_signals < p).to(float)
    T = torch.sparse_coo_tensor(indices=A.indices(), values=thre_signals)
    assert A.shape == T.shape
    assert ((A - T).to_dense() < 0).sum() == 0
    return T


def mask_Sy(S: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """Create mask for T which discards signals from nodes which state != 1."""
    return (S > 0).to(torch.int).repeat(len(S), 1).T.to_sparse_coo()


def mask_Sx(S:torch.Tensor) -> torch.Tensor:
    """Create mask for T which discards signals to nodes which state != 0."""
    return torch.abs(torch.abs(S) - 1).to_sparse_coo()


def get_active_nodes(T: torch.Tensor, S: torch.Tensor) -> torch.Tensor:
    """Obtain newly active nodes (0 -> 1) in current simulation step."""
    Sy = mask_Sy(S)
    Sx = mask_Sx(S)
    S_new = (T * Sy).sum(dim=0).to_dense() * Sx.to_dense()
    assert torch.all(S[S_new.to(torch.int).to(bool)] == 0) == torch.Tensor([True]), \
        "Some nodes were activated against rules!"
    return S_new


def decay_active_nodes(S: torch.Tensor) -> torch.Tensor:
    """Change states of nodes that are active to become activated (1 -> -1)."""
    return -1 * torch.abs(S)


def simulation_step(S: torch.Tensor, A: torch.Tensor, p: float) -> torch.Tensor:
    """
    Make a single simulation step.
    
    1. determine which edges drawn value below p
    2. try to make inactive nodes active; only those wich preserved edges with active neighbours
    3. decay activation potential for nodes that were acting as the active in the current step

    :param S: vector of node states (0 - inactive, 1 - active, -1 - activated)
    :param A: adjacency matrix
    :param p: probability of activation between active and inactive node
    """
    T = draw_live_edges(A, p)
    S_1 = get_active_nodes(T, S)
    S_0 = decay_active_nodes(S)
    return (S_1 + S_0).to(torch.int8)

## Test

In [None]:
S_i = S
print(S_i)
for i in range(10):
    S_i = simulation_step(S_i, A, p)
    print(S_i)