# Escape Operator from STREL

In [None]:
import torch
import torch.nn as nn

In [None]:
class DiffEscape():
    """
    Escape operator for STREL. Models escape condition over a spatial graph.
    """
    def __init__(
        self,
        adjacency_matrix,
        d1,
        d2,
        graph_nodes,
        distance_domain_min=None,
        distance_domain_max=None,
    ) -> None:
        self.d1 = d1
        self.d2 = d2
        self.distance_domain_min = distance_domain_min
        self.distance_domain_max = distance_domain_max
        self.adjacency_matrix = adjacency_matrix
        self.graph_nodes = graph_nodes

    def __call__(self, s1):
        return self._boolean(s1)

    def time_depth(self) -> int:
        return 0

    def neighbors_fn(self, node):
        aa = (self.adjacency_matrix[:, node] > 0)
        neigh = self.graph_nodes[aa]
        neigh_pairs = [(i.item(), self.adjacency_matrix[i, node].item()) for i in neigh]
        return neigh_pairs

    def compute_min_distance_matrix(self):
        n = len(self.graph_nodes)
        D = torch.full((n, n), float("inf"))
        for i in range(n):
            D[i, i] = 0.0
        for i in range(n):
            for j in range(n):
                if self.adjacency_matrix[i, j] > 0:
                    D[i, j] = self.adjacency_matrix[i, j].item()
        for k in range(n):
            for i in range(n):
                for j in range(n):
                    D[i, j] = min(D[i, j], D[i, k] + D[k, j])
        return D

    def _boolean(self, s1):
        L = self.graph_nodes
        n = len(L)

        # compute shortest-path distances
        D = self.compute_min_distance_matrix()

        # initialize e with e[l, l] = s1(l), others = 0
        '''
        e = torch.zeros((n, n), requires_grad=True)
        # e = torch.full((n, n), self.distance_domain_min, requires_grad=True) # quantitative
        for i in range(n):
            e[i, i] = s1[i]
        '''
        e = torch.zeros((n, n), requires_grad=True)
        for i in range(n):
            e = e.clone() # TODO: non va bene clone
            e[i, i] = s1[i].clone()


        # initialize T as dictionary with (l, l) → s1(l)
        # T = {(i, i): s1[i] for i in range(n)}  # dict: (l, l) -> value
        # TODO. fai lista di tuple al posto di dict
        T = [(i, i) for i in range(n)]

        while T:
            # T_prime = {}  # dict: (l1', l2) -> value # TODO: lisra
            T_prime = []
            e_prime = e.clone()

            for (l1, l2), val in T.items():
                for l1_prime, w in self.neighbors_fn(l1):
                    new_val = torch.minimum(s1[l1_prime], e[l1, l2])
                    combined = torch.maximum(e[l1_prime, l2], new_val)

                    if not torch.equal(combined, e[l1_prime, l2]): # !=
                        e_prime[l1_prime, l2] = combined # modifica con scatter
                        # keep highest propagated value if (l1', l2) already in T_prime
                        if (l1_prime, l2) in T_prime:
                            T_prime[(l1_prime, l2)] = torch.maximum(T_prime[(l1_prime, l2)], combined) # fai append
                        else:
                            T_prime[(l1_prime, l2)] = combined

            T = T_prime
            e = e_prime

        '''
        s = torch.zeros(len(self.graph_nodes), requires_grad=True)
        # s = torch.full((n,), self.distance_domain_min, requires_grad=True) # quantitative
        for i in range(n):
            vals = [
                e[i, j] for j in range(n)
                if self.d1 <= D[i, j] <= self.d2
            ]
            if vals:
                s[i] = torch.stack(vals).max()
        '''
        s = torch.zeros(len(self.graph_nodes), requires_grad=True)
        for i in range(n):
            vals = [
                e[i, j] for j in range(n)
                if self.d1 <= D[i, j] <= self.d2
            ]
            if vals:
                max_val = torch.stack(vals).max()
                s = s.clone().scatter_(0, torch.tensor([i]), max_val.unsqueeze(0))

        return s

