In [None]:
from dataclasses import dataclass
from typing import Any, Callable

from bidict import bidict
import networkx as nx
import network_diffusion as nd
import numpy as np
import torch
import torch_geometric as pyg

In [None]:
net_p = nd.mln.functions.get_toy_network_piotr()
print(net_p)

In [None]:
net_m = nd.MultilayerNetwork.from_nx_layer(nx.les_miserables_graph(), ["l1", "l2", "l3"])
print(net_m)

In [None]:
def prepare_mln_for_conversion(net: nd.MultilayerNetwork) -> tuple[nd.MultilayerNetwork, bidict, dict[str, set[Any]] | None]:
    """
    Prepare nd.MultilayerNetwork for conversion to torch representation.

    If net is not multiplex, then multiplicity actoss all layers is imposed. Names of actors are
    converted to integers.

    :param net: a multilayer network to prepare for conversion
    :return: a new instance of nd.MultilayerNetwork prepared for conversion, a bi-directional map of
        old and new names of the actors, a dict of sets of nodes added to make the net multiplex
    """
    if not net.is_multiplex():
        net, added_nodes = net.to_multiplex()
    else:
        net, added_nodes = net.copy(), None
    # ac_map = {ac.actor_id: idx for idx, ac in enumerate(sorted(net.get_actors(), key=lambda x: x.actor_id))}
    ac_map = {ac.actor_id: idx for idx, ac in enumerate(net.get_actors())}
    for l_graph in net.layers.values():
        nx.relabel_nodes(l_graph, mapping=ac_map, copy=False)
    return net, bidict(ac_map), added_nodes


def coalesce_and_check(tensor_raw: torch.Tensor) -> torch.Tensor:
    """Coalesce tensor and check if anything chenged during that operation."""
    tensor_coalesced = tensor_raw.coalesce()
    assert torch.all(tensor_raw._indices() == tensor_coalesced._indices())
    assert torch.all(tensor_raw._values() == tensor_coalesced._values())
    assert tensor_raw.size() == tensor_coalesced.size()
    assert tensor_raw._nnz() == tensor_coalesced._nnz()
    assert tensor_raw.layout == tensor_coalesced.layout
    return tensor_coalesced


def mln_to_sparse(net: nd.MultilayerNetwork, actor_order: list[int]) -> tuple[torch.Tensor, list[str]]:
    """
    Converse nd.MultilayerNetwork to an adjacency matrix as a tensor. 

    :param net: nd.MultilayerNetwork to be converted, must be multiplex and have actors' ids as ints
    :param actor_order: order of actors' ids to be used in the output adjacency tensor
    :return: an adj. matrix as a sparse tensor and a list of layer names ordered as in adj. matrix
    """
    A, L = [], []
    for l_name, l_graph in net.layers.items():
        A_idx, A_val = pyg.utils.from_scipy_sparse_matrix(nx.adjacency_matrix(l_graph, actor_order))
        l_A = torch.sparse_coo_tensor(indices=A_idx, values=A_val, size=[len(actor_order)] * 2, is_coalesced=True, check_invariants=True)
        A.append(l_A)
        L.append(l_name)
    return coalesce_and_check(torch.stack(A)), L


def create_nodes_mask(
        layers_order: list[str], actors_map: bidict, nodes_added: dict[str, set[Any]] | None
    ) -> torch.Tensor:
    """
    Create mask with nodes added artifically (while converting network to multiplex) marked as ones.

    :param layers_order: names of layers in an order that is preserved in A
    :param actors_map: map of actor names Any -> int between original network and its sparse repr
    :param nodes_added: a dict of sets of nodes added to make the net multiplex
    :return: tensor of states
    """
    nodes_mask = torch.zeros([len(layers_order), len(actors_map)])
    if not nodes_added:
        return nodes_mask
    for l_name, l_added_nodes in nodes_added.items():
        # print(f"layer: {l_name}, idx: {net.L.index(l_name)}, added nodes: {l_added_nodes}")
        for l_added_node in l_added_nodes:
            # print(f"map: {l_added_node}->{net.C[l_added_node]}")
            nodes_mask[layers_order.index(l_name), actors_map[l_added_node]] = 1. # -1 * float("inf")
    return nodes_mask