# Test

In [None]:
# 1. Graph initialization (16 nodes, 0-based indexing)
graph_nodes = torch.tensor(list(range(16)))
N = len(graph_nodes)
adjacency_matrix = torch.zeros((N, N))

edges = [
    (7, 0), (7, 3), (7, 5), (7, 6), (6, 7), (6, 4), (4, 6), (4, 5),
    (6, 1), (6, 9), (9, 6), (9, 2), (9, 13), (9, 8), (8, 9), (8, 13),
    (9, 15), (15, 9), (9, 10), (10, 9), (15, 10), (10, 15), (15, 12),
    (10, 14), (10, 11),
    (0, 7), (3, 7), (5, 7), (5, 4),
    (1, 6), (2, 9), (13, 9), (13, 8),
    (12, 15), (14, 10), (11, 10)
]
for i, j in edges:
    adjacency_matrix[i, j] = 1

# 2. Define node roles
router_nodes = [4, 6, 7, 8, 10, 15]
coord_node = 9
end_dev_nodes = [0, 1, 2, 3, 5, 11, 12, 13, 14]

# 3. Node features: type (0=coord, 1=router, 2=end_dev) + battery ∈ [0,1]
node_features = torch.zeros((N, 2))  # [:, 0] = node type, [:, 1] = battery

for i in router_nodes:
    node_features[i, 0] = 1  # router
node_features[coord_node, 0] = 0  # coordinator
for i in end_dev_nodes:
    node_features[i, 0] = 2  # end device

# Random battery levels
node_features[:, 1] = torch.rand(N)

# 4. Define escape signal: not(end_dev) ≡ (tipo != 2)
escape_signal = (node_features[:, 0] != 2).float()

# 5. Initialize and run the Escape operator
escape_op = DiffEscape(
    adjacency_matrix=adjacency_matrix,
    d1=2,
    d2=float("inf"),
    graph_nodes=graph_nodes,
    distance_domain_min=0,
    distance_domain_max=float("inf"),
)

result = escape_op(escape_signal)

# 6. Check if each router can escape through other routers or coordinator within 3 hops
print("\nVerifica: ogni nodo può scappare da not(end_dev) con distanza almeno 2?\n")
for i in range(N):
    is_valid = result[i].item() > 0.5
    print(f"Nodo {i} ({'not end_dev' if node_features[i, 0] != 2 else 'end_dev'}) → {'✅ Sì' if is_valid else '❌ No'}")

# 7. Escape result matrix
Q_escape = torch.stack([
    torch.arange(N),
    result,
    (result > 0.5).float(),
])
print("\nMatrice Q_escape (nodo, soddisfazione, binarizzato):")
print(Q_escape)


Verifica: ogni nodo può scappare da not(end_dev) con distanza almeno 2?

Nodo 0 (end_dev) → ❌ No
Nodo 1 (end_dev) → ❌ No
Nodo 2 (end_dev) → ❌ No
Nodo 3 (end_dev) → ❌ No
Nodo 4 (not end_dev) → ✅ Sì
Nodo 5 (end_dev) → ❌ No
Nodo 6 (not end_dev) → ✅ Sì
Nodo 7 (not end_dev) → ✅ Sì
Nodo 8 (not end_dev) → ✅ Sì
Nodo 9 (not end_dev) → ✅ Sì
Nodo 10 (not end_dev) → ✅ Sì
Nodo 11 (end_dev) → ❌ No
Nodo 12 (end_dev) → ❌ No
Nodo 13 (end_dev) → ❌ No
Nodo 14 (end_dev) → ❌ No
Nodo 15 (not end_dev) → ✅ Sì

Matrice Q_escape (nodo, soddisfazione, binarizzato):
tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
         14., 15.],
        [ 0.,  0.,  0.,  0.,  1.,  0.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,
          0.,  1.],
        [ 0.,  0.,  0.,  0.,  1.,  0.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,
          0.,  1.]], grad_fn=<StackBackward0>)


In [None]:
 '''
 class DiffEscape():
    """
    Escape operator for STREL. Models escape condition over a spatial graph.
    """
    def __init__(
        self,
        adjacency_matrix,
        d1,
        d2,
        graph_nodes,
        distance_domain_min=None,
        distance_domain_max=None,
    ) -> None:
        self.d1 = d1
        self.d2 = d2
        self.distance_domain_min = distance_domain_min
        self.distance_domain_max = distance_domain_max
        self.adjacency_matrix = adjacency_matrix
        self.graph_nodes = graph_nodes

    def __call__(self, s1):
        return self._boolean(s1)

    def time_depth(self) -> int:
        return 0

    def neighbors_fn(self, node):
        aa = (self.adjacency_matrix[:, node] > 0)
        neigh = self.graph_nodes[aa]
        neigh_pairs = [(i.item(), self.adjacency_matrix[i, node].item()) for i in neigh]
        return neigh_pairs

    def distance_function(self, weight):
        return weight

    def compute_min_distance_matrix(self):
        # TODO BETTER !!!
        """
        Computes all-pairs shortest distances D using the provided distance function.
        """
        N = self.adjacency_matrix.shape[0]
        device = self.adjacency_matrix.device
        inf = 1e9

        D = torch.full((N, N), inf, dtype=torch.float32, device=device)
        mask = self.adjacency_matrix > 0
        D[mask] = self.distance_function(self.adjacency_matrix[mask])
        D.fill_diagonal_(0.0)

        for k in range(N):
            D = torch.minimum(D, D[:, k].unsqueeze(1) + D[k, :].unsqueeze(0))

        return D  # shape: (N, N)

    def _boolean(self, s1):
        L = self.graph_nodes
        N = len(L)

        D = self.compute_min_distance_matrix()

        # Inizializza e
        #    - Crea un dizionario e con tutte le coppie ordinate di nodi (l1, l2) del grafo.
        #    - Ogni coppia viene inizializzata a distance_domain_min (in semantica booleana sarà tipicamente False)
        e = {
            (l.item(), l2.item()): self.distance_domain_min
            for l in L for l2 in L
        }
        for l in L:
            e[(l.item(), l.item())] = s1[l.item()] # sovrascrivi solo le diagonali

        # Inizializza T: Crea un insieme di coppie (l, l) per cui iniziare l'espansione dell'informazione
        T = {(l.item(), l.item()) for l in L}

        while T:
        T_prime = {}  # dict per nuovi aggiornamenti (l1p, l2): v
        e_prime = e.copy()

        for (l1, l2) in T:
            for (l1p, w) in self.neighbors_fn(l1):
                v = torch.maximum(
                    e[(l1p, l2)],
                    torch.minimum(s1[l1p], e[(l1, l2)])
                )
                if v != e[(l1p, l2)]:
                    e_prime[(l1p, l2)] = v
                    T_prime[(l1p, l2)] = v

        T = T_prime.keys()
        e = e_prime

        # Calcolo finale di s
        s = torch.zeros(N, dtype=torch.bool)
        for l in L:
            l_idx = l.item()
            vals = []
            for l2 in L:
                l2_idx = l2.item()
                if self.d1 <= D[l_idx, l2_idx] <= self.d2:
                    vals.append(e[(l_idx, l2_idx)])
            s[l_idx] = torch.any(torch.stack(vals)) if vals else torch.tensor(False)
        return s
'''