@dataclass(frozen=True)
class MultilayerNetworkTorch:
    """Representation of nd.MultilayerNetwork in a tensor notation."""

    adjacency_tensor: torch.Tensor  # adjacency matrix as a sparse tensor
    layers_order: list[str]  # names of layers in an order that is preserved in adjacency_tensor
    actors_map: bidict  # map of actor names Any -> int between the original network and its sparse repr
    nodes_mask: torch.Tensor  # mask of nodes added while making the network multiplex ordered in the same way as the nodes in adjacency_tensor

    @classmethod
    def from_mln(cls, net: nd.MultilayerNetwork) -> "MultilayerNetworkTorch":
        """Represent net as in a tensor notation."""
        net_converted, ac_map, nodes_added = prepare_mln_for_conversion(net=net)
        adj, l_order = mln_to_sparse(net=net_converted, actor_order=list(ac_map.values()))
        n_mask = create_nodes_mask(layers_order=l_order, actors_map=ac_map, nodes_added=nodes_added)
        return cls(adjacency_tensor=adj, layers_order=l_order, actors_map=ac_map, nodes_mask=n_mask)
    
    def __repr__(self) -> str:
        return f"{self.__class__.__name__} at {id(self)}\n" \
            f"A: {self.adjacency_tensor}\n L: {self.layers_order}\n C: {self.actors_map}\n S: {self.nodes_mask}\n"

In [None]:
net_raw = net_p
nd.mln.functions.draw_mln(net_raw, 100)

In [None]:
net_converted, ac_map, nodes_added = prepare_mln_for_conversion(net_raw)
nd.mln.functions.draw_mln(net_converted, 100)

In [None]:
adjacency_tensor, layers_order = mln_to_sparse(net_converted, list(ac_map.values()))
adjacency_tensor.to_dense(), layers_order

In [None]:
MultilayerNetworkTorch.from_mln(net=net_raw)

## Functions

In [None]:
def create_states_tensor(mln_torch: MultilayerNetworkTorch, seed_set: set[Any]) -> torch.Tensor:
    """
    Create tensor of states

    :param mln_torch: a network (in tensor representation) to create a states tensor for
    :param seed_set: a set of initially active actors (ids of actors given in the original form)
    :return: a tensor shaped as [number_of_layers x number_of_actors] with 1. marked for seed nodes
        and -inf for nodes that were artifically added during converting the network to the tensor
        representation
    """
    seed_set_mapped = [mln_torch.actors_map[seed] for seed in seed_set]
    print(f"{seed_set} -> {seed_set_mapped}")
    states_raw = torch.clone(mln_torch.nodes_mask)
    states_raw[states_raw == 1.] = -1 * float("inf")
    states_raw[:, seed_set_mapped] += 1
    return states_raw


def draw_live_edges(A: torch.Tensor, p: float) -> torch.Tensor:
    """Draw eges which transmit the state (i.e. their random weight < p)."""
    raw_signals = torch.rand_like(A.values(), dtype=float)
    thre_signals = (raw_signals < p).to(float)
    T = torch.sparse_coo_tensor(indices=A.indices(), values=thre_signals)
    assert A.shape == T.shape
    assert ((A - T).to_dense() < 0).sum() == 0
    return T