In [None]:
'''
import torch

class DiffEscape():
    """
    Escape operator for STREL. Models escape condition over a spatial graph.
    """
    def __init__(
        self,
        adjacency_matrix,
        d1,
        d2,
        graph_nodes,
        distance_domain_min=None,
        distance_domain_max=None,
    ) -> None:
        self.d1 = d1
        self.d2 = d2
        self.distance_domain_min = distance_domain_min
        self.distance_domain_max = distance_domain_max
        self.adjacency_matrix = adjacency_matrix
        self.graph_nodes = graph_nodes

    def __call__(self, s1):
        return self._boolean(s1)

    def time_depth(self) -> int:
        return 0

    def neighbors_fn(self, node):
        aa = (self.adjacency_matrix[:, node] > 0)
        neigh = self.graph_nodes[aa]
        neigh_pairs = [(i.item(), self.adjacency_matrix[i, node].item()) for i in neigh]
        return neigh_pairs

    def compute_min_distance_matrix(self):
        # Floyd-Warshall for simplicity
        n = len(self.graph_nodes)
        D = torch.full((n, n), float("inf"))
        for i in range(n):
            D[i, i] = 0.0
        for i in range(n):
            for j in range(n):
                if self.adjacency_matrix[i, j] > 0:
                    D[i, j] = self.adjacency_matrix[i, j].item()

        for k in range(n):
            for i in range(n):
                for j in range(n):
                    if D[i, k] + D[k, j] < D[i, j]:
                        D[i, j] = D[i, k] + D[k, j]
        return D

def _boolean(self, s1):
    L = self.graph_nodes
    n = len(L)

    # Step 1: compute shortest-path distances
    D = self.compute_min_distance_matrix()

    # Step 2: initialize e with e[l, l] = s1(l), others = 0
    e = torch.zeros((n, n), requires_grad=True)
    for i in range(n):
        e[i, i] = s1[i]

    # Step 3: initialize T as dictionary with (l, l) → s1(l)
    T = {(i, i): s1[i] for i in range(n)}

    while T:
        T_prime = {}
        e_prime = e.clone()

        for (l1, l2), val in T.items():
            for l1_prime, w in self.neighbors_fn(l1):
                # Semiring combine: min(s1[l1'], e[l1, l2]) → boolean: AND
                new_val = torch.minimum(s1[l1_prime], e[l1, l2])

                # Semiring choose: max(e[l1', l2], new_val) → boolean: OR
                combined = torch.maximum(e[l1_prime, l2], new_val)

                if not torch.equal(combined, e[l1_prime, l2]):
                    e_prime[l1_prime, l2] = combined
                    T_prime[(l1_prime, l2)] = combined

        T = T_prime
        e = e_prime

    # Final aggregation: s[l] = max({ e[l, l'] | D[l, l'] ∈ [d1, d2] })
    s = torch.zeros(n, requires_grad=True)
    for i in range(n):
        candidates = [
            e[i, j] for j in range(n)
            if self.d1 <= D[i, j] <= self.d2
        ]
        if candidates:
            s[i] = torch.stack(candidates).max()

    return s
'''

In [None]:
'''
class DiffEscape():
    """
    Escape operator for STREL. Models escape condition over a spatial graph.
    """
    def __init__(
        self,
        adjacency_matrix,
        distance_function,
        d1,
        d2,
        graph_nodes,
        distance_domain_min=None,
        distance_domain_max=None,
    ) -> None:
        self.adjacency_matrix = adjacency_matrix
        self.f = distance_function
        self.d1 = d1
        self.d2 = d2
        self.graph_nodes = graph_nodes
        self.distance_domain_min = distance_domain_min
        self.distance_domain_max = distance_domain_max

    def __call__(self, s1):
        return self._boolean(s1)

    def neighbors_fn(self, node):
        aa = (self.adjacency_matrix[:, node] > 0)
        neigh = self.graph_nodes[aa]
        neigh_pairs = [(i.item(), self.adjacency_matrix[i, node].item()) for i in neigh]
        return neigh_pairs

    def compute_min_distance_matrix(self):
        """
        Computes all-pairs shortest distances D using the provided distance function.
        """
        import networkx as nx
        G = nx.DiGraph()
        for i in range(self.adjacency_matrix.shape[0]):
            for j in range(self.adjacency_matrix.shape[1]):
                if self.adjacency_matrix[i, j] > 0:
                    G.add_edge(i, j, weight=self.f(self.adjacency_matrix[i, j].item()))
        return dict(nx.all_pairs_dijkstra_path_length(G))

    def _boolean(self, s1):
        import torch
        L = self.graph_nodes
        D = self.compute_min_distance_matrix()

        # Inizializza e
        N = len(L)
        e = {
            (l.item(), l2.item()): self.distance_domain_min
            for l in L for l2 in L
        }
        for l in L:
            e[(l.item(), l.item())] = s1[l.item()]

        # Inizializza T
        T = {(l.item(), l.item()) for l in L}

        while T:
            T_prime = set()
            e_prime = e.copy()
            for (l1, l2) in T:
                for (l1p, w) in self.neighbors_fn(l1):
                    v = torch.logical_or(
                        e[(l1p, l2)],
                        torch.logical_and(s1[l1p], e[(l1, l2)])
                    )
                    if v != e[(l1p, l2)]:
                        e_prime[(l1p, l2)] = v
                        T_prime.add((l1p, l2))
            T = T_prime
            e = e_prime

        # Calcolo finale di s
        s = torch.zeros(N, dtype=torch.bool)
        for l in L:
            l_idx = l.item()
            vals = []
            for l2 in L:
                l2_idx = l2.item()
                if D.get(l_idx, {}).get(l2_idx, float('inf')) in range(self.d1, self.d2 + 1):
                    vals.append(e[(l_idx, l2_idx)])
            s[l_idx] = torch.any(torch.stack(vals)) if vals else torch.tensor(False)
        return s
'''


In [None]:
'''
import torch

class DiffEscape():
    """
    Differentiable Boolean Escape operator for STREL.
    Dictionary-based version with full initialization of e.
    """

    def __init__(
        self,
        adjacency_matrix,
        d1,
        d2,
        graph_nodes,
        distance_domain_min=None,
        distance_domain_max=None,
    ) -> None:
        super().__init__()
        self.d1 = d1
        self.d2 = d2
        self.distance_domain_min = distance_domain_min
        self.distance_domain_max = distance_domain_max
        self.adjacency_matrix = adjacency_matrix
        self.graph_nodes = graph_nodes

    def __call__(self, s1):
        return self._boolean(s1)

    def __str__(self) -> str:
        return f"Escape[{self.d1},{self.d2}]"

    def time_depth(self) -> int:
        return 0

    def neighbors_fn(self, node):
        aa = (self.adjacency_matrix[:, node] > 0)
        neigh = self.graph_nodes[aa]
        return [(i.item(), self.adjacency_matrix[i, node].item()) for i in neigh]

    def distance_function(self, weight):
        return weight

    def compute_distance_matrix(self):
        num_nodes = len(self.graph_nodes)
        D = torch.full((num_nodes, num_nodes), float('inf'))
        for i in range(num_nodes):
            D[i, i] = 0.0
        for i in range(num_nodes):
            for j in range(num_nodes):
                if self.adjacency_matrix[i, j] > 0:
                    D[i, j] = self.distance_function(self.adjacency_matrix[i, j].item())

        for k in range(num_nodes):
            for i in range(num_nodes):
                for j in range(num_nodes):
                    D[i, j] = min(D[i, j], D[i, k] + D[k, j])
        return D

    def _boolean(self, s1):
        num_nodes = len(self.graph_nodes)
        D = self.compute_distance_matrix()

        # Inizializza e su L x L
        e = {}
        for l in range(num_nodes):
            for l_prime in range(num_nodes):
                if l == l_prime:
                    e[(l, l)] = s1[l]
                else:
                    e[(l, l_prime)] = self.distance_domain_min

        # Inizializza T = {(l, l) | l in L}
        T = {(l, l): True for l in range(num_nodes)}

        while T:
            T_prime = {}
            e_prime = e.copy()

            for (l1, l2) in T.keys():
                for l1_prime, w in self.neighbors_fn(l1):
                    key = (l1_prime, l2)
                    old_val = e.get(key, self.distance_domain_min)
                    new_val = torch.maximum(old_val, torch.minimum(s1[l1_prime], e[(l1, l2)]))
                    if not torch.equal(new_val, old_val):
                        e_prime[key] = new_val
                        T_prime[key] = True

            T = T_prime
            e = e_prime

        # Aggregazione finale
        s = torch.zeros(num_nodes, requires_grad=True)
        for l in range(num_nodes):
            vals = [
                e[(l, l_prime)]
                for l_prime in range(num_nodes)
                if self.d1 <= D[l, l_prime] <= self.d2
            ]
            if vals:
                s[l] = torch.max(torch.stack(vals))

        return s
'''