def mask_S_from(S: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """Create mask for T which discards signals from nodes which state != 1."""
    return (S > 0).to(torch.int).unsqueeze(-1).repeat(1, 1, S.shape[1]).to_sparse_coo()


def mask_S_to(S:torch.Tensor) -> torch.Tensor:
    """Create mask for T which discards signals to nodes which state != 0."""
    return torch.abs(torch.abs(S) - 1).to_sparse_coo()


def get_active_nodes(T: torch.Tensor, S: torch.Tensor) -> torch.Tensor:
    """Obtain newly active nodes (0 -> 1) in the current simulation step."""
    S_f = mask_S_from(S)
    S_t = mask_S_to(S)
    S_new = ((T * S_f).sum(dim=1) * S_t).to_dense()
    assert torch.all(S[S_new.to(torch.int).to(bool)] == 0) == torch.Tensor([True]), \
        "Some nodes were activated against rules (i.e. only these with state 0 can be activated)!"
    return S_new


def protocol_AND(S_raw: torch.Tensor, net: MultilayerNetworkTorch) -> torch.Tensor:
    """
    Aggregate positive impulses from the layers using AND strategy.

    :param S_raw: raw impulses obtained by the nodes
    :param net: a network which is a medium for the diffusion
    :return: a tensor shaped as [1 x number of actors] with 1. denoting actors that were activated 
        in this simulation step and 0. denoting actors that weren't activated
    """
    return (S_raw + net.nodes_mask > 0).all(dim=0).to(torch.float)


def protocol_OR(S_raw: torch.Tensor, net: MultilayerNetworkTorch) -> torch.Tensor:
    """
    Aggregate positive impulses from the layers using AND strategy.

    :param S_raw: raw impulses obtained by the nodes
    :param net: a network which is a medium for the diffusion
    :return: a tensor shaped as [1 x number of actors] with 1. denoting actors that were activated 
        in this simulation step and 0. denoting actors that weren't activated
    """
    return (S_raw > 0).any(dim=0).to(torch.float)


def decay_active_nodes(S: torch.Tensor) -> torch.Tensor:
    """Change states of nodes that are active to become activated (1 -> -1)."""
    decayed_S = -1. * torch.abs(S)
    decayed_S[decayed_S == -0.] = 0.
    return decayed_S


def simulation_step(net: MultilayerNetworkTorch, p: float, protocol: Callable, S0: torch.Tensor) -> torch.Tensor:
    """
    Make a single simulation step.
    
    1. determine which edges drawn value below p
    2. transfer state from active (1.) nodes to their inactive (0.) neighbours only if egdes were preserved at step 1.
    3. aggregate positive impulses from the layers to determine actors that got activated during this simulation step 
    4. decay activation potential for actors that were acting as the active in the current simulation step
    5. obtain the final tensor of states after this simulation step 

    :param net: a network wtihch is a medium of the diffusion
    :param p: a probability of activation between active and inactive node
    :param protocol: a function that aggregates positive impulses from the network's layers
    :param S0: initial tensor of nodes' states (0 - inactive, 1 - active, -1 - activated, -inf - node does not exist)
    :return: updated tensor with nodes' states
    """
    T = draw_live_edges(net.adjacency_tensor, p)
    S1_raw = get_active_nodes(T, S0)
    S1_aggregated = protocol(S_raw=S1_raw, net=net)
    S0_decayed = decay_active_nodes(S0)
    return S1_aggregated + S0_decayed


def S_nodes_to_actors(S: torch.Tensor) -> torch.Tensor:
    """Convert tensor of nodes' states to a vector of actors' states."""
    _S = torch.clone(S)
    _S[_S == -1 * float("inf")] = 0.
    return _S.sum(dim=0).clamp(-1, 1)

## Test 2

In [None]:
net_nd = net_m
net_torch = MultilayerNetworkTorch.from_mln(net_nd)

In [None]:
seeds = set()
for actor in net_nd.get_actors():
    if np.random.choice([0, 1], p=[0.8, 0.2]) == 1:
        seeds.add(actor.actor_id)
print(seeds)

In [None]:
S0 = create_states_tensor(net_torch, seeds)
S0

In [None]:
Sn = S0
# print(Sn)
print(S_nodes_to_actors(Sn))
for step in range(5):
    Sn = simulation_step(net_torch, 0.3, protocol_AND, Sn)
    # print(Sn)
    print(S_nodes_to_actors(Sn))