In [None]:
'''
import torch

class DiffEscape():
    """
    Escape operator for STREL. Models escape condition over a spatial graph.
    """
    def __init__(
        self,
        adjacency_matrix,
        d1,
        d2,
        graph_nodes,
        distance_domain_min=None,
        distance_domain_max=None,
    ) -> None:
        self.d1 = d1
        self.d2 = d2
        self.distance_domain_min = distance_domain_min
        self.distance_domain_max = distance_domain_max
        self.adjacency_matrix = adjacency_matrix
        self.graph_nodes = graph_nodes

    def __call__(self, s1):
        return self._boolean(s1)

    def time_depth(self) -> int:
        return 0

    def neighbors_fn(self, node):
        aa = (self.adjacency_matrix[:, node] > 0)
        neigh = self.graph_nodes[aa]
        neigh_pairs = [(i.item(), self.adjacency_matrix[i, node].item()) for i in neigh]
        return neigh_pairs

    def compute_min_distance_matrix(self):
        # Floyd-Warshall for simplicity
        n = len(self.graph_nodes)
        D = torch.full((n, n), float("inf"))
        for i in range(n):
            D[i, i] = 0.0
        for i in range(n):
            for j in range(n):
                if self.adjacency_matrix[i, j] > 0:
                    D[i, j] = self.adjacency_matrix[i, j].item()

        for k in range(n):
            for i in range(n):
                for j in range(n):
                    if D[i, k] + D[k, j] < D[i, j]:
                        D[i, j] = D[i, k] + D[k, j]
        return D

    def _boolean(self, s1):
        L = self.graph_nodes
        n = len(L)

        # Step 1: compute distance matrix D
        D = self.compute_min_distance_matrix()

        # Step 2: initialize e
        e = torch.full((n, n), self.distance_domain_min, requires_grad=True)
        for i in range(n):
            e[i, i] = s1[i]

        # Step 3: initialize T as dict with (l, l)
        T = {(i, i) for i in range(n)}

        while T:
            T_prime = set()
            e_prime = e.clone()

            for (l1, l2) in T:
                for l1_prime, w in self.neighbors_fn(l1):
                    # compute new value
                    v = torch.maximum(
                        e[l1_prime, l2],
                        torch.minimum(s1[l1_prime], e[l1, l2])
                    )
                    if not torch.equal(v, e[l1_prime, l2]):
                        e_prime[l1_prime, l2] = v
                        T_prime.add((l1_prime, l2))

            T = T_prime
            e = e_prime

        # Final aggregation
        s = torch.full((n,), self.distance_domain_min, requires_grad=True)
        for i in range(n):
            vals = [
                e[i, j] for j in range(n)
                if self.d1 <= D[i, j] <= self.d2
            ]
            if vals:
                s[i] = torch.stack(vals).max()

        return s
'''

In [None]:
import torch
import torch.nn as nn

class DiffEscape():
    """
    Escape operator for STREL. Models escape condition over a spatial graph.
    """
    def __init__(
        self,
        adjacency_matrix,
        d1,
        d2,
        graph_nodes,
        distance_domain_min=None,
        distance_domain_max=None,
    ) -> None:
        self.d1 = d1
        self.d2 = d2
        self.distance_domain_min = distance_domain_min
        self.distance_domain_max = distance_domain_max
        self.adjacency_matrix = adjacency_matrix
        self.graph_nodes = graph_nodes

    def __call__(self, s1):
        return self._boolean(s1)

    def time_depth(self) -> int:
        return 0

    def neighbors_fn(self, node):
        aa = (self.adjacency_matrix[:, node] > 0)
        neigh = self.graph_nodes[aa]
        neigh_pairs = [(i.item(), self.adjacency_matrix[i, node].item()) for i in neigh]
        return neigh_pairs

    def compute_min_distance_matrix(self):
        n = len(self.graph_nodes)
        D = torch.full((n, n), float("inf"))
        for i in range(n):
            D[i, i] = 0.0
        for i in range(n):
            for j in range(n):
                if self.adjacency_matrix[i, j] > 0:
                    D[i, j] = self.adjacency_matrix[i, j].item()
        for k in range(n):
            for i in range(n):
                for j in range(n):
                    D[i, j] = min(D[i, j], D[i, k] + D[k, j])
        return D

    def _boolean(self, s1):
        L = self.graph_nodes
        n = len(L)

        # compute shortest-path distances
        D = self.compute_min_distance_matrix()

        # initialize e with e[l, l] = s1(l), others = 0
        '''
        e = torch.zeros((n, n), requires_grad=True)
        # e = torch.full((n, n), self.distance_domain_min, requires_grad=True) # quantitative
        for i in range(n):
            e[i, i] = s1[i]
        '''
        e = torch.zeros((n, n), requires_grad=True)
        for i in range(n):
            e = e.clone()
            e[i, i] = s1[i].clone()


        # initialize T as dictionary with (l, l) → s1(l)
        T = {(i, i): s1[i] for i in range(n)}  # dict: (l, l) -> value

        while T:
            T_prime = {}  # dict: (l1', l2) -> value
            e_prime = e.clone()

            for (l1, l2), val in T.items():
                for l1_prime, w in self.neighbors_fn(l1):
                    new_val = torch.minimum(s1[l1_prime], e[l1, l2])
                    combined = torch.maximum(e[l1_prime, l2], new_val)

                    if not torch.equal(combined, e[l1_prime, l2]):
                        e_prime[l1_prime, l2] = combined
                        # keep highest propagated value if (l1', l2) already in T_prime
                        if (l1_prime, l2) in T_prime:
                            T_prime[(l1_prime, l2)] = torch.maximum(T_prime[(l1_prime, l2)], combined)
                        else:
                            T_prime[(l1_prime, l2)] = combined

            T = T_prime
            e = e_prime

        '''
        s = torch.zeros(len(self.graph_nodes), requires_grad=True)
        # s = torch.full((n,), self.distance_domain_min, requires_grad=True) # quantitative
        for i in range(n):
            vals = [
                e[i, j] for j in range(n)
                if self.d1 <= D[i, j] <= self.d2
            ]
            if vals:
                s[i] = torch.stack(vals).max()
        '''
        s = torch.zeros(len(self.graph_nodes), requires_grad=True)
        for i in range(n):
            vals = [
                e[i, j] for j in range(n)
                if self.d1 <= D[i, j] <= self.d2
            ]
            if vals:
                max_val = torch.stack(vals).max()
                s = s.clone().scatter_(0, torch.tensor([i]), max_val.unsqueeze(0))


        return s

In [None]:
# 1. Graph initialization (16 nodes, 0-based indexing)
graph_nodes = torch.tensor(list(range(16)))
N = len(graph_nodes)
adjacency_matrix = torch.zeros((N, N))

edges = [
    (7, 0), (7, 3), (7, 5), (7, 6), (6, 7), (6, 4), (4, 6), (4, 5),
    (6, 1), (6, 9), (9, 6), (9, 2), (9, 13), (9, 8), (8, 9), (8, 13),
    (9, 15), (15, 9), (9, 10), (10, 9), (15, 10), (10, 15), (15, 12),
    (10, 14), (10, 11),
    (0, 7), (3, 7), (5, 7), (5, 4),
    (1, 6), (2, 9), (13, 9), (13, 8),
    (12, 15), (14, 10), (11, 10)
]
for i, j in edges:
    adjacency_matrix[i, j] = 1

# 2. Define node roles
router_nodes = [4, 6, 7, 8, 10, 15]
coord_node = 9
end_dev_nodes = [0, 1, 2, 3, 5, 11, 12, 13, 14]

# 3. Node features: type (0=coord, 1=router, 2=end_dev) + battery ∈ [0,1]
node_features = torch.zeros((N, 2))  # [:, 0] = node type, [:, 1] = battery

for i in router_nodes:
    node_features[i, 0] = 1  # router
node_features[coord_node, 0] = 0  # coordinator
for i in end_dev_nodes:
    node_features[i, 0] = 2  # end device

# Random battery levels
node_features[:, 1] = torch.rand(N)

# 4. Define escape signal: not(end_dev) ≡ (tipo != 2)
escape_signal = (node_features[:, 0] != 2).float()

# 5. Initialize and run the Escape operator
escape_op = DiffEscape(
    adjacency_matrix=adjacency_matrix,
    d1=2,
    d2=float("inf"),
    graph_nodes=graph_nodes,
    distance_domain_min=0,
    distance_domain_max=float("inf"),
)

result = escape_op(escape_signal)

# 6. Check if each router can escape through other routers or coordinator within 3 hops
print("\nVerifica: ogni nodo può scappare da not(end_dev) con distanza almeno 2?\n")
for i in range(N):
    is_valid = result[i].item() > 0.5
    print(f"Nodo {i} ({'not end_dev' if node_features[i, 0] != 2 else 'end_dev'}) → {'✅ Sì' if is_valid else '❌ No'}")

# 7. Escape result matrix
Q_escape = torch.stack([
    torch.arange(N),
    result,
    (result > 0.5).float(),
])
print("\nMatrice Q_escape (nodo, soddisfazione, binarizzato):")
print(Q_escape)


Verifica: ogni nodo può scappare da not(end_dev) con distanza almeno 2?

Nodo 0 (end_dev) → ❌ No
Nodo 1 (end_dev) → ❌ No
Nodo 2 (end_dev) → ❌ No
Nodo 3 (end_dev) → ❌ No
Nodo 4 (not end_dev) → ✅ Sì
Nodo 5 (end_dev) → ❌ No
Nodo 6 (not end_dev) → ✅ Sì
Nodo 7 (not end_dev) → ✅ Sì
Nodo 8 (not end_dev) → ✅ Sì
Nodo 9 (not end_dev) → ✅ Sì
Nodo 10 (not end_dev) → ✅ Sì
Nodo 11 (end_dev) → ❌ No
Nodo 12 (end_dev) → ❌ No
Nodo 13 (end_dev) → ❌ No
Nodo 14 (end_dev) → ❌ No
Nodo 15 (not end_dev) → ✅ Sì

Matrice Q_escape (nodo, soddisfazione, binarizzato):
tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
         14., 15.],
        [ 0.,  0.,  0.,  0.,  1.,  0.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,
          0.,  1.],
        [ 0.,  0.,  0.,  0.,  1.,  0.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,
          0.,  1.]], grad_fn=<StackBackward0>)